Safemotion Lib
Loading...
Searching...
No Matches
utils.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8
9
11 """
12 Performs all_gather operation on the provided tensors.
13 *** Warning ***: torch.distributed.all_gather has no gradient.
14 """
15 tensors_gather = [torch.ones_like(tensor)
16 for _ in range(torch.distributed.get_world_size())]
17 torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
18
19 output = torch.cat(tensors_gather, dim=0)
20 return output
21
22
23def normalize(x, axis=-1):
24 """Normalizing to unit length along the specified dimension.
25 Args:
26 x: pytorch Variable
27 Returns:
28 x: pytorch Variable, same shape as input
29 """
30 x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
31 return x
32
33
35 m, n = x.size(0), y.size(0)
36 xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
37 yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
38 dist = xx + yy - 2 * torch.matmul(x, y.t())
39 dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
40 return dist
41
42
43def cosine_dist(x, y):
44 bs1, bs2 = x.size(0), y.size(0)
45 frac_up = torch.matmul(x, y.transpose(0, 1))
46 frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
47 (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
48 cosine = frac_up / frac_down
49 return 1 - cosine