20 feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
21 embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
22 num_classes = cfg.MODEL.HEADS.NUM_CLASSES
23 neck_feat = cfg.MODEL.HEADS.NECK_FEAT
24 pool_type = cfg.MODEL.HEADS.POOL_LAYER
25 cls_type = cfg.MODEL.HEADS.CLS_LAYER
26 with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
27 norm_type = cfg.MODEL.HEADS.NORM
30 elif pool_type ==
'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
31 elif pool_type ==
'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
36 elif pool_type ==
"identity": self.pool_layer = nn.Identity()
37 elif pool_type ==
"flatten": self.pool_layer =
Flatten()
38 else:
raise KeyError(f
"{pool_type} is not supported!")
45 bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=
False))
46 feat_dim = embedding_dim
49 bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=
True))
55 if cls_type ==
'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=
False)
56 elif cls_type ==
'arcSoftmax': self.classifier =
ArcSoftmax(cfg, feat_dim, num_classes)
57 elif cls_type ==
'circleSoftmax': self.classifier =
CircleSoftmax(cfg, feat_dim, num_classes)
58 elif cls_type ==
'amSoftmax': self.classifier =
AMSoftmax(cfg, feat_dim, num_classes)
59 else:
raise KeyError(f
"{cls_type} is not supported!")
63 self.classifier.apply(weights_init_classifier)
65 def forward(self, features, targets=None):
67 See :class:`ReIDHeads.forward`.
69 global_feat = self.pool_layer(features)
71 bn_feat = bn_feat[..., 0, 0]
75 if not self.training:
return bn_feat
79 if self.classifier.__class__.__name__ ==
'Linear':
80 cls_outputs = self.classifier(bn_feat)
81 pred_class_logits = F.linear(bn_feat, self.classifier.weight)
83 cls_outputs = self.classifier(bn_feat, targets)
84 pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
85 F.normalize(self.classifier.weight))
88 if self.
neck_feat ==
"before": feat = global_feat[..., 0, 0]
89 elif self.
neck_feat ==
"after": feat = bn_feat
90 else:
raise KeyError(f
"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
94 "cls_outputs": cls_outputs,
95 "pred_class_logits": pred_class_logits,