Safemotion Lib
Loading...
Searching...
No Matches
triplet_loss.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8import torch.nn.functional as F
9
10from fastreid.utils import comm
11from fastreid.layers import GatherLayer
12from .utils import concat_all_gather, euclidean_dist, normalize
13
14
15def softmax_weights(dist, mask):
16 max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
17 diff = dist - max_v
18 Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
19 W = torch.exp(diff) * mask / Z
20 return W
21
22
23def hard_example_mining(dist_mat, is_pos, is_neg):
24 """For each anchor, find the hardest positive and negative sample.
25 Args:
26 dist_mat: pair wise distance between samples, shape [N, M]
27 is_pos: positive index with shape [N, M]
28 is_neg: negative index with shape [N, M]
29 Returns:
30 dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
31 dist_an: pytorch Variable, distance(anchor, negative); shape [N]
32 p_inds: pytorch LongTensor, with shape [N];
33 indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
34 n_inds: pytorch LongTensor, with shape [N];
35 indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
36 NOTE: Only consider the case in which all labels have same num of samples,
37 thus we can cope with all anchors in parallel.
38 """
39
40 assert len(dist_mat.size()) == 2
41 N = dist_mat.size(0)
42 # print(f'N = {N}')
43 # print(f'dist_mat = {dist_mat}')
44 # print(f'dist_mat.shape = {dist_mat.shape}')
45 # print(f'is_pos = {is_pos}')
46 # print(f'is_pos.shape = {is_pos.shape}')
47 # print(f'is_pos_sum = {torch.sum(is_pos)}')
48 # print(f'dist_mat[is_pos] = {dist_mat[is_pos]}')
49 # print(f'dist_mat[is_pos].shape = {dist_mat[is_pos].shape}')
50 # print(f'dist_mat[is_pos].contiguouse() = {dist_mat[is_pos].contiguous()}')
51 # print(f'dist_mat[is_pos].contiguouse().shape = {dist_mat[is_pos].contiguous().shape}')
52 # print(f'dist_mat[is_pos].contiguouse().view(N,-1) = {dist_mat[is_pos].contiguous().view(N,-1)}')
53
54
55 # `dist_ap` means distance(anchor, positive)
56 # both `dist_ap` and `relative_p_inds` with shape [N, 1]
57 dist_ap, relative_p_inds = torch.max(
58 dist_mat[is_pos].reshape(N,-1), 1, keepdim=True)
59 # dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
60 # `dist_an` means distance(anchor, negative)
61 # both `dist_an` and `relative_n_inds` with shape [N, 1]
62 dist_an, relative_n_inds = torch.min(
63 dist_mat[is_neg].reshape(N,-1), 1, keepdim=True)
64 # dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
65
66 # shape [N]
67 dist_ap = dist_ap.squeeze(1)
68 dist_an = dist_an.squeeze(1)
69
70 return dist_ap, dist_an
71
72
73def weighted_example_mining(dist_mat, is_pos, is_neg):
74 """For each anchor, find the weighted positive and negative sample.
75 Args:
76 dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
77 is_pos:
78 is_neg:
79 Returns:
80 dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
81 dist_an: pytorch Variable, distance(anchor, negative); shape [N]
82 """
83 assert len(dist_mat.size()) == 2
84
85 is_pos = is_pos.float()
86 is_neg = is_neg.float()
87 dist_ap = dist_mat * is_pos
88 dist_an = dist_mat * is_neg
89
90 weights_ap = softmax_weights(dist_ap, is_pos)
91 weights_an = softmax_weights(-dist_an, is_neg)
92
93 dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
94 dist_an = torch.sum(dist_an * weights_an, dim=1)
95
96 return dist_ap, dist_an
97
98
99def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
100 r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
101 Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
102 Loss for Person Re-Identification'."""
103
104 if norm_feat: embedding = normalize(embedding, axis=-1)
105
106 # For distributed training, gather all features from different process.
107 if comm.get_world_size() > 1:
108 all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
109 all_targets = concat_all_gather(targets)
110 else:
111 all_embedding = embedding
112 all_targets = targets
113
114 dist_mat = euclidean_dist(embedding, all_embedding)
115
116 N, M = dist_mat.size()
117 is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
118 is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
119
120 if hard_mining:
121 dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
122 else:
123 dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
124
125 y = dist_an.new().resize_as_(dist_an).fill_(1)
126
127 if margin > 0:
128 loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
129 else:
130 loss = F.soft_margin_loss(dist_an - dist_ap, y)
131 # fmt: off
132 if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
133 # fmt: on
134
135 return loss
hard_example_mining(dist_mat, is_pos, is_neg)
weighted_example_mining(dist_mat, is_pos, is_neg)