Safemotion Lib
Loading...
Searching...
No Matches
rank.py
Go to the documentation of this file.
1# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
2
3import warnings
4from collections import defaultdict
5
6import faiss
7import numpy as np
8
9try:
10 from .rank_cylib.rank_cy import evaluate_cy
11
12 IS_CYTHON_AVAI = True
13except ImportError:
14 IS_CYTHON_AVAI = False
15 warnings.warn(
16 'Cython rank evaluation (very fast so highly recommended) is '
17 'unavailable, now use python evaluation.'
18 )
19
20
21def eval_cuhk03(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
22 """Evaluation with cuhk03 metric
23 Key: one image for each gallery identity is randomly sampled for each query identity.
24 Random sampling is performed num_repeats times.
25 """
26 num_repeats = 10
27
28 num_q, num_g = distmat.shape
29 dim = q_feats.shape[1]
30
31 index = faiss.IndexFlatL2(dim)
32 index.add(g_feats)
33 if use_distmat:
34 indices = np.argsort(distmat, axis=1)
35 else:
36 _, indices = index.search(q_feats, k=num_g)
37
38 if num_g < max_rank:
39 max_rank = num_g
40 print(
41 'Note: number of gallery samples is quite small, got {}'.
42 format(num_g)
43 )
44
45 matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
46
47 # compute cmc curve for each query
48 all_cmc = []
49 all_AP = []
50 num_valid_q = 0. # number of valid query
51
52 for q_idx in range(num_q):
53 # get query pid and camid
54 q_pid = q_pids[q_idx]
55 q_camid = q_camids[q_idx]
56
57 # remove gallery samples that have the same pid and camid with query
58 order = indices[q_idx]
59 remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
60 keep = np.invert(remove)
61
62 # compute cmc curve
63 raw_cmc = matches[q_idx][
64 keep] # binary vector, positions with value 1 are correct matches
65 if not np.any(raw_cmc):
66 # this condition is true when query identity does not appear in gallery
67 continue
68
69 kept_g_pids = g_pids[order][keep]
70 g_pids_dict = defaultdict(list)
71 for idx, pid in enumerate(kept_g_pids):
72 g_pids_dict[pid].append(idx)
73
74 cmc = 0.
75 for repeat_idx in range(num_repeats):
76 mask = np.zeros(len(raw_cmc), dtype=np.bool)
77 for _, idxs in g_pids_dict.items():
78 # randomly sample one image for each gallery person
79 rnd_idx = np.random.choice(idxs)
80 mask[rnd_idx] = True
81 masked_raw_cmc = raw_cmc[mask]
82 _cmc = masked_raw_cmc.cumsum()
83 _cmc[_cmc > 1] = 1
84 cmc += _cmc[:max_rank].astype(np.float32)
85
86 cmc /= num_repeats
87 all_cmc.append(cmc)
88 # compute AP
89 num_rel = raw_cmc.sum()
90 tmp_cmc = raw_cmc.cumsum()
91 tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
92 tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
93 AP = tmp_cmc.sum() / num_rel
94 all_AP.append(AP)
95 num_valid_q += 1.
96
97 assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
98
99 all_cmc = np.asarray(all_cmc).astype(np.float32)
100 all_cmc = all_cmc.sum(0) / num_valid_q
101 mAP = np.mean(all_AP)
102
103 return all_cmc, mAP
104
105
106def eval_market1501(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
107 """Evaluation with market1501 metric
108 Key: for each query identity, its gallery images from the same camera view are discarded.
109 """
110 num_q, num_g = distmat.shape
111 dim = q_feats.shape[1]
112
113
114 q_feats = q_feats.detach().numpy()
115 g_feats = g_feats.detach().numpy()
116
117
118 index = faiss.IndexFlatL2(dim)
119 index.add(g_feats)
120
121 print(f'\nnum_q = {num_q}')
122 print(f'num_g = {num_g}')
123 print(f'dim = {dim}')
124 print(f'index = {index}')
125 print(f'max_rank = {max_rank}')
126 print(f'q_camids = {q_camids}')
127 print(f'g_camids = {g_camids}')
128
129
130 if num_g < max_rank:
131 max_rank = num_g
132 print('Note: number of gallery samples is quite small, got {}'.format(num_g))
133
134 if use_distmat:
135 indices = np.argsort(distmat, axis=1)
136 else:
137 _, indices = index.search(q_feats, k=num_g)
138
139 matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
140
141 print(f'matches = {matches}')
142
143 # compute cmc curve for each query
144 all_cmc = []
145 all_AP = []
146 all_INP = []
147 num_valid_q = 0. # number of valid query
148
149 for q_idx in range(num_q):
150 # get query pid and camid
151 q_pid = q_pids[q_idx]
152 q_camid = q_camids[q_idx]
153
154 # remove gallery samples that have the same pid and camid with query
155 order = indices[q_idx]
156 remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
157 keep = np.invert(remove)
158 print(f'order = {order}')
159 print(f'remove = {remove}')
160 print(f'keep = {keep}')
161 print(f'matches = {matches}')
162 print(f'matches.shape = {matches.shape}')
163
164
165 np.place(keep, keep == False, True)
166
167
168 # compute cmc curve
169 raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
170 print(f'raw_cmc = {raw_cmc}')
171 if not np.any(raw_cmc):
172 # this condition is true when query identity does not appear in gallery
173 continue
174
175 cmc = raw_cmc.cumsum()
176
177 pos_idx = np.where(raw_cmc == 1)
178 max_pos_idx = np.max(pos_idx)
179 inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
180 all_INP.append(inp)
181
182 cmc[cmc > 1] = 1
183
184 all_cmc.append(cmc[:max_rank])
185 num_valid_q += 1.
186
187 # compute average precision
188 # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
189 num_rel = raw_cmc.sum()
190 tmp_cmc = raw_cmc.cumsum()
191 tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
192 tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
193 AP = tmp_cmc.sum() / num_rel
194 all_AP.append(AP)
195
196 print(f'num_valid_q = {num_valid_q}')
197 print(f'all_INP = {all_INP}')
198 print(f'all_AP = {all_AP}')
199
200 assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
201
202 all_cmc = np.asarray(all_cmc).astype(np.float32)
203 all_cmc = all_cmc.sum(0) / num_valid_q
204
205 print(f'all_cmc = {all_cmc}')
206
207 return all_cmc, all_AP, all_INP
208
209
211 distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03, use_distmat
212):
213 if use_metric_cuhk03:
214 return eval_cuhk03(
215 distmat, q_feats, g_feats, g_pids, q_camids, g_camids, max_rank, use_distmat
216 )
217 else:
218 return eval_market1501(
219 distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat
220 )
221
222
224 distmat,
225 q_feats,
226 g_feats,
227 q_pids,
228 g_pids,
229 q_camids,
230 g_camids,
231 max_rank=50,
232 use_metric_cuhk03=False,
233 use_distmat=False,
234 use_cython=True
235):
236 """Evaluates CMC rank.
237 Args:
238 distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
239 q_feats (numpy.ndarray): 2-D array containing query features.
240 g_feats (numpy.ndarray): 2-D array containing gallery features.
241 q_pids (numpy.ndarray): 1-D array containing person identities
242 of each query instance.
243 g_pids (numpy.ndarray): 1-D array containing person identities
244 of each gallery instance.
245 q_camids (numpy.ndarray): 1-D array containing camera views under
246 which each query instance is captured.
247 g_camids (numpy.ndarray): 1-D array containing camera views under
248 which each gallery instance is captured.
249 max_rank (int, optional): maximum CMC rank to be computed. Default is 50.
250 use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
251 Default is False. This should be enabled when using cuhk03 classic split.
252 use_cython (bool, optional): use cython code for evaluation. Default is True.
253 This is highly recommended as the cython code can speed up the cmc computation
254 by more than 10x. This requires Cython to be installed.
255 """
256 if use_cython and IS_CYTHON_AVAI:
257 return evaluate_cy(
258 distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
259 use_metric_cuhk03, use_distmat
260 )
261 else:
262 return evaluate_py(
263 distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
264 use_metric_cuhk03, use_distmat
265 )
evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_distmat=False, use_cython=True)
Definition rank.py:235
evaluate_py(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03, use_distmat)
Definition rank.py:212
eval_cuhk03(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
Definition rank.py:21
eval_market1501(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
Definition rank.py:106