34def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
36 Launch multi-gpu or distributed training.
37 This function must be called on all machines involved in the training.
38 It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine.
40 main_func: a function that will be called by `main_func(*args)`
41 num_gpus_per_machine (int): number of GPUs per machine
42 num_machines (int): the total number of machines
43 machine_rank (int): the rank of this machine
44 dist_url (str): url to connect to for distributed jobs, including protocol
45 e.g. "tcp://127.0.0.1:8686".
46 Can be set to "auto" to automatically select a free port on localhost
47 args (tuple): arguments passed to main_func
49 world_size = num_machines * num_gpus_per_machine
54 if dist_url ==
"auto":
55 assert num_machines == 1,
"dist_url=auto not supported in multi-machine jobs."
57 dist_url = f
"tcp://127.0.0.1:{port}"
58 if num_machines > 1
and dist_url.startswith(
"file://"):
59 logger = logging.getLogger(__name__)
61 "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
64 print(f
'num_gpus_per_machine = {num_gpus_per_machine}\n')
68 print(f
'args = {args}\n')
72 nprocs=num_gpus_per_machine,
73 args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
81 local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
83 print(f
'local_rank = {local_rank}\n')
86 assert torch.cuda.is_available(),
"cuda is not available. Please check your installation."
87 global_rank = machine_rank * num_gpus_per_machine + local_rank
89 dist.init_process_group(
90 backend=
"NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
92 except Exception
as e:
93 logger = logging.getLogger(__name__)
94 logger.error(
"Process group URL: {}".format(dist_url))
100 assert num_gpus_per_machine <= torch.cuda.device_count()
101 torch.cuda.set_device(local_rank)
104 assert comm._LOCAL_PROCESS_GROUP
is None
105 num_machines = world_size // num_gpus_per_machine
106 for i
in range(num_machines):
107 ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
108 pg = dist.new_group(ranks_on_i)
109 if i == machine_rank:
110 comm._LOCAL_PROCESS_GROUP = pg