20 feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
21 num_classes = cfg.MODEL.HEADS.NUM_CLASSES
22 pool_type = cfg.MODEL.HEADS.POOL_LAYER
23 cls_type = cfg.MODEL.HEADS.CLS_LAYER
24 with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
25 norm_type = cfg.MODEL.HEADS.NORM
28 elif pool_type ==
'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
29 elif pool_type ==
'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
34 elif pool_type ==
"identity": self.pool_layer = nn.Identity()
35 elif pool_type ==
"flatten": self.pool_layer =
Flatten()
36 else:
raise KeyError(f
"{pool_type} is not supported!")
39 if cls_type ==
'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=
False)
40 elif cls_type ==
'arcSoftmax': self.classifier =
ArcSoftmax(cfg, feat_dim, num_classes)
41 elif cls_type ==
'circleSoftmax': self.classifier =
CircleSoftmax(cfg, feat_dim, num_classes)
42 elif cls_type ==
'amSoftmax': self.classifier =
AMSoftmax(cfg, feat_dim, num_classes)
43 else:
raise KeyError(f
"{cls_type} is not supported!")
49 bottleneck = [nn.BatchNorm1d(num_classes)]
54 self.classifier.apply(weights_init_classifier)
56 def forward(self, features, targets=None):
58 See :class:`ReIDHeads.forward`.
60 global_feat = self.pool_layer(features)
61 global_feat = global_feat[..., 0, 0]
63 classifier_name = self.classifier.__class__.__name__
65 if classifier_name ==
'Linear': cls_outputs = self.classifier(global_feat)
66 else: cls_outputs = self.classifier(global_feat, targets)
73 "cls_outputs": cls_outputs,
76 cls_outputs = torch.sigmoid(cls_outputs)