Safemotion Lib
Loading...
Searching...
No Matches
reid_evaluation.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6import copy
7import logging
8from collections import OrderedDict
9from sklearn import metrics
10
11import numpy as np
12import torch
13import torch.nn.functional as F
14
15from .evaluator import DatasetEvaluator
16from .query_expansion import aqe
17from .rank import evaluate_rank
18from .rerank import re_ranking
19from .roc import evaluate_roc
20from fastreid.utils import comm
21
22logger = logging.getLogger(__name__)
23
24
26 def __init__(self, cfg, num_query, output_dir=None):
27 self.cfg = cfg
28 self._num_query = num_query
29 self._output_dir = output_dir
30
31 self.features = []
32 self.pids = []
33 self.camids = []
34
35 def reset(self):
36 self.features = []
37 self.pids = []
38 self.camids = []
39
40 def process(self, inputs, outputs):
41 self.pids.extend(inputs["targets"])
42 self.camids.extend(inputs["camids"])
43 self.features.append(outputs.cpu())
44
45 @staticmethod
46 def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
47 assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
48 if metric == "cosine":
49 dist = 1 - torch.mm(query_feat, gallery_feat.t())
50 else:
51 m, n = query_feat.size(0), gallery_feat.size(0)
52 xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n)
53 yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t()
54 dist = xx + yy
55 dist.addmm_(query_feat, gallery_feat.t(), beta=1, alpha=-2)
56 dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
57 return dist.cpu().numpy()
58
59 def evaluate(self):
60 if comm.get_world_size() > 1:
61 comm.synchronize()
62 features = comm.gather(self.features)
63 features = sum(features, [])
64
65 pids = comm.gather(self.pids)
66 pids = sum(pids, [])
67
68 camids = comm.gather(self.camids)
69 camids = sum(camids, [])
70
71 # fmt: off
72 if not comm.is_main_process(): return {}
73 # fmt: on
74 else:
75 features = self.features
76 pids = self.pids
77 camids = self.camids
78
79 features = torch.cat(features, dim=0)
80 # query feature, person ids and camera ids
81 query_features = features[:self._num_query]
82 query_pids = np.asarray(pids[:self._num_query])
83 query_camids = np.asarray(camids[:self._num_query])
84
85 # gallery features, person ids and camera ids
86 gallery_features = features[self._num_query:]
87 gallery_pids = np.asarray(pids[self._num_query:])
88 gallery_camids = np.asarray(camids[self._num_query:])
89
90 self._results = OrderedDict()
91
92 if self.cfg.TEST.AQE.ENABLED:
93 logger.info("Test with AQE setting")
94 qe_time = self.cfg.TEST.AQE.QE_TIME
95 qe_k = self.cfg.TEST.AQE.QE_K
96 alpha = self.cfg.TEST.AQE.ALPHA
97 query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
98
99 if self.cfg.TEST.METRIC == "cosine":
100 query_features = F.normalize(query_features, dim=1)
101 gallery_features = F.normalize(gallery_features, dim=1)
102
103 dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
104
105 if self.cfg.TEST.RERANK.ENABLED:
106 logger.info("Test with rerank setting")
107 k1 = self.cfg.TEST.RERANK.K1
108 k2 = self.cfg.TEST.RERANK.K2
109 lambda_value = self.cfg.TEST.RERANK.LAMBDA
110 q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
111 g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
112 re_dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
113 query_features = query_features.numpy()
114 gallery_features = gallery_features.numpy()
115 cmc, all_AP, all_INP = evaluate_rank(re_dist, query_features, gallery_features,
116 query_pids, gallery_pids, query_camids,
117 gallery_camids, use_distmat=True)
118 else:
119 query_features = query_features.numpy()
120 gallery_features = gallery_features.numpy()
121 cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features,
122 query_pids, gallery_pids, query_camids, gallery_camids,
123 use_distmat=False)
124 mAP = np.mean(all_AP)
125 mINP = np.mean(all_INP)
126 for r in [1, 5, 10]:
127 self._results['Rank-{}'.format(r)] = cmc[r - 1]
128 self._results['mAP'] = mAP
129 self._results['mINP'] = mINP
130
131 if self.cfg.TEST.ROC_ENABLED:
132 scores, labels = evaluate_roc(dist, query_features, gallery_features,
133 query_pids, gallery_pids, query_camids, gallery_camids)
134 fprs, tprs, thres = metrics.roc_curve(labels, scores)
135
136 for fpr in [1e-4, 1e-3, 1e-2]:
137 ind = np.argmin(np.abs(fprs - fpr))
138 self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
139
140 return copy.deepcopy(self._results)
__init__(self, cfg, num_query, output_dir=None)
cal_dist(str metric, torch.tensor query_feat, torch.tensor gallery_feat)