99def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
100 r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
101 Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
102 Loss for Person Re-Identification'."""
103
104 if norm_feat: embedding = normalize(embedding, axis=-1)
105
106
107 if comm.get_world_size() > 1:
108 all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
109 all_targets = concat_all_gather(targets)
110 else:
111 all_embedding = embedding
112 all_targets = targets
113
114 dist_mat = euclidean_dist(embedding, all_embedding)
115
116 N, M = dist_mat.size()
117 is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
118 is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
119
120 if hard_mining:
121 dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
122 else:
123 dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
124
125 y = dist_an.new().resize_as_(dist_an).fill_(1)
126
127 if margin > 0:
128 loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
129 else:
130 loss = F.soft_margin_loss(dist_an - dist_ap, y)
131
132 if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
133
134
135 return loss