Safemotion Lib
Loading...
Searching...
No Matches
comm.py
Go to the documentation of this file.
1"""
2This file contains primitives for multi-gpu communication.
3This is useful when doing distributed training.
4"""
5
6import functools
7import logging
8import numpy as np
9import pickle
10import torch
11import torch.distributed as dist
12
13_LOCAL_PROCESS_GROUP = None
14"""
15A torch process group which only includes processes that on the same machine as the current process.
16This variable is set when processes are spawned by `launch()` in "engine/launch.py".
17"""
18
19
20def get_world_size() -> int:
21 if not dist.is_available():
22 return 1
23 if not dist.is_initialized():
24 return 1
25 return dist.get_world_size()
26
27
28def get_rank() -> int:
29 if not dist.is_available():
30 return 0
31 if not dist.is_initialized():
32 return 0
33 return dist.get_rank()
34
35
36def get_local_rank() -> int:
37 """
38 Returns:
39 The rank of the current process within the local (per-machine) process group.
40 """
41 if not dist.is_available():
42 return 0
43 if not dist.is_initialized():
44 return 0
45 assert _LOCAL_PROCESS_GROUP is not None
46 return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
47
48
49def get_local_size() -> int:
50 """
51 Returns:
52 The size of the per-machine process group,
53 i.e. the number of processes per machine.
54 """
55 if not dist.is_available():
56 return 1
57 if not dist.is_initialized():
58 return 1
59 return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
60
61
62def is_main_process() -> bool:
63 return get_rank() == 0
64
65
67 """
68 Helper function to synchronize (barrier) among all processes when
69 using distributed training
70 """
71 if not dist.is_available():
72 return
73 if not dist.is_initialized():
74 return
75 world_size = dist.get_world_size()
76 if world_size == 1:
77 return
78 dist.barrier()
79
80
81@functools.lru_cache()
83 """
84 Return a process group based on gloo backend, containing all the ranks
85 The result is cached.
86 """
87 if dist.get_backend() == "nccl":
88 return dist.new_group(backend="gloo")
89 else:
90 return dist.group.WORLD
91
92
93def _serialize_to_tensor(data, group):
94 backend = dist.get_backend(group)
95 assert backend in ["gloo", "nccl"]
96 device = torch.device("cpu" if backend == "gloo" else "cuda")
97
98 buffer = pickle.dumps(data)
99 if len(buffer) > 1024 ** 3:
100 logger = logging.getLogger(__name__)
101 logger.warning(
102 "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
103 get_rank(), len(buffer) / (1024 ** 3), device
104 )
105 )
106 storage = torch.ByteStorage.from_buffer(buffer)
107 tensor = torch.ByteTensor(storage).to(device=device)
108 return tensor
109
110
111def _pad_to_largest_tensor(tensor, group):
112 """
113 Returns:
114 list[int]: size of the tensor, on each rank
115 Tensor: padded tensor that has the max size
116 """
117 world_size = dist.get_world_size(group=group)
118 assert (
119 world_size >= 1
120 ), "comm.gather/all_gather must be called from ranks within the given group!"
121 local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
122 size_list = [
123 torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
124 ]
125 dist.all_gather(size_list, local_size, group=group)
126 size_list = [int(size.item()) for size in size_list]
127
128 max_size = max(size_list)
129
130 # we pad the tensor because torch all_gather does not support
131 # gathering tensors of different shapes
132 if local_size != max_size:
133 padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
134 tensor = torch.cat((tensor, padding), dim=0)
135 return size_list, tensor
136
137
138def all_gather(data, group=None):
139 """
140 Run all_gather on arbitrary picklable data (not necessarily tensors).
141 Args:
142 data: any picklable object
143 group: a torch process group. By default, will use a group which
144 contains all ranks on gloo backend.
145 Returns:
146 list[data]: list of data gathered from each rank
147 """
148 if get_world_size() == 1:
149 return [data]
150 if group is None:
151 group = _get_global_gloo_group()
152 if dist.get_world_size(group) == 1:
153 return [data]
154
155 tensor = _serialize_to_tensor(data, group)
156
157 size_list, tensor = _pad_to_largest_tensor(tensor, group)
158 max_size = max(size_list)
159
160 # receiving Tensor from all ranks
161 tensor_list = [
162 torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
163 ]
164 dist.all_gather(tensor_list, tensor, group=group)
165
166 data_list = []
167 for size, tensor in zip(size_list, tensor_list):
168 buffer = tensor.cpu().numpy().tobytes()[:size]
169 data_list.append(pickle.loads(buffer))
170
171 return data_list
172
173
174def gather(data, dst=0, group=None):
175 """
176 Run gather on arbitrary picklable data (not necessarily tensors).
177 Args:
178 data: any picklable object
179 dst (int): destination rank
180 group: a torch process group. By default, will use a group which
181 contains all ranks on gloo backend.
182 Returns:
183 list[data]: on dst, a list of data gathered from each rank. Otherwise,
184 an empty list.
185 """
186 if get_world_size() == 1:
187 return [data]
188 if group is None:
189 group = _get_global_gloo_group()
190 if dist.get_world_size(group=group) == 1:
191 return [data]
192 rank = dist.get_rank(group=group)
193
194 tensor = _serialize_to_tensor(data, group)
195 size_list, tensor = _pad_to_largest_tensor(tensor, group)
196
197 # receiving Tensor from all ranks
198 if rank == dst:
199 max_size = max(size_list)
200 tensor_list = [
201 torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
202 ]
203 dist.gather(tensor, tensor_list, dst=dst, group=group)
204
205 data_list = []
206 for size, tensor in zip(size_list, tensor_list):
207 buffer = tensor.cpu().numpy().tobytes()[:size]
208 data_list.append(pickle.loads(buffer))
209 return data_list
210 else:
211 dist.gather(tensor, [], dst=dst, group=group)
212 return []
213
214
216 """
217 Returns:
218 int: a random number that is the same across all workers.
219 If workers need a shared RNG, they can use this shared seed to
220 create one.
221 All workers must call this function, otherwise it will deadlock.
222 """
223 ints = np.random.randint(2 ** 31)
224 all_ints = all_gather(ints)
225 return all_ints[0]
226
227
228def reduce_dict(input_dict, average=True):
229 """
230 Reduce the values in the dictionary from all processes so that process with rank
231 0 has the reduced results.
232 Args:
233 input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
234 average (bool): whether to do average or sum
235 Returns:
236 a dict with the same keys as input_dict, after reduction.
237 """
238 world_size = get_world_size()
239 if world_size < 2:
240 return input_dict
241 with torch.no_grad():
242 names = []
243 values = []
244 # sort the keys so that they are consistent across processes
245 for k in sorted(input_dict.keys()):
246 names.append(k)
247 values.append(input_dict[k])
248 values = torch.stack(values, dim=0)
249 dist.reduce(values, dst=0)
250 if dist.get_rank() == 0 and average:
251 # only main process gets accumulated, so only divide by
252 # world_size in this case
253 values /= world_size
254 reduced_dict = {k: v for k, v in zip(names, values)}
255 return reduced_dict
_serialize_to_tensor(data, group)
Definition comm.py:93
bool is_main_process()
Definition comm.py:62
int get_world_size()
Definition comm.py:20
reduce_dict(input_dict, average=True)
Definition comm.py:228
_pad_to_largest_tensor(tensor, group)
Definition comm.py:111
gather(data, dst=0, group=None)
Definition comm.py:174
all_gather(data, group=None)
Definition comm.py:138
int get_local_size()
Definition comm.py:49
_get_global_gloo_group()
Definition comm.py:82
int get_local_rank()
Definition comm.py:36