Safemotion Lib
Loading...
Searching...
No Matches
resnet.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import logging
8import math
9
10import torch
11from torch import nn
12
13from fastreid.layers import (
14 IBN,
15 SELayer,
16 Non_local,
17 get_norm,
18)
19from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
20from .build import BACKBONE_REGISTRY
21from fastreid.utils import comm
22
23
24logger = logging.getLogger(__name__)
25model_urls = {
26 '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
27 '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
28 '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
29 '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
30 '152x': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
31 'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
32 'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
33 'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
34 'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
35 'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
36}
37
38
39class BasicBlock(nn.Module):
40 expansion = 1
41
42 def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
43 stride=1, downsample=None, reduction=16):
44 super(BasicBlock, self).__init__()
45 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
46 if with_ibn:
47 self.bn1 = IBN(planes, bn_norm)
48 else:
49 self.bn1 = get_norm(bn_norm, planes)
50 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
51 self.bn2 = get_norm(bn_norm, planes)
52 self.relu = nn.ReLU(inplace=True)
53 if with_se:
54 self.se = SELayer(planes, reduction)
55 else:
56 self.se = nn.Identity()
57 self.downsample = downsample
58 self.stride = stride
59
60 def forward(self, x):
61 identity = x
62
63 out = self.conv1(x)
64 out = self.bn1(out)
65 out = self.relu(out)
66
67 out = self.conv2(out)
68 out = self.bn2(out)
69
70 if self.downsample is not None:
71 identity = self.downsample(x)
72
73 out += identity
74 out = self.relu(out)
75
76 return out
77
78
79class Bottleneck(nn.Module):
80 expansion = 4
81
82 def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
83 stride=1, downsample=None, reduction=16):
84 super(Bottleneck, self).__init__()
85 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
86 if with_ibn:
87 self.bn1 = IBN(planes, bn_norm)
88 else:
89 self.bn1 = get_norm(bn_norm, planes)
90 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
91 padding=1, bias=False)
92 self.bn2 = get_norm(bn_norm, planes)
93 self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
94 self.bn3 = get_norm(bn_norm, planes * self.expansion)
95 self.relu = nn.ReLU(inplace=True)
96 if with_se:
97 self.se = SELayer(planes * self.expansion, reduction)
98 else:
99 self.se = nn.Identity()
100 self.downsample = downsample
101 self.stride = stride
102
103 def forward(self, x):
104 residual = x
105
106 out = self.conv1(x)
107 out = self.bn1(out)
108 out = self.relu(out)
109
110 out = self.conv2(out)
111 out = self.bn2(out)
112 out = self.relu(out)
113
114 out = self.conv3(out)
115 out = self.bn3(out)
116 out = self.se(out)
117
118 if self.downsample is not None:
119 residual = self.downsample(x)
120
121 out += residual
122 out = self.relu(out)
123
124 return out
125
126
127class ResNet(nn.Module):
128 def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers):
129 self.inplanes = 64
130 super().__init__()
131 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
132 bias=False)
133 self.bn1 = get_norm(bn_norm, 64)
134 self.relu = nn.ReLU(inplace=True)
135 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
136 # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
137 self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se)
138 self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se)
139 self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se)
140 self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se)
141
142 self.random_init()
143
144 # fmt: off
145 if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
146 else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
147 # fmt: on
148
149 def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False):
150 downsample = None
151 if stride != 1 or self.inplanes != planes * block.expansion:
152 downsample = nn.Sequential(
153 nn.Conv2d(self.inplanes, planes * block.expansion,
154 kernel_size=1, stride=stride, bias=False),
155 get_norm(bn_norm, planes * block.expansion),
156 )
157
158 layers = []
159 layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample))
160 self.inplanes = planes * block.expansion
161 for i in range(1, blocks):
162 layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se))
163
164 return nn.Sequential(*layers)
165
166 def _build_nonlocal(self, layers, non_layers, bn_norm):
167 self.NL_1 = nn.ModuleList(
168 [Non_local(256, bn_norm) for _ in range(non_layers[0])])
169 self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
170 self.NL_2 = nn.ModuleList(
171 [Non_local(512, bn_norm) for _ in range(non_layers[1])])
172 self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
173 self.NL_3 = nn.ModuleList(
174 [Non_local(1024, bn_norm) for _ in range(non_layers[2])])
175 self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
176 self.NL_4 = nn.ModuleList(
177 [Non_local(2048, bn_norm) for _ in range(non_layers[3])])
178 self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
179
180 def forward(self, x):
181 x = self.conv1(x)
182 x = self.bn1(x)
183 x = self.relu(x)
184 x = self.maxpool(x)
185
186 NL1_counter = 0
187 if len(self.NL_1_idx) == 0:
188 self.NL_1_idx = [-1]
189 for i in range(len(self.layer1)):
190 x = self.layer1[i](x)
191 if i == self.NL_1_idx[NL1_counter]:
192 _, C, H, W = x.shape
193 x = self.NL_1[NL1_counter](x)
194 NL1_counter += 1
195 # Layer 2
196 NL2_counter = 0
197 if len(self.NL_2_idx) == 0:
198 self.NL_2_idx = [-1]
199 for i in range(len(self.layer2)):
200 x = self.layer2[i](x)
201 if i == self.NL_2_idx[NL2_counter]:
202 _, C, H, W = x.shape
203 x = self.NL_2[NL2_counter](x)
204 NL2_counter += 1
205 # Layer 3
206 NL3_counter = 0
207 if len(self.NL_3_idx) == 0:
208 self.NL_3_idx = [-1]
209 for i in range(len(self.layer3)):
210 x = self.layer3[i](x)
211 if i == self.NL_3_idx[NL3_counter]:
212 _, C, H, W = x.shape
213 x = self.NL_3[NL3_counter](x)
214 NL3_counter += 1
215 # Layer 4
216 NL4_counter = 0
217 if len(self.NL_4_idx) == 0:
218 self.NL_4_idx = [-1]
219 for i in range(len(self.layer4)):
220 x = self.layer4[i](x)
221 if i == self.NL_4_idx[NL4_counter]:
222 _, C, H, W = x.shape
223 x = self.NL_4[NL4_counter](x)
224 NL4_counter += 1
225
226 return x
227
228 def random_init(self):
229 for m in self.modules():
230 if isinstance(m, nn.Conv2d):
231 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
232 nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
233 elif isinstance(m, nn.BatchNorm2d):
234 nn.init.constant_(m.weight, 1)
235 nn.init.constant_(m.bias, 0)
236
237
239 """Initializes model with pretrained weights.
240
241 Layers that don't match with pretrained layers in name or size are kept unchanged.
242 """
243 import os
244 import errno
245 import gdown
246
247 def _get_torch_home():
248 ENV_TORCH_HOME = 'TORCH_HOME'
249 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
250 DEFAULT_CACHE_DIR = '~/.cache'
251 torch_home = os.path.expanduser(
252 os.getenv(
253 ENV_TORCH_HOME,
254 os.path.join(
255 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
256 )
257 )
258 )
259 return torch_home
260
261 torch_home = _get_torch_home()
262 model_dir = os.path.join(torch_home, 'checkpoints')
263 try:
264 os.makedirs(model_dir)
265 except OSError as e:
266 if e.errno == errno.EEXIST:
267 # Directory already exists, ignore.
268 pass
269 else:
270 # Unexpected OSError, re-raise.
271 raise
272
273 filename = model_urls[key].split('/')[-1]
274
275 cached_file = os.path.join(model_dir, filename)
276
277 if not os.path.exists(cached_file):
278 if comm.is_main_process():
279 gdown.download(model_urls[key], cached_file, quiet=False)
280
281 comm.synchronize()
282
283 logger.info(f"Loading pretrained model from {cached_file}")
284 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
285
286 return state_dict
287
288
289@BACKBONE_REGISTRY.register()
291 """
292 Create a ResNet instance from config.
293 Returns:
294 ResNet: a :class:`ResNet` instance.
295 """
296
297 # fmt: off
298 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
299 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
300 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
301 bn_norm = cfg.MODEL.BACKBONE.NORM
302 with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
303 with_se = cfg.MODEL.BACKBONE.WITH_SE
304 with_nl = cfg.MODEL.BACKBONE.WITH_NL
305 depth = cfg.MODEL.BACKBONE.DEPTH
306 # fmt: on
307
308 num_blocks_per_stage = {
309 '18x': [2, 2, 2, 2],
310 '34x': [3, 4, 6, 3],
311 '50x': [3, 4, 6, 3],
312 '101x': [3, 4, 23, 3],
313 '152x': [3, 8, 36, 3],
314 }[depth]
315
316 nl_layers_per_stage = {
317 '18x': [0, 0, 0, 0],
318 '34x': [0, 0, 0, 0],
319 '50x': [0, 2, 3, 0],
320 '101x': [0, 2, 9, 0],
321 '152x': [0, 4, 12, 0]
322 }[depth]
323
324 block = {
325 '18x': BasicBlock,
326 '34x': BasicBlock,
327 '50x': Bottleneck,
328 '101x': Bottleneck,
329 '152x': Bottleneck,
330 }[depth]
331
332 model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
333 num_blocks_per_stage, nl_layers_per_stage)
334 if pretrain:
335 # Load pretrain path if specifically
336 if pretrain_path:
337 try:
338 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
339 logger.info(f"Loading pretrained model from {pretrain_path}")
340 except FileNotFoundError as e:
341 logger.info(f'{pretrain_path} is not found! Please check this path.')
342 raise e
343 except KeyError as e:
344 logger.info("State dict keys error! Please check the state dict.")
345 raise e
346 else:
347 key = depth
348 if with_ibn: key = 'ibn_' + key
349 if with_se: key = 'se_' + key
350
351 state_dict = init_pretrained_weights(key)
352
353 incompatible = model.load_state_dict(state_dict, strict=False)
354 if incompatible.missing_keys:
355 logger.info(
356 get_missing_parameters_message(incompatible.missing_keys)
357 )
358 if incompatible.unexpected_keys:
359 logger.info(
360 get_unexpected_parameters_message(incompatible.unexpected_keys)
361 )
362
363 return model
__init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16)
Definition resnet.py:43
__init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16)
Definition resnet.py:83
_build_nonlocal(self, layers, non_layers, bn_norm)
Definition resnet.py:166
__init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers)
Definition resnet.py:128
_make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False)
Definition resnet.py:149