Safemotion Lib
Loading...
Searching...
No Matches
mgn.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6import copy
7
8import torch
9from torch import nn
10
11from fastreid.layers import get_norm
12from fastreid.modeling.backbones import build_backbone
13from fastreid.modeling.backbones.resnet import Bottleneck
14from fastreid.modeling.heads import build_heads
15from fastreid.modeling.losses import *
16from .build import META_ARCH_REGISTRY
17
18
19@META_ARCH_REGISTRY.register()
20class MGN(nn.Module):
21 def __init__(self, cfg):
22 super().__init__()
23 self._cfg = cfg
24 assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
25 self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
26 self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
27
28 # fmt: off
29 # backbone
30 bn_norm = cfg.MODEL.BACKBONE.NORM
31 with_se = cfg.MODEL.BACKBONE.WITH_SE
32 extra_bn = cfg.MODEL.BACKBONE.EXTRA_BN
33 # fmt :on
34
35 backbone = build_backbone(cfg)
36 self.backbone = nn.Sequential(
37 backbone.conv1,
38 backbone.bn1,
39 backbone.relu,
40 backbone.maxpool,
41 backbone.layer1,
42 backbone.layer2,
43 backbone.layer3[0]
44 )
45 res_conv4 = nn.Sequential(*backbone.layer3[1:])
46 res_g_conv5 = backbone.layer4
47
48 res_p_conv5 = nn.Sequential(
49 Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
50 nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
51 Bottleneck(2048, 512, bn_norm, False, with_se),
52 Bottleneck(2048, 512, bn_norm, False, with_se))
53 res_p_conv5.load_state_dict(backbone.layer4.state_dict())
54
55 # branch1
56 self.b1 = nn.Sequential(
57 copy.deepcopy(res_conv4),
58 copy.deepcopy(res_g_conv5)
59 )
60 self.b1_head = build_heads(cfg)
61
62 # branch2
63 self.b2 = nn.Sequential(
64 copy.deepcopy(res_conv4),
65 copy.deepcopy(res_p_conv5)
66 )
67 self.b2_head = build_heads(cfg)
68 self.b21_head = build_heads(cfg)
69 self.b22_head = build_heads(cfg)
70
71 # branch3
72 self.b3 = nn.Sequential(
73 copy.deepcopy(res_conv4),
74 copy.deepcopy(res_p_conv5)
75 )
76 self.b3_head = build_heads(cfg)
77 self.b31_head = build_heads(cfg)
78 self.b32_head = build_heads(cfg)
79 self.b33_head = build_heads(cfg)
80
81 self.has_extra_bn = extra_bn
82 if extra_bn:
83 self.heads_extra_bn_b1 = get_norm(bn_norm, 2048)
84 self.heads_extra_bn_b2 = get_norm(bn_norm, 2048)
85 self.heads_extra_bn_b3 = get_norm(bn_norm, 2048)
86
87 @property
88 def device(self):
89 return self.pixel_mean.device
90
91 def forward(self, batched_inputs):
92 images = self.preprocess_image(batched_inputs)
93 features = self.backbone(images) # (bs, 2048, 16, 8)
94
95 # branch1
96 b1_feat = self.b1(features)
97 if self.has_extra_bn: b1_feat = self.heads_extra_bn_b1(b1_feat)
98
99 # branch2
100 b2_feat = self.b2(features)
101 if self.has_extra_bn: b2_feat = self.heads_extra_bn_b2(b2_feat)
102 b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
103
104 # branch3
105 b3_feat = self.b3(features)
106 if self.has_extra_bn: b3_feat = self.heads_extra_bn_b3(b3_feat)
107 b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
108
109 if self.training:
110 assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
111 targets = batched_inputs["targets"].long().to(self.device)
112
113 if targets.sum() < 0: targets.zero_()
114
115 b1_outputs = self.b1_head(b1_feat, targets)
116 b2_outputs = self.b2_head(b2_feat, targets)
117 b21_outputs = self.b21_head(b21_feat, targets)
118 b22_outputs = self.b22_head(b22_feat, targets)
119 b3_outputs = self.b3_head(b3_feat, targets)
120 b31_outputs = self.b31_head(b31_feat, targets)
121 b32_outputs = self.b32_head(b32_feat, targets)
122 b33_outputs = self.b33_head(b33_feat, targets)
123
124 return {
125 "b1_outputs": b1_outputs,
126 "b2_outputs": b2_outputs,
127 "b21_outputs": b21_outputs,
128 "b22_outputs": b22_outputs,
129 "b3_outputs": b3_outputs,
130 "b31_outputs": b31_outputs,
131 "b32_outputs": b32_outputs,
132 "b33_outputs": b33_outputs,
133 "targets": targets,
134 }
135 else:
136 b1_pool_feat = self.b1_head(b1_feat)
137 b2_pool_feat = self.b2_head(b2_feat)
138 b21_pool_feat = self.b21_head(b21_feat)
139 b22_pool_feat = self.b22_head(b22_feat)
140 b3_pool_feat = self.b3_head(b3_feat)
141 b31_pool_feat = self.b31_head(b31_feat)
142 b32_pool_feat = self.b32_head(b32_feat)
143 b33_pool_feat = self.b33_head(b33_feat)
144
145 pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
146 b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
147 return pred_feat
148
149 def preprocess_image(self, batched_inputs):
150 r"""
151 Normalize and batch the input images.
152 """
153 if isinstance(batched_inputs, dict):
154 images = batched_inputs["images"].to(self.device)
155 elif isinstance(batched_inputs, torch.Tensor):
156 images = batched_inputs.to(self.device)
157 else:
158 raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
159
160 images.sub_(self.pixel_mean).div_(self.pixel_std)
161 return images
162
163 def losses(self, outs):
164 # fmt: off
165 b1_outputs = outs["b1_outputs"]
166 b2_outputs = outs["b2_outputs"]
167 b21_outputs = outs["b21_outputs"]
168 b22_outputs = outs["b22_outputs"]
169 b3_outputs = outs["b3_outputs"]
170 b31_outputs = outs["b31_outputs"]
171 b32_outputs = outs["b32_outputs"]
172 b33_outputs = outs["b33_outputs"]
173 gt_labels = outs["targets"]
174 # model predictions
175 pred_class_logits = b1_outputs['pred_class_logits'].detach()
176 b1_logits = b1_outputs['cls_outputs']
177 b2_logits = b2_outputs['cls_outputs']
178 b21_logits = b21_outputs['cls_outputs']
179 b22_logits = b22_outputs['cls_outputs']
180 b3_logits = b3_outputs['cls_outputs']
181 b31_logits = b31_outputs['cls_outputs']
182 b32_logits = b32_outputs['cls_outputs']
183 b33_logits = b33_outputs['cls_outputs']
184 b1_pool_feat = b1_outputs['features']
185 b2_pool_feat = b2_outputs['features']
186 b3_pool_feat = b3_outputs['features']
187 b21_pool_feat = b21_outputs['features']
188 b22_pool_feat = b22_outputs['features']
189 b31_pool_feat = b31_outputs['features']
190 b32_pool_feat = b32_outputs['features']
191 b33_pool_feat = b33_outputs['features']
192 # fmt: on
193
194 # Log prediction accuracy
195 log_accuracy(pred_class_logits, gt_labels)
196
197 b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
198 b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
199
200 loss_dict = {}
201 loss_names = self._cfg.MODEL.LOSSES.NAME
202
203 if "CrossEntropyLoss" in loss_names:
204 loss_dict['loss_cls_b1'] = cross_entropy_loss(
205 b1_logits,
206 gt_labels,
207 self._cfg.MODEL.LOSSES.CE.EPSILON,
208 self._cfg.MODEL.LOSSES.CE.ALPHA,
209 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
210 loss_dict['loss_cls_b2'] = cross_entropy_loss(
211 b2_logits,
212 gt_labels,
213 self._cfg.MODEL.LOSSES.CE.EPSILON,
214 self._cfg.MODEL.LOSSES.CE.ALPHA,
215 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
216 loss_dict['loss_cls_b21'] = cross_entropy_loss(
217 b21_logits,
218 gt_labels,
219 self._cfg.MODEL.LOSSES.CE.EPSILON,
220 self._cfg.MODEL.LOSSES.CE.ALPHA,
221 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
222 loss_dict['loss_cls_b22'] = cross_entropy_loss(
223 b22_logits,
224 gt_labels,
225 self._cfg.MODEL.LOSSES.CE.EPSILON,
226 self._cfg.MODEL.LOSSES.CE.ALPHA,
227 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
228 loss_dict['loss_cls_b3'] = cross_entropy_loss(
229 b3_logits,
230 gt_labels,
231 self._cfg.MODEL.LOSSES.CE.EPSILON,
232 self._cfg.MODEL.LOSSES.CE.ALPHA,
233 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
234 loss_dict['loss_cls_b31'] = cross_entropy_loss(
235 b31_logits,
236 gt_labels,
237 self._cfg.MODEL.LOSSES.CE.EPSILON,
238 self._cfg.MODEL.LOSSES.CE.ALPHA,
239 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
240 loss_dict['loss_cls_b32'] = cross_entropy_loss(
241 b32_logits,
242 gt_labels,
243 self._cfg.MODEL.LOSSES.CE.EPSILON,
244 self._cfg.MODEL.LOSSES.CE.ALPHA,
245 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
246 loss_dict['loss_cls_b33'] = cross_entropy_loss(
247 b33_logits,
248 gt_labels,
249 self._cfg.MODEL.LOSSES.CE.EPSILON,
250 self._cfg.MODEL.LOSSES.CE.ALPHA,
251 ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
252
253 if "TripletLoss" in loss_names:
254 loss_dict['loss_triplet_b1'] = triplet_loss(
255 b1_pool_feat,
256 gt_labels,
257 self._cfg.MODEL.LOSSES.TRI.MARGIN,
258 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
259 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
260 ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
261 loss_dict['loss_triplet_b2'] = triplet_loss(
262 b2_pool_feat,
263 gt_labels,
264 self._cfg.MODEL.LOSSES.TRI.MARGIN,
265 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
266 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
267 ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
268 loss_dict['loss_triplet_b3'] = triplet_loss(
269 b3_pool_feat,
270 gt_labels,
271 self._cfg.MODEL.LOSSES.TRI.MARGIN,
272 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
273 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
274 ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
275 loss_dict['loss_triplet_b22'] = triplet_loss(
276 b22_pool_feat,
277 gt_labels,
278 self._cfg.MODEL.LOSSES.TRI.MARGIN,
279 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
280 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
281 ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
282 loss_dict['loss_triplet_b33'] = triplet_loss(
283 b33_pool_feat,
284 gt_labels,
285 self._cfg.MODEL.LOSSES.TRI.MARGIN,
286 self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
287 self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
288 ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
289
290 return loss_dict
forward(self, batched_inputs)
Definition mgn.py:91
preprocess_image(self, batched_inputs)
Definition mgn.py:149