34def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
35 """
36 Launch multi-gpu or distributed training.
37 This function must be called on all machines involved in the training.
38 It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine.
39 Args:
40 main_func: a function that will be called by `main_func(*args)`
41 num_gpus_per_machine (int): number of GPUs per machine
42 num_machines (int): the total number of machines
43 machine_rank (int): the rank of this machine
44 dist_url (str): url to connect to for distributed jobs, including protocol
45 e.g. "tcp://127.0.0.1:8686".
46 Can be set to "auto" to automatically select a free port on localhost
47 args (tuple): arguments passed to main_func
48 """
49 world_size = num_machines * num_gpus_per_machine
50 if world_size > 1:
51
52
53
54 if dist_url == "auto":
55 assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
56 port = _find_free_port()
57 dist_url = f"tcp://127.0.0.1:{port}"
58 if num_machines > 1 and dist_url.startswith("file://"):
59 logger = logging.getLogger(__name__)
60 logger.warning(
61 "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
62 )
63
64 print(f'num_gpus_per_machine = {num_gpus_per_machine}\n')
65
66
67
68 print(f'args = {args}\n')
69
70 mp.spawn(
71 _distributed_worker,
72 nprocs=num_gpus_per_machine,
73 args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
74 daemon=False,
75 )
76 else:
77 main_func(*args)
78
79