Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Static Public Member Functions | Public Attributes | List of all members
fastreid.utils.visualizer.Visualizer Class Reference

Public Member Functions

 __init__ (self, dataset)
 
 get_model_output (self, all_ap, dist, q_pids, g_pids, q_camids, g_camids)
 
 get_matched_result (self, q_index)
 
 save_rank_result (self, query_indices, output, max_rank=5, vis_label=False, label_sort='ascending', actmap=False)
 
 vis_rank_list (self, output, vis_label, num_vis=100, rank_sort="ascending", label_sort="ascending", max_rank=5, actmap=False)
 
 vis_roc_curve (self, output)
 

Static Public Member Functions

 plot_roc_curve (fpr, tpr, name='model', fig=None)
 
 plot_distribution (pos, neg, name='model', fig=None)
 
 save_roc_info (output, fpr, tpr, pos, neg)
 
 load_roc_info (path)
 

Public Attributes

 dataset
 
 all_ap
 
 dist
 
 sim
 
 q_pids
 
 g_pids
 
 q_camids
 
 g_camids
 
 indices
 
 matches
 
 num_query
 

Detailed Description

Visualize images(activation map) ranking list of features generated by reid models.

Definition at line 20 of file visualizer.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.utils.visualizer.Visualizer.__init__ ( self,
dataset )

Definition at line 23 of file visualizer.py.

23 def __init__(self, dataset):
24 self.dataset = dataset
25

Member Function Documentation

◆ get_matched_result()

fastreid.utils.visualizer.Visualizer.get_matched_result ( self,
q_index )

Definition at line 40 of file visualizer.py.

40 def get_matched_result(self, q_index):
41 q_pid = self.q_pids[q_index]
42 q_camid = self.q_camids[q_index]
43
44 order = self.indices[q_index]
45 remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
46 keep = np.invert(remove)
47 cmc = self.matches[q_index][keep]
48 sort_idx = order[keep]
49 return cmc, sort_idx
50

◆ get_model_output()

fastreid.utils.visualizer.Visualizer.get_model_output ( self,
all_ap,
dist,
q_pids,
g_pids,
q_camids,
g_camids )

Definition at line 26 of file visualizer.py.

26 def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids):
27 self.all_ap = all_ap
28 self.dist = dist
29 self.sim = 1 - dist
30 self.q_pids = q_pids
31 self.g_pids = g_pids
32 self.q_camids = q_camids
33 self.g_camids = g_camids
34
35 self.indices = np.argsort(dist, axis=1)
36 self.matches = (g_pids[self.indices] == q_pids[:, np.newaxis]).astype(np.int32)
37
38 self.num_query = len(q_pids)
39

◆ load_roc_info()

fastreid.utils.visualizer.Visualizer.load_roc_info ( path)
static

Definition at line 245 of file visualizer.py.

245 def load_roc_info(path):
246 with open(path, 'rb') as handle: res = pickle.load(handle)
247 return res
248

◆ plot_distribution()

fastreid.utils.visualizer.Visualizer.plot_distribution ( pos,
neg,
name = 'model',
fig = None )
static

Definition at line 207 of file visualizer.py.

207 def plot_distribution(pos, neg, name='model', fig=None):
208 if fig is None:
209 fig = plt.figure()
210 pos_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
211 n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True,
212 color=pos_color,
213 label='positive with {}'.format(name))
214 mu = np.mean(pos)
215 sigma = np.std(pos)
216 y = norm.pdf(bins, mu, sigma) # fitting curve
217 plt.plot(bins, y, color=pos_color) # plot y curve
218
219 neg_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
220 n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True,
221 color=neg_color,
222 label='negative with {}'.format(name))
223 mu = np.mean(neg)
224 sigma = np.std(neg)
225 y = norm.pdf(bins, mu, sigma) # fitting curve
226 plt.plot(bins, y, color=neg_color) # plot y curve
227
228 plt.xticks(np.arange(0, 1.5, 0.1))
229 plt.title('positive and negative pairs distribution')
230 plt.legend(loc='best')
231 return fig
232

◆ plot_roc_curve()

fastreid.utils.visualizer.Visualizer.plot_roc_curve ( fpr,
tpr,
name = 'model',
fig = None )
static

Definition at line 194 of file visualizer.py.

194 def plot_roc_curve(fpr, tpr, name='model', fig=None):
195 if fig is None:
196 fig = plt.figure()
197 plt.semilogx(np.arange(0, 1, 0.01), np.arange(0, 1, 0.01), 'r', linestyle='--', label='Random guess')
198 plt.semilogx(fpr, tpr, color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)),
199 label='ROC curve with {}'.format(name))
200 plt.title('Receiver Operating Characteristic')
201 plt.xlabel('False Positive Rate')
202 plt.ylabel('True Positive Rate')
203 plt.legend(loc='best')
204 return fig
205

◆ save_rank_result()

fastreid.utils.visualizer.Visualizer.save_rank_result ( self,
query_indices,
output,
max_rank = 5,
vis_label = False,
label_sort = 'ascending',
actmap = False )

Definition at line 51 of file visualizer.py.

52 actmap=False):
53 if vis_label:
54 fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12))
55 else:
56 fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
57 for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
58 all_imgs = []
59 cmc, sort_idx = self.get_matched_result(q_idx)
60 query_info = self.dataset[q_idx]
61 query_img = query_info['images']
62 cam_id = query_info['camids']
63
64 # img_paths를 안넣었기 때문에 이미지 이름을 임시로 설정
65
66 if 'img_paths' in query_info:
67 query_name = query_info['img_paths'].split('/')[-1]
68 else:
69 query_name = f'{cnt}.jpg'
70
71 all_imgs.append(query_img)
72 query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
73 plt.clf()
74 ax = fig.add_subplot(1, max_rank + 1, 1)
75 ax.imshow(query_img)
76 ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
77 ax.axis("off")
78 for i in range(max_rank):
79 if vis_label:
80 ax = fig.add_subplot(2, max_rank + 1, i + 2)
81 else:
82 ax = fig.add_subplot(1, max_rank + 1, i + 2)
83 g_idx = self.num_query + sort_idx[i]
84 gallery_info = self.dataset[g_idx]
85 gallery_img = gallery_info['images']
86 cam_id = gallery_info['camids']
87 all_imgs.append(gallery_img)
88 gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
89 if cmc[i] == 1:
90 label = 'true'
91 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
92 height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
93 fill=False, linewidth=5))
94 else:
95 label = 'false'
96 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
97 height=gallery_img.shape[0] - 1,
98 edgecolor=(0, 0, 1), fill=False, linewidth=5))
99 ax.imshow(gallery_img)
100 ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
101 ax.axis("off")
102 # if actmap:
103 # act_outputs = []
104 #
105 # def hook_fns_forward(module, input, output):
106 # act_outputs.append(output.cpu())
107 #
108 # all_imgs = np.stack(all_imgs, axis=0) # (b, 3, h, w)
109 # all_imgs = torch.from_numpy(all_imgs).float()
110 # # normalize
111 # all_imgs = all_imgs.sub_(self.mean).div_(self.std)
112 # sz = list(all_imgs.shape[-2:])
113 # handle = m.base.register_forward_hook(hook_fns_forward)
114 # with torch.no_grad():
115 # _ = m(all_imgs.cuda())
116 # handle.remove()
117 # acts = self.get_actmap(act_outputs[0], sz)
118 # for i in range(top + 1):
119 # axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
120 if vis_label:
121 label_indice = np.where(cmc == 1)[0]
122 if label_sort == "ascending": label_indice = label_indice[::-1]
123 label_indice = label_indice[:max_rank]
124 for i in range(max_rank):
125 if i >= len(label_indice): break
126 j = label_indice[i]
127 g_idx = self.num_query + sort_idx[j]
128 gallery_info = self.dataset[g_idx]
129 gallery_img = gallery_info['images']
130 cam_id = gallery_info['camids']
131 gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
132 ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i)
133 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
134 height=gallery_img.shape[0] - 1,
135 edgecolor=(1, 0, 0),
136 fill=False, linewidth=5))
137 ax.imshow(gallery_img)
138 ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}')
139 ax.axis("off")
140
141 plt.tight_layout()
142 filepath = os.path.join(output, "{}.jpg".format(cnt))
143 fig.savefig(filepath)
144
145 plt.close(fig)
146

◆ save_roc_info()

fastreid.utils.visualizer.Visualizer.save_roc_info ( output,
fpr,
tpr,
pos,
neg )
static

Definition at line 234 of file visualizer.py.

234 def save_roc_info(output, fpr, tpr, pos, neg):
235 results = {
236 "fpr": np.asarray(fpr),
237 "tpr": np.asarray(tpr),
238 "pos": np.asarray(pos),
239 "neg": np.asarray(neg),
240 }
241 with open(os.path.join(output, "roc_info.pickle"), "wb") as handle:
242 pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
243

◆ vis_rank_list()

fastreid.utils.visualizer.Visualizer.vis_rank_list ( self,
output,
vis_label,
num_vis = 100,
rank_sort = "ascending",
label_sort = "ascending",
max_rank = 5,
actmap = False )
Visualize rank list of query instance
Args:
output (str): a directory to save rank list result.
vis_label (bool): if visualize label of query
num_vis (int):
rank_sort (str): save visualization results by which order,
    if rank_sort is ascending, AP from low to high, vice versa.
label_sort (bool):
max_rank (int): maximum number of rank result to visualize
actmap (bool):

Definition at line 147 of file visualizer.py.

148 actmap=False):
149 r"""Visualize rank list of query instance
150 Args:
151 output (str): a directory to save rank list result.
152 vis_label (bool): if visualize label of query
153 num_vis (int):
154 rank_sort (str): save visualization results by which order,
155 if rank_sort is ascending, AP from low to high, vice versa.
156 label_sort (bool):
157 max_rank (int): maximum number of rank result to visualize
158 actmap (bool):
159 """
160 assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)
161
162 query_indices = np.argsort(self.all_ap)
163 if rank_sort == 'descending': query_indices = query_indices[::-1]
164
165 query_indices = query_indices[:num_vis]
166 self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)
167

◆ vis_roc_curve()

fastreid.utils.visualizer.Visualizer.vis_roc_curve ( self,
output )

Definition at line 168 of file visualizer.py.

168 def vis_roc_curve(self, output):
169 PathManager.mkdirs(output)
170 pos, neg = [], []
171 for i, q in enumerate(self.q_pids):
172 cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
173 ind_pos = np.where(cmc == 1)[0]
174 q_dist = self.dist[i]
175 pos.extend(q_dist[sort_idx[ind_pos]])
176
177 ind_neg = np.where(cmc == 0)[0]
178 neg.extend(q_dist[sort_idx[ind_neg]])
179
180 scores = np.hstack((pos, neg))
181 labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
182
183 fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
184
185 self.plot_roc_curve(fpr, tpr)
186 filepath = os.path.join(output, "roc.jpg")
187 plt.savefig(filepath)
188 self.plot_distribution(pos, neg)
189 filepath = os.path.join(output, "pos_neg_dist.jpg")
190 plt.savefig(filepath)
191 return fpr, tpr, pos, neg
192

Member Data Documentation

◆ all_ap

fastreid.utils.visualizer.Visualizer.all_ap

Definition at line 27 of file visualizer.py.

◆ dataset

fastreid.utils.visualizer.Visualizer.dataset

Definition at line 24 of file visualizer.py.

◆ dist

fastreid.utils.visualizer.Visualizer.dist

Definition at line 28 of file visualizer.py.

◆ g_camids

fastreid.utils.visualizer.Visualizer.g_camids

Definition at line 33 of file visualizer.py.

◆ g_pids

fastreid.utils.visualizer.Visualizer.g_pids

Definition at line 31 of file visualizer.py.

◆ indices

fastreid.utils.visualizer.Visualizer.indices

Definition at line 35 of file visualizer.py.

◆ matches

fastreid.utils.visualizer.Visualizer.matches

Definition at line 36 of file visualizer.py.

◆ num_query

fastreid.utils.visualizer.Visualizer.num_query

Definition at line 38 of file visualizer.py.

◆ q_camids

fastreid.utils.visualizer.Visualizer.q_camids

Definition at line 32 of file visualizer.py.

◆ q_pids

fastreid.utils.visualizer.Visualizer.q_pids

Definition at line 30 of file visualizer.py.

◆ sim

fastreid.utils.visualizer.Visualizer.sim

Definition at line 29 of file visualizer.py.


The documentation for this class was generated from the following file: