Safemotion Lib
Loading...
Searching...
No Matches
baseline.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8from torch import nn
9
10from fastreid.layers import get_norm
11from fastreid.modeling.backbones import build_backbone
12from fastreid.modeling.heads import build_heads
13from fastreid.modeling.losses import *
14from .build import META_ARCH_REGISTRY
15
16
17@META_ARCH_REGISTRY.register()
18class Baseline(nn.Module):
19 def __init__(self, cfg):
20 super().__init__()
21 self._cfg = cfg
22 assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
23 self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
24 self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
25
26 # backbone
27 self.backbone = build_backbone(cfg)
28
29 # head
30 self.heads = build_heads(cfg)
31
32 self.has_extra_bn = cfg.MODEL.BACKBONE.EXTRA_BN
33 if self.has_extra_bn:
34 self.heads_extra_bn = get_norm(cfg.MODEL.BACKBONE.NORM, cfg.MODEL.BACKBONE.FEAT_DIM)
35
36 @property
37 def device(self):
38 return self.pixel_mean.device
39
40 def forward(self, batched_inputs):
41 images = self.preprocess_image(batched_inputs)
42 features = self.backbone(images)
43
44 if self.has_extra_bn: features = self.heads_extra_bn(features)
45 if self.training:
46 assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
47 targets = batched_inputs["targets"].to(self.device)
48
49 # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
50 # may be larger than that in the original dataset, so the circle/arcface will
51 # throw an error. We just set all the targets to 0 to avoid this problem.
52 if targets.sum() < 0: targets.zero_()
53
54 outputs = self.heads(features, targets)
55 return {
56 "outputs": outputs,
57 "targets": targets,
58 }
59 else:
60 outputs = self.heads(features)
61 return outputs
62
63 def preprocess_image(self, batched_inputs):
64 r"""
65 Normalize and batch the input images.
66 """
67 if isinstance(batched_inputs, dict):
68 images = batched_inputs["images"].to(self.device)
69 elif isinstance(batched_inputs, torch.Tensor):
70 images = batched_inputs.to(self.device)
71 else:
72 raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
73
74 images.sub_(self.pixel_mean).div_(self.pixel_std)
75 return images
76
77 def losses(self, outs):
78 r"""
79 Compute loss from modeling's outputs, the loss function input arguments
80 must be the same as the outputs of the model forwarding.
81 """
82 # fmt: off
83 outputs = outs["outputs"]
84 gt_labels = outs["targets"]
85 # model predictions
86 pred_class_logits = outputs['pred_class_logits'].detach()
87 cls_outputs = outputs['cls_outputs']
88 pred_features = outputs['features']
89 # fmt: on
90
91 # Log prediction accuracy
92 log_accuracy(pred_class_logits, gt_labels)
93
94 loss_dict = {}
95 loss_names = self._cfg.MODEL.LOSSES.NAME
96
97 if "CrossEntropyLoss" in loss_names:
98 loss_dict['loss_cls'] = cross_entropy_loss(
99 cls_outputs,
100 gt_labels,
101 self._cfg.MODEL.LOSSES.CE.EPSILON,
102 self._cfg.MODEL.LOSSES.CE.ALPHA,
103 ) * self._cfg.MODEL.LOSSES.CE.SCALE
104
105 if "TripletLoss" in loss_names:
106 loss_dict['loss_triplet'] = triplet_loss(
107 pred_features,
108 gt_labels,
109 self._cfg.MODEL.LOSSES.TRI.MARGIN,
110 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
111 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
112 ) * self._cfg.MODEL.LOSSES.TRI.SCALE
113
114 if "CircleLoss" in loss_names:
115 loss_dict['loss_circle'] = circle_loss(
116 pred_features,
117 gt_labels,
118 self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
119 self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
120 ) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
121
122 return loss_dict