Safemotion Lib
Loading...
Searching...
No Matches
smooth_ap.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7# based on:
8# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
9
10import torch
11import torch.nn.functional as F
12
13from fastreid.utils import comm
14from fastreid.modeling.losses.utils import concat_all_gather
15
16
17def sigmoid(tensor, temp=1.0):
18 """ temperature controlled sigmoid
19 takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
20 """
21 exponent = -tensor / temp
22 # clamp the input tensor for stability
23 exponent = torch.clamp(exponent, min=-50, max=50)
24 y = 1.0 / (1.0 + torch.exp(exponent))
25 return y
26
27
28class SmoothAP(object):
29 r"""PyTorch implementation of the Smooth-AP loss.
30 implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
31 the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
32 have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
33 e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
34 labels = ( A, A, A, B, B, B, C, C, C)
35 (the order of the classes however does not matter)
36 For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
37 mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
38 same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
39 Args:
40 anneal : float
41 the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
42 results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
43 batch_size : int
44 the batch size being used during training.
45 num_id : int
46 the number of different classes that are represented in the batch.
47 feat_dims : int
48 the dimension of the input feature embeddings
49 Shape:
50 - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
51 - Output: scalar
52 Examples::
53 >>> loss = SmoothAP(0.01, 60, 6, 256)
54 >>> input = torch.randn(60, 256, requires_grad=True).cuda()
55 >>> output = loss(input)
56 >>> output.backward()
57 """
58
59 def __init__(self, cfg):
60 r"""
61 Parameters
62 ----------
63 cfg: (cfgNode)
64
65 anneal : float
66 the temperature of the sigmoid that is used to smooth the ranking function
67 batch_size : int
68 the batch size being used
69 num_id : int
70 the number of different classes that are represented in the batch
71 feat_dims : int
72 the dimension of the input feature embeddings
73 """
74
75 self.anneal = 0.01
76 self.num_id = cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE
77 # self.num_id = 6
78
79 def __call__(self, embedding, targets):
80 """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
81
82 # ------ differentiable ranking of all retrieval set ------
83 embedding = F.normalize(embedding, dim=1)
84
85 feat_dim = embedding.size(1)
86
87 # For distributed training, gather all features from different process.
88 if comm.get_world_size() > 1:
89 all_embedding = concat_all_gather(embedding)
90 all_targets = concat_all_gather(targets)
91 else:
92 all_embedding = embedding
93 all_targets = targets
94
95 sim_dist = torch.matmul(embedding, all_embedding.t())
96 N, M = sim_dist.size()
97
98 # Compute the mask which ignores the relevance score of the query to itself
99 mask_indx = 1.0 - torch.eye(M, device=sim_dist.device)
100 mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1) # (N, M, M)
101
102 # sim_dist -> N, 1, M -> N, M, N
103 sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M, 1) # (N, M, M)
104 # sim_dist_repeat_t = sim_dist.t().unsqueeze(dim=1).repeat(1, N, 1) # (N, N, M)
105
106 # Compute the difference matrix
107 sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2, 1) # (N, M, M)
108
109 # Pass through the sigmoid
110 sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask_indx
111
112 # Compute all the rankings
113 sim_all_rk = torch.sum(sim_sg, dim=-1) + 1 # (N, N)
114
115 pos_mask = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float() # (N, M)
116
117 pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1) # (N, M, M)
118
119 # Compute positive rankings
120 pos_sim_sg = sim_sg * pos_mask_repeat
121 sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1 # (N, N)
122
123 # sum the values of the Smooth-AP for all instances in the mini-batch
124 ap = 0
125 group = N // self.num_id
126 for ind in range(self.num_id):
127 pos_divide = torch.sum(
128 sim_pos_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
129 ap += pos_divide / torch.sum(pos_mask[ind*group]) / N
130 return 1 - ap
131
132
133class SmoothAP_old(torch.nn.Module):
134 """PyTorch implementation of the Smooth-AP loss.
135 implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
136 the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
137 have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
138 e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
139 labels = ( A, A, A, B, B, B, C, C, C)
140 (the order of the classes however does not matter)
141 For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
142 mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
143 same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
144 Args:
145 anneal : float
146 the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
147 results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
148 batch_size : int
149 the batch size being used during training.
150 num_id : int
151 the number of different classes that are represented in the batch.
152 feat_dims : int
153 the dimension of the input feature embeddings
154 Shape:
155 - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
156 - Output: scalar
157 Examples::
158 >>> loss = SmoothAP(0.01, 60, 6, 256)
159 >>> input = torch.randn(60, 256, requires_grad=True).cuda()
160 >>> output = loss(input)
161 >>> output.backward()
162 """
163
164 def __init__(self, anneal, batch_size, num_id, feat_dims):
165 """
166 Parameters
167 ----------
168 anneal : float
169 the temperature of the sigmoid that is used to smooth the ranking function
170 batch_size : int
171 the batch size being used
172 num_id : int
173 the number of different classes that are represented in the batch
174 feat_dims : int
175 the dimension of the input feature embeddings
176 """
177 super().__init__()
178
179 assert(batch_size%num_id==0)
180
181 self.anneal = anneal
182 self.batch_size = batch_size
183 self.num_id = num_id
184 self.feat_dims = feat_dims
185
186 def forward(self, preds):
187 """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
188
189 preds = F.normalize(preds, dim=1)
190 # ------ differentiable ranking of all retrieval set ------
191 # compute the mask which ignores the relevance score of the query to itself
192 mask = 1.0 - torch.eye(self.batch_size)
193 mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
194 # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
195 sim_all = torch.mm(preds, preds.t())
196 sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
197 # compute the difference matrix
198 sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
199 # pass through the sigmoid
200 sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask
201 # compute the rankings
202 sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
203
204 # ------ differentiable ranking of only positive set in retrieval set ------
205 # compute the mask which only gives non-zero weights to the positive set
206 xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
207 pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
208 pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
209 # compute the relevance scores
210 sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
211 sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
212 # compute the difference matrix
213 sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
214 # pass through the sigmoid
215 sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask
216 # compute the rankings of the positive set
217 sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
218
219 # sum the values of the Smooth-AP for all instances in the mini-batch
220 ap = torch.zeros(1)
221 group = int(self.batch_size / self.num_id)
222 for ind in range(self.num_id):
223 pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
224 ap = ap + ((pos_divide / group) / self.batch_size)
225
226 return 1-ap
227
228if __name__ == '__main__':
229 loss1 = SmoothAP(0.01)
230 loss2 = SmoothAP_old(0.01, 60, 6, 256)
231
232 inputs = torch.randn(60, 256, requires_grad=True)
233 targets = []
234 for i in range(6):
235 targets.extend([i]*10)
236 targets = torch.LongTensor(targets)
237
238 output1 = loss1(inputs, targets)
239 output2 = loss2(inputs)
240
241 print(torch.sum(output1 - output2))
__init__(self, anneal, batch_size, num_id, feat_dims)
Definition smooth_ap.py:164
__call__(self, embedding, targets)
Definition smooth_ap.py:79