Safemotion Lib
Loading...
Searching...
No Matches
osnet.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7# based on:
8# https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet.py
9
10import logging
11
12import torch
13from torch import nn
14
15from fastreid.layers import get_norm
16from fastreid.utils import comm
17from .build import BACKBONE_REGISTRY
18
19logger = logging.getLogger(__name__)
20model_urls = {
21 'osnet_x1_0':
22 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
23 'osnet_x0_75':
24 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
25 'osnet_x0_5':
26 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
27 'osnet_x0_25':
28 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
29 'osnet_ibn_x1_0':
30 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
31}
32
33
34
37class ConvLayer(nn.Module):
38 """Convolution layer (conv + bn + relu)."""
39
41 self,
42 in_channels,
43 out_channels,
44 kernel_size,
45 bn_norm,
46 stride=1,
47 padding=0,
48 groups=1,
49 IN=False
50 ):
51 super(ConvLayer, self).__init__()
52 self.conv = nn.Conv2d(
53 in_channels,
54 out_channels,
55 kernel_size,
56 stride=stride,
57 padding=padding,
58 bias=False,
59 groups=groups
60 )
61 if IN:
62 self.bn = nn.InstanceNorm2d(out_channels, affine=True)
63 else:
64 self.bn = get_norm(bn_norm, out_channels)
65 self.relu = nn.ReLU(inplace=True)
66
67 def forward(self, x):
68 x = self.conv(x)
69 x = self.bn(x)
70 x = self.relu(x)
71 return x
72
73
74class Conv1x1(nn.Module):
75 """1x1 convolution + bn + relu."""
76
77 def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1):
78 super(Conv1x1, self).__init__()
79 self.conv = nn.Conv2d(
80 in_channels,
81 out_channels,
82 1,
83 stride=stride,
84 padding=0,
85 bias=False,
86 groups=groups
87 )
88 self.bn = get_norm(bn_norm, out_channels)
89 self.relu = nn.ReLU(inplace=True)
90
91 def forward(self, x):
92 x = self.conv(x)
93 x = self.bn(x)
94 x = self.relu(x)
95 return x
96
97
98class Conv1x1Linear(nn.Module):
99 """1x1 convolution + bn (w/o non-linearity)."""
100
101 def __init__(self, in_channels, out_channels, bn_norm, stride=1):
102 super(Conv1x1Linear, self).__init__()
103 self.conv = nn.Conv2d(
104 in_channels, out_channels, 1, stride=stride, padding=0, bias=False
105 )
106 self.bn = get_norm(bn_norm, out_channels)
107
108 def forward(self, x):
109 x = self.conv(x)
110 x = self.bn(x)
111 return x
112
113
114class Conv3x3(nn.Module):
115 """3x3 convolution + bn + relu."""
116
117 def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1):
118 super(Conv3x3, self).__init__()
119 self.conv = nn.Conv2d(
120 in_channels,
121 out_channels,
122 3,
123 stride=stride,
124 padding=1,
125 bias=False,
126 groups=groups
127 )
128 self.bn = get_norm(bn_norm, out_channels)
129 self.relu = nn.ReLU(inplace=True)
130
131 def forward(self, x):
132 x = self.conv(x)
133 x = self.bn(x)
134 x = self.relu(x)
135 return x
136
137
138class LightConv3x3(nn.Module):
139 """Lightweight 3x3 convolution.
140 1x1 (linear) + dw 3x3 (nonlinear).
141 """
142
143 def __init__(self, in_channels, out_channels, bn_norm):
144 super(LightConv3x3, self).__init__()
145 self.conv1 = nn.Conv2d(
146 in_channels, out_channels, 1, stride=1, padding=0, bias=False
147 )
148 self.conv2 = nn.Conv2d(
149 out_channels,
150 out_channels,
151 3,
152 stride=1,
153 padding=1,
154 bias=False,
155 groups=out_channels
156 )
157 self.bn = get_norm(bn_norm, out_channels)
158 self.relu = nn.ReLU(inplace=True)
159
160 def forward(self, x):
161 x = self.conv1(x)
162 x = self.conv2(x)
163 x = self.bn(x)
164 x = self.relu(x)
165 return x
166
167
168
171class ChannelGate(nn.Module):
172 """A mini-network that generates channel-wise gates conditioned on input tensor."""
173
175 self,
176 in_channels,
177 num_gates=None,
178 return_gates=False,
179 gate_activation='sigmoid',
180 reduction=16,
181 layer_norm=False
182 ):
183 super(ChannelGate, self).__init__()
184 if num_gates is None: num_gates = in_channels
185 self.return_gates = return_gates
186
187 self.global_avgpool = nn.AdaptiveAvgPool2d(1)
188
189 self.fc1 = nn.Conv2d(
190 in_channels,
191 in_channels // reduction,
192 kernel_size=1,
193 bias=True,
194 padding=0
195 )
196 self.norm1 = None
197 if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
198 self.relu = nn.ReLU(inplace=True)
199 self.fc2 = nn.Conv2d(
200 in_channels // reduction,
201 num_gates,
202 kernel_size=1,
203 bias=True,
204 padding=0
205 )
206 if gate_activation == 'sigmoid':
207 self.gate_activation = nn.Sigmoid()
208 elif gate_activation == 'relu':
209 self.gate_activation = nn.ReLU(inplace=True)
210 elif gate_activation == 'linear':
211 self.gate_activation = nn.Identity()
212 else:
213 raise RuntimeError(
214 "Unknown gate activation: {}".format(gate_activation)
215 )
216
217 def forward(self, x):
218 input = x
219 x = self.global_avgpool(x)
220 x = self.fc1(x)
221 if self.norm1 is not None: x = self.norm1(x)
222 x = self.relu(x)
223 x = self.fc2(x)
224 x = self.gate_activation(x)
225 if self.return_gates: return x
226 return input * x
227
228
229class OSBlock(nn.Module):
230 """Omni-scale feature learning block."""
231
233 self,
234 in_channels,
235 out_channels,
236 bn_norm,
237 IN=False,
238 bottleneck_reduction=4,
239 **kwargs
240 ):
241 super(OSBlock, self).__init__()
242 mid_channels = out_channels // bottleneck_reduction
243 self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm)
244 self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm)
245 self.conv2b = nn.Sequential(
246 LightConv3x3(mid_channels, mid_channels, bn_norm),
247 LightConv3x3(mid_channels, mid_channels, bn_norm),
248 )
249 self.conv2c = nn.Sequential(
250 LightConv3x3(mid_channels, mid_channels, bn_norm),
251 LightConv3x3(mid_channels, mid_channels, bn_norm),
252 LightConv3x3(mid_channels, mid_channels, bn_norm),
253 )
254 self.conv2d = nn.Sequential(
255 LightConv3x3(mid_channels, mid_channels, bn_norm),
256 LightConv3x3(mid_channels, mid_channels, bn_norm),
257 LightConv3x3(mid_channels, mid_channels, bn_norm),
258 LightConv3x3(mid_channels, mid_channels, bn_norm),
259 )
260 self.gate = ChannelGate(mid_channels)
261 self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm)
262 self.downsample = None
263 if in_channels != out_channels:
264 self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm)
265 self.IN = None
266 if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
267 self.relu = nn.ReLU(True)
268
269 def forward(self, x):
270 identity = x
271 x1 = self.conv1(x)
272 x2a = self.conv2a(x1)
273 x2b = self.conv2b(x1)
274 x2c = self.conv2c(x1)
275 x2d = self.conv2d(x1)
276 x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
277 x3 = self.conv3(x2)
278 if self.downsample is not None:
279 identity = self.downsample(identity)
280 out = x3 + identity
281 if self.IN is not None:
282 out = self.IN(out)
283 return self.relu(out)
284
285
286
289class OSNet(nn.Module):
290 """Omni-Scale Network.
291
292 Reference:
293 - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
294 - Zhou et al. Learning Generalisable Omni-Scale Representations
295 for Person Re-Identification. arXiv preprint, 2019.
296 """
297
299 self,
300 blocks,
301 layers,
302 channels,
303 bn_norm,
304 IN=False,
305 **kwargs
306 ):
307 super(OSNet, self).__init__()
308 num_blocks = len(blocks)
309 assert num_blocks == len(layers)
310 assert num_blocks == len(channels) - 1
311
312 # convolutional backbone
313 self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, stride=2, padding=3, IN=IN)
314 self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
315 self.conv2 = self._make_layer(
316 blocks[0],
317 layers[0],
318 channels[0],
319 channels[1],
320 bn_norm,
321 reduce_spatial_size=True,
322 IN=IN
323 )
324 self.conv3 = self._make_layer(
325 blocks[1],
326 layers[1],
327 channels[1],
328 channels[2],
329 bn_norm,
330 reduce_spatial_size=True
331 )
332 self.conv4 = self._make_layer(
333 blocks[2],
334 layers[2],
335 channels[2],
336 channels[3],
337 bn_norm,
338 reduce_spatial_size=False
339 )
340 self.conv5 = Conv1x1(channels[3], channels[3], bn_norm)
341
342 self._init_params()
343
345 self,
346 block,
347 layer,
348 in_channels,
349 out_channels,
350 bn_norm,
351 reduce_spatial_size,
352 IN=False
353 ):
354 layers = []
355
356 layers.append(block(in_channels, out_channels, bn_norm, IN=IN))
357 for i in range(1, layer):
358 layers.append(block(out_channels, out_channels, bn_norm, IN=IN))
359
360 if reduce_spatial_size:
361 layers.append(
362 nn.Sequential(
363 Conv1x1(out_channels, out_channels, bn_norm),
364 nn.AvgPool2d(2, stride=2),
365 )
366 )
367
368 return nn.Sequential(*layers)
369
370 def _init_params(self):
371 for m in self.modules():
372 if isinstance(m, nn.Conv2d):
373 nn.init.kaiming_normal_(
374 m.weight, mode='fan_out', nonlinearity='relu'
375 )
376 if m.bias is not None:
377 nn.init.constant_(m.bias, 0)
378
379 elif isinstance(m, nn.BatchNorm2d):
380 nn.init.constant_(m.weight, 1)
381 nn.init.constant_(m.bias, 0)
382
383 elif isinstance(m, nn.BatchNorm1d):
384 nn.init.constant_(m.weight, 1)
385 nn.init.constant_(m.bias, 0)
386
387 elif isinstance(m, nn.Linear):
388 nn.init.normal_(m.weight, 0, 0.01)
389 if m.bias is not None:
390 nn.init.constant_(m.bias, 0)
391
392 def forward(self, x):
393 x = self.conv1(x)
394 x = self.maxpool(x)
395 x = self.conv2(x)
396 x = self.conv3(x)
397 x = self.conv4(x)
398 x = self.conv5(x)
399 return x
400
401
402def init_pretrained_weights(model, key=''):
403 """Initializes model with pretrained weights.
404
405 Layers that don't match with pretrained layers in name or size are kept unchanged.
406 """
407 import os
408 import errno
409 import gdown
410 from collections import OrderedDict
411 import warnings
412 import logging
413
414 logger = logging.getLogger(__name__)
415
416 def _get_torch_home():
417 ENV_TORCH_HOME = 'TORCH_HOME'
418 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
419 DEFAULT_CACHE_DIR = '~/.cache'
420 torch_home = os.path.expanduser(
421 os.getenv(
422 ENV_TORCH_HOME,
423 os.path.join(
424 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
425 )
426 )
427 )
428 return torch_home
429
430 torch_home = _get_torch_home()
431 model_dir = os.path.join(torch_home, 'checkpoints')
432 try:
433 os.makedirs(model_dir)
434 except OSError as e:
435 if e.errno == errno.EEXIST:
436 # Directory already exists, ignore.
437 pass
438 else:
439 # Unexpected OSError, re-raise.
440 raise
441 filename = key + '_imagenet.pth'
442 cached_file = os.path.join(model_dir, filename)
443
444 if not os.path.exists(cached_file):
445 if comm.is_main_process():
446 gdown.download(model_urls[key], cached_file, quiet=False)
447
448 comm.synchronize()
449
450 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
451 model_dict = model.state_dict()
452 new_state_dict = OrderedDict()
453 matched_layers, discarded_layers = [], []
454
455 for k, v in state_dict.items():
456 if k.startswith('module.'):
457 k = k[7:] # discard module.
458
459 if k in model_dict and model_dict[k].size() == v.size():
460 new_state_dict[k] = v
461 matched_layers.append(k)
462 else:
463 discarded_layers.append(k)
464
465 model_dict.update(new_state_dict)
466 model.load_state_dict(model_dict)
467
468 if len(matched_layers) == 0:
469 warnings.warn(
470 'The pretrained weights from "{}" cannot be loaded, '
471 'please check the key names manually '
472 '(** ignored and continue **)'.format(cached_file)
473 )
474 else:
475 logger.info(
476 'Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file)
477 )
478 if len(discarded_layers) > 0:
479 logger.info(
480 '** The following layers are discarded '
481 'due to unmatched keys or layer size: {}'.format(discarded_layers)
482 )
483
484
485@BACKBONE_REGISTRY.register()
487 """
488 Create a OSNet instance from config.
489 Returns:
490 OSNet: a :class:`OSNet` instance
491 """
492
493 # fmt: off
494 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
495 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
496 with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
497 bn_norm = cfg.MODEL.BACKBONE.NORM
498 depth = cfg.MODEL.BACKBONE.DEPTH
499 # fmt: on
500
501 num_blocks_per_stage = [2, 2, 2]
502 num_channels_per_stage = {
503 "x1_0": [64, 256, 384, 512],
504 "x0_75": [48, 192, 288, 384],
505 "x0_5": [32, 128, 192, 256],
506 "x0_25": [16, 64, 96, 128]}[depth]
507 model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage,
508 bn_norm, IN=with_ibn)
509
510 if pretrain:
511 # Load pretrain path if specifically
512 if pretrain_path:
513 try:
514 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
515 logger.info(f"Loading pretrained model from {pretrain_path}")
516 model.load_state_dict(state_dict)
517 except FileNotFoundError as e:
518 logger.info(f'{pretrain_path} is not found! Please check this path.')
519 raise e
520 except KeyError as e:
521 logger.info("State dict keys error! Please check the state dict.")
522 raise e
523 else:
524 if with_ibn:
525 pretrain_key = "osnet_ibn_" + depth
526 else:
527 pretrain_key = "osnet_" + depth
528
529 init_pretrained_weights(model, pretrain_key)
530 return model
__init__(self, in_channels, num_gates=None, return_gates=False, gate_activation='sigmoid', reduction=16, layer_norm=False)
Definition osnet.py:182
__init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1)
Definition osnet.py:77
__init__(self, in_channels, out_channels, bn_norm, stride=1)
Definition osnet.py:101
__init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1)
Definition osnet.py:117
__init__(self, in_channels, out_channels, kernel_size, bn_norm, stride=1, padding=0, groups=1, IN=False)
Definition osnet.py:50
__init__(self, in_channels, out_channels, bn_norm)
Definition osnet.py:143
__init__(self, in_channels, out_channels, bn_norm, IN=False, bottleneck_reduction=4, **kwargs)
Definition osnet.py:240
_make_layer(self, block, layer, in_channels, out_channels, bn_norm, reduce_spatial_size, IN=False)
Definition osnet.py:353
__init__(self, blocks, layers, channels, bn_norm, IN=False, **kwargs)
Definition osnet.py:306
init_pretrained_weights(model, key='')
Definition osnet.py:402