22 assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
23 self.register_buffer(
"pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
24 self.register_buffer(
"pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
34 self.
heads_extra_bn = get_norm(cfg.MODEL.BACKBONE.NORM, cfg.MODEL.BACKBONE.FEAT_DIM)
77 def losses(self, outs):
79 Compute loss from modeling's outputs, the loss function input arguments
80 must be the same as the outputs of the model forwarding.
83 outputs = outs[
"outputs"]
84 gt_labels = outs[
"targets"]
86 pred_class_logits = outputs[
'pred_class_logits'].detach()
87 cls_outputs = outputs[
'cls_outputs']
88 pred_features = outputs[
'features']
92 log_accuracy(pred_class_logits, gt_labels)
95 loss_names = self.
_cfg.MODEL.LOSSES.NAME
97 if "CrossEntropyLoss" in loss_names:
98 loss_dict[
'loss_cls'] = cross_entropy_loss(
101 self.
_cfg.MODEL.LOSSES.CE.EPSILON,
102 self.
_cfg.MODEL.LOSSES.CE.ALPHA,
103 ) * self.
_cfg.MODEL.LOSSES.CE.SCALE
105 if "TripletLoss" in loss_names:
106 loss_dict[
'loss_triplet'] = triplet_loss(
109 self.
_cfg.MODEL.LOSSES.TRI.MARGIN,
110 self.
_cfg.MODEL.LOSSES.TRI.NORM_FEAT,
111 self.
_cfg.MODEL.LOSSES.TRI.HARD_MINING,
112 ) * self.
_cfg.MODEL.LOSSES.TRI.SCALE
114 if "CircleLoss" in loss_names:
115 loss_dict[
'loss_circle'] = circle_loss(
118 self.
_cfg.MODEL.LOSSES.CIRCLE.MARGIN,
119 self.
_cfg.MODEL.LOSSES.CIRCLE.ALPHA,
120 ) * self.
_cfg.MODEL.LOSSES.CIRCLE.SCALE