24 assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
25 self.register_buffer(
"pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
26 self.register_buffer(
"pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
30 bn_norm = cfg.MODEL.BACKBONE.NORM
31 with_se = cfg.MODEL.BACKBONE.WITH_SE
32 extra_bn = cfg.MODEL.BACKBONE.EXTRA_BN
35 backbone = build_backbone(cfg)
45 res_conv4 = nn.Sequential(*backbone.layer3[1:])
46 res_g_conv5 = backbone.layer4
48 res_p_conv5 = nn.Sequential(
49 Bottleneck(1024, 512, bn_norm,
False, with_se, downsample=nn.Sequential(
50 nn.Conv2d(1024, 2048, 1, bias=
False), get_norm(bn_norm, 2048))),
51 Bottleneck(2048, 512, bn_norm,
False, with_se),
52 Bottleneck(2048, 512, bn_norm,
False, with_se))
53 res_p_conv5.load_state_dict(backbone.layer4.state_dict())
56 self.
b1 = nn.Sequential(
57 copy.deepcopy(res_conv4),
58 copy.deepcopy(res_g_conv5)
63 self.
b2 = nn.Sequential(
64 copy.deepcopy(res_conv4),
65 copy.deepcopy(res_p_conv5)
72 self.
b3 = nn.Sequential(
73 copy.deepcopy(res_conv4),
74 copy.deepcopy(res_p_conv5)
96 b1_feat = self.
b1(features)
100 b2_feat = self.
b2(features)
102 b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
105 b3_feat = self.
b3(features)
107 b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
110 assert "targets" in batched_inputs,
"Person ID annotation are missing in training!"
111 targets = batched_inputs[
"targets"].long().to(self.
device)
113 if targets.sum() < 0: targets.zero_()
115 b1_outputs = self.
b1_head(b1_feat, targets)
116 b2_outputs = self.
b2_head(b2_feat, targets)
117 b21_outputs = self.
b21_head(b21_feat, targets)
118 b22_outputs = self.
b22_head(b22_feat, targets)
119 b3_outputs = self.
b3_head(b3_feat, targets)
120 b31_outputs = self.
b31_head(b31_feat, targets)
121 b32_outputs = self.
b32_head(b32_feat, targets)
122 b33_outputs = self.
b33_head(b33_feat, targets)
125 "b1_outputs": b1_outputs,
126 "b2_outputs": b2_outputs,
127 "b21_outputs": b21_outputs,
128 "b22_outputs": b22_outputs,
129 "b3_outputs": b3_outputs,
130 "b31_outputs": b31_outputs,
131 "b32_outputs": b32_outputs,
132 "b33_outputs": b33_outputs,
136 b1_pool_feat = self.
b1_head(b1_feat)
137 b2_pool_feat = self.
b2_head(b2_feat)
138 b21_pool_feat = self.
b21_head(b21_feat)
139 b22_pool_feat = self.
b22_head(b22_feat)
140 b3_pool_feat = self.
b3_head(b3_feat)
141 b31_pool_feat = self.
b31_head(b31_feat)
142 b32_pool_feat = self.
b32_head(b32_feat)
143 b33_pool_feat = self.
b33_head(b33_feat)
145 pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
146 b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
163 def losses(self, outs):
165 b1_outputs = outs[
"b1_outputs"]
166 b2_outputs = outs[
"b2_outputs"]
167 b21_outputs = outs[
"b21_outputs"]
168 b22_outputs = outs[
"b22_outputs"]
169 b3_outputs = outs[
"b3_outputs"]
170 b31_outputs = outs[
"b31_outputs"]
171 b32_outputs = outs[
"b32_outputs"]
172 b33_outputs = outs[
"b33_outputs"]
173 gt_labels = outs[
"targets"]
175 pred_class_logits = b1_outputs[
'pred_class_logits'].detach()
176 b1_logits = b1_outputs[
'cls_outputs']
177 b2_logits = b2_outputs[
'cls_outputs']
178 b21_logits = b21_outputs[
'cls_outputs']
179 b22_logits = b22_outputs[
'cls_outputs']
180 b3_logits = b3_outputs[
'cls_outputs']
181 b31_logits = b31_outputs[
'cls_outputs']
182 b32_logits = b32_outputs[
'cls_outputs']
183 b33_logits = b33_outputs[
'cls_outputs']
184 b1_pool_feat = b1_outputs[
'features']
185 b2_pool_feat = b2_outputs[
'features']
186 b3_pool_feat = b3_outputs[
'features']
187 b21_pool_feat = b21_outputs[
'features']
188 b22_pool_feat = b22_outputs[
'features']
189 b31_pool_feat = b31_outputs[
'features']
190 b32_pool_feat = b32_outputs[
'features']
191 b33_pool_feat = b33_outputs[
'features']
195 log_accuracy(pred_class_logits, gt_labels)
197 b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
198 b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
201 loss_names = self.
_cfg.MODEL.LOSSES.NAME
203 if "CrossEntropyLoss" in loss_names:
204 loss_dict[
'loss_cls_b1'] = cross_entropy_loss(
207 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
208 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
209 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
210 loss_dict[
'loss_cls_b2'] = cross_entropy_loss(
213 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
214 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
215 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
216 loss_dict[
'loss_cls_b21'] = cross_entropy_loss(
219 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
220 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
221 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
222 loss_dict[
'loss_cls_b22'] = cross_entropy_loss(
225 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
226 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
227 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
228 loss_dict[
'loss_cls_b3'] = cross_entropy_loss(
231 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
232 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
233 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
234 loss_dict[
'loss_cls_b31'] = cross_entropy_loss(
237 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
238 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
239 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
240 loss_dict[
'loss_cls_b32'] = cross_entropy_loss(
243 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
244 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
245 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
246 loss_dict[
'loss_cls_b33'] = cross_entropy_loss(
249 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
250 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
251 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE * 0.125
253 if "TripletLoss" in loss_names:
254 loss_dict[
'loss_triplet_b1'] = triplet_loss(
257 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
258 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
259 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
260 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE * 0.2
261 loss_dict[
'loss_triplet_b2'] = triplet_loss(
264 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
265 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
266 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
267 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE * 0.2
268 loss_dict[
'loss_triplet_b3'] = triplet_loss(
271 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
272 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
273 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
274 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE * 0.2
275 loss_dict[
'loss_triplet_b22'] = triplet_loss(
278 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
279 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
280 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
281 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE * 0.2
282 loss_dict[
'loss_triplet_b33'] = triplet_loss(
285 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
286 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
287 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
288 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE * 0.2