52 actmap=False):
53 if vis_label:
54 fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12))
55 else:
56 fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
57 for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
58 all_imgs = []
59 cmc, sort_idx = self.get_matched_result(q_idx)
60 query_info = self.dataset[q_idx]
61 query_img = query_info['images']
62 cam_id = query_info['camids']
63
64
65
66 if 'img_paths' in query_info:
67 query_name = query_info['img_paths'].split('/')[-1]
68 else:
69 query_name = f'{cnt}.jpg'
70
71 all_imgs.append(query_img)
72 query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
73 plt.clf()
74 ax = fig.add_subplot(1, max_rank + 1, 1)
75 ax.imshow(query_img)
76 ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
77 ax.axis("off")
78 for i in range(max_rank):
79 if vis_label:
80 ax = fig.add_subplot(2, max_rank + 1, i + 2)
81 else:
82 ax = fig.add_subplot(1, max_rank + 1, i + 2)
83 g_idx = self.num_query + sort_idx[i]
84 gallery_info = self.dataset[g_idx]
85 gallery_img = gallery_info['images']
86 cam_id = gallery_info['camids']
87 all_imgs.append(gallery_img)
88 gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
89 if cmc[i] == 1:
90 label = 'true'
91 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
92 height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
93 fill=False, linewidth=5))
94 else:
95 label = 'false'
96 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
97 height=gallery_img.shape[0] - 1,
98 edgecolor=(0, 0, 1), fill=False, linewidth=5))
99 ax.imshow(gallery_img)
100 ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
101 ax.axis("off")
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120 if vis_label:
121 label_indice = np.where(cmc == 1)[0]
122 if label_sort == "ascending": label_indice = label_indice[::-1]
123 label_indice = label_indice[:max_rank]
124 for i in range(max_rank):
125 if i >= len(label_indice): break
126 j = label_indice[i]
127 g_idx = self.num_query + sort_idx[j]
128 gallery_info = self.dataset[g_idx]
129 gallery_img = gallery_info['images']
130 cam_id = gallery_info['camids']
131 gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
132 ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i)
133 ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
134 height=gallery_img.shape[0] - 1,
135 edgecolor=(1, 0, 0),
136 fill=False, linewidth=5))
137 ax.imshow(gallery_img)
138 ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}')
139 ax.axis("off")
140
141 plt.tight_layout()
142 filepath = os.path.join(output, "{}.jpg".format(cnt))
143 fig.savefig(filepath)
144
145 plt.close(fig)
146