12 Performs all_gather operation on the provided tensors.
13 *** Warning ***: torch.distributed.all_gather has no gradient.
15 tensors_gather = [torch.ones_like(tensor)
16 for _
in range(torch.distributed.get_world_size())]
17 torch.distributed.all_gather(tensors_gather, tensor, async_op=
False)
19 output = torch.cat(tensors_gather, dim=0)
35 m, n = x.size(0), y.size(0)
36 xx = torch.pow(x, 2).sum(1, keepdim=
True).expand(m, n)
37 yy = torch.pow(y, 2).sum(1, keepdim=
True).expand(n, m).t()
38 dist = xx + yy - 2 * torch.matmul(x, y.t())
39 dist = dist.clamp(min=1e-12).sqrt()
44 bs1, bs2 = x.size(0), y.size(0)
45 frac_up = torch.matmul(x, y.transpose(0, 1))
46 frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
47 (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
48 cosine = frac_up / frac_down