94 backend = dist.get_backend(group)
95 assert backend
in [
"gloo",
"nccl"]
96 device = torch.device(
"cpu" if backend ==
"gloo" else "cuda")
98 buffer = pickle.dumps(data)
99 if len(buffer) > 1024 ** 3:
100 logger = logging.getLogger(__name__)
102 "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
103 get_rank(), len(buffer) / (1024 ** 3), device
106 storage = torch.ByteStorage.from_buffer(buffer)
107 tensor = torch.ByteTensor(storage).to(device=device)
114 list[int]: size of the tensor, on each rank
115 Tensor: padded tensor that has the max size
117 world_size = dist.get_world_size(group=group)
120 ),
"comm.gather/all_gather must be called from ranks within the given group!"
121 local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
123 torch.zeros([1], dtype=torch.int64, device=tensor.device)
for _
in range(world_size)
125 dist.all_gather(size_list, local_size, group=group)
126 size_list = [int(size.item())
for size
in size_list]
128 max_size = max(size_list)
132 if local_size != max_size:
133 padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
134 tensor = torch.cat((tensor, padding), dim=0)
135 return size_list, tensor
140 Run all_gather on arbitrary picklable data (not necessarily tensors).
142 data: any picklable object
143 group: a torch process group. By default, will use a group which
144 contains all ranks on gloo backend.
146 list[data]: list of data gathered from each rank
152 if dist.get_world_size(group) == 1:
158 max_size = max(size_list)
162 torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
for _
in size_list
164 dist.all_gather(tensor_list, tensor, group=group)
167 for size, tensor
in zip(size_list, tensor_list):
168 buffer = tensor.cpu().numpy().tobytes()[:size]
169 data_list.append(pickle.loads(buffer))
176 Run gather on arbitrary picklable data (not necessarily tensors).
178 data: any picklable object
179 dst (int): destination rank
180 group: a torch process group. By default, will use a group which
181 contains all ranks on gloo backend.
183 list[data]: on dst, a list of data gathered from each rank. Otherwise,
190 if dist.get_world_size(group=group) == 1:
192 rank = dist.get_rank(group=group)
199 max_size = max(size_list)
201 torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
for _
in size_list
203 dist.gather(tensor, tensor_list, dst=dst, group=group)
206 for size, tensor
in zip(size_list, tensor_list):
207 buffer = tensor.cpu().numpy().tobytes()[:size]
208 data_list.append(pickle.loads(buffer))
211 dist.gather(tensor, [], dst=dst, group=group)
230 Reduce the values in the dictionary from all processes so that process with rank
231 0 has the reduced results.
233 input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
234 average (bool): whether to do average or sum
236 a dict with the same keys as input_dict, after reduction.
241 with torch.no_grad():
245 for k
in sorted(input_dict.keys()):
247 values.append(input_dict[k])
248 values = torch.stack(values, dim=0)
249 dist.reduce(values, dst=0)
250 if dist.get_rank() == 0
and average:
254 reduced_dict = {k: v
for k, v
in zip(names, values)}