80 """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
83 embedding = F.normalize(embedding, dim=1)
85 feat_dim = embedding.size(1)
88 if comm.get_world_size() > 1:
89 all_embedding = concat_all_gather(embedding)
90 all_targets = concat_all_gather(targets)
92 all_embedding = embedding
95 sim_dist = torch.matmul(embedding, all_embedding.t())
96 N, M = sim_dist.size()
99 mask_indx = 1.0 - torch.eye(M, device=sim_dist.device)
100 mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1)
103 sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M, 1)
107 sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2, 1)
113 sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
115 pos_mask = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()
117 pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1)
120 pos_sim_sg = sim_sg * pos_mask_repeat
121 sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1
126 for ind
in range(self.
num_id):
127 pos_divide = torch.sum(
128 sim_pos_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
129 ap += pos_divide / torch.sum(pos_mask[ind*group]) / N