Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Protected Attributes | List of all members
fastreid.modeling.meta_arch.baseline.Baseline Class Reference
Inheritance diagram for fastreid.modeling.meta_arch.baseline.Baseline:

Public Member Functions

 __init__ (self, cfg)
 
 device (self)
 
 forward (self, batched_inputs)
 
 preprocess_image (self, batched_inputs)
 
 losses (self, outs)
 

Public Attributes

 backbone
 
 heads
 
 has_extra_bn
 
 heads_extra_bn
 
 pixel_mean
 
 pixel_std
 

Protected Attributes

 _cfg
 

Detailed Description

Definition at line 18 of file baseline.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.modeling.meta_arch.baseline.Baseline.__init__ ( self,
cfg )

Definition at line 19 of file baseline.py.

19 def __init__(self, cfg):
20 super().__init__()
21 self._cfg = cfg
22 assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
23 self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
24 self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
25
26 # backbone
27 self.backbone = build_backbone(cfg)
28
29 # head
30 self.heads = build_heads(cfg)
31
32 self.has_extra_bn = cfg.MODEL.BACKBONE.EXTRA_BN
33 if self.has_extra_bn:
34 self.heads_extra_bn = get_norm(cfg.MODEL.BACKBONE.NORM, cfg.MODEL.BACKBONE.FEAT_DIM)
35

Member Function Documentation

◆ device()

fastreid.modeling.meta_arch.baseline.Baseline.device ( self)

Definition at line 37 of file baseline.py.

37 def device(self):
38 return self.pixel_mean.device
39

◆ forward()

fastreid.modeling.meta_arch.baseline.Baseline.forward ( self,
batched_inputs )

Definition at line 40 of file baseline.py.

40 def forward(self, batched_inputs):
41 images = self.preprocess_image(batched_inputs)
42 features = self.backbone(images)
43
44 if self.has_extra_bn: features = self.heads_extra_bn(features)
45 if self.training:
46 assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
47 targets = batched_inputs["targets"].to(self.device)
48
49 # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
50 # may be larger than that in the original dataset, so the circle/arcface will
51 # throw an error. We just set all the targets to 0 to avoid this problem.
52 if targets.sum() < 0: targets.zero_()
53
54 outputs = self.heads(features, targets)
55 return {
56 "outputs": outputs,
57 "targets": targets,
58 }
59 else:
60 outputs = self.heads(features)
61 return outputs
62

◆ losses()

fastreid.modeling.meta_arch.baseline.Baseline.losses ( self,
outs )
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.

Definition at line 77 of file baseline.py.

77 def losses(self, outs):
78 r"""
79 Compute loss from modeling's outputs, the loss function input arguments
80 must be the same as the outputs of the model forwarding.
81 """
82 # fmt: off
83 outputs = outs["outputs"]
84 gt_labels = outs["targets"]
85 # model predictions
86 pred_class_logits = outputs['pred_class_logits'].detach()
87 cls_outputs = outputs['cls_outputs']
88 pred_features = outputs['features']
89 # fmt: on
90
91 # Log prediction accuracy
92 log_accuracy(pred_class_logits, gt_labels)
93
94 loss_dict = {}
95 loss_names = self._cfg.MODEL.LOSSES.NAME
96
97 if "CrossEntropyLoss" in loss_names:
98 loss_dict['loss_cls'] = cross_entropy_loss(
99 cls_outputs,
100 gt_labels,
101 self._cfg.MODEL.LOSSES.CE.EPSILON,
102 self._cfg.MODEL.LOSSES.CE.ALPHA,
103 ) * self._cfg.MODEL.LOSSES.CE.SCALE
104
105 if "TripletLoss" in loss_names:
106 loss_dict['loss_triplet'] = triplet_loss(
107 pred_features,
108 gt_labels,
109 self._cfg.MODEL.LOSSES.TRI.MARGIN,
110 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
111 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
112 ) * self._cfg.MODEL.LOSSES.TRI.SCALE
113
114 if "CircleLoss" in loss_names:
115 loss_dict['loss_circle'] = circle_loss(
116 pred_features,
117 gt_labels,
118 self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
119 self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
120 ) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
121
122 return loss_dict

◆ preprocess_image()

fastreid.modeling.meta_arch.baseline.Baseline.preprocess_image ( self,
batched_inputs )
Normalize and batch the input images.

Definition at line 63 of file baseline.py.

63 def preprocess_image(self, batched_inputs):
64 r"""
65 Normalize and batch the input images.
66 """
67 if isinstance(batched_inputs, dict):
68 images = batched_inputs["images"].to(self.device)
69 elif isinstance(batched_inputs, torch.Tensor):
70 images = batched_inputs.to(self.device)
71 else:
72 raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
73
74 images.sub_(self.pixel_mean).div_(self.pixel_std)
75 return images
76

Member Data Documentation

◆ _cfg

fastreid.modeling.meta_arch.baseline.Baseline._cfg
protected

Definition at line 21 of file baseline.py.

◆ backbone

fastreid.modeling.meta_arch.baseline.Baseline.backbone

Definition at line 27 of file baseline.py.

◆ has_extra_bn

fastreid.modeling.meta_arch.baseline.Baseline.has_extra_bn

Definition at line 32 of file baseline.py.

◆ heads

fastreid.modeling.meta_arch.baseline.Baseline.heads

Definition at line 30 of file baseline.py.

◆ heads_extra_bn

fastreid.modeling.meta_arch.baseline.Baseline.heads_extra_bn

Definition at line 34 of file baseline.py.

◆ pixel_mean

fastreid.modeling.meta_arch.baseline.Baseline.pixel_mean

Definition at line 74 of file baseline.py.

◆ pixel_std

fastreid.modeling.meta_arch.baseline.Baseline.pixel_std

Definition at line 74 of file baseline.py.


The documentation for this class was generated from the following file: