Safemotion Lib
Loading...
Searching...
No Matches
resnext.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7# based on:
8# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
9
10import logging
11import math
12
13import torch
14import torch.nn as nn
15
16from fastreid.layers import *
17from fastreid.utils import comm
18from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
19from .build import BACKBONE_REGISTRY
20
21logger = logging.getLogger(__name__)
22model_urls = {
23 'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth',
24}
25
26
27class Bottleneck(nn.Module):
28 """
29 RexNeXt bottleneck type C
30 """
31 expansion = 4
32
33 def __init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1,
34 downsample=None):
35 """ Constructor
36 Args:
37 inplanes: input channel dimensionality
38 planes: output channel dimensionality
39 baseWidth: base width.
40 cardinality: num of convolution groups.
41 stride: conv stride. Replaces pooling layer.
42 """
43 super(Bottleneck, self).__init__()
44
45 D = int(math.floor(planes * (baseWidth / 64)))
46 C = cardinality
47 self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
48 if with_ibn:
49 self.bn1 = IBN(D * C, bn_norm)
50 else:
51 self.bn1 = get_norm(bn_norm, D * C)
52 self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
53 self.bn2 = get_norm(bn_norm, D * C)
54 self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
55 self.bn3 = get_norm(bn_norm, planes * 4)
56 self.relu = nn.ReLU(inplace=True)
57
58 self.downsample = downsample
59
60 def forward(self, x):
61 residual = x
62
63 out = self.conv1(x)
64 out = self.bn1(out)
65 out = self.relu(out)
66
67 out = self.conv2(out)
68 out = self.bn2(out)
69 out = self.relu(out)
70
71 out = self.conv3(out)
72 out = self.bn3(out)
73
74 if self.downsample is not None:
75 residual = self.downsample(x)
76
77 out += residual
78 out = self.relu(out)
79
80 return out
81
82
83class ResNeXt(nn.Module):
84 """
85 ResNext optimized for the ImageNet dataset, as specified in
86 https://arxiv.org/pdf/1611.05431.pdf
87 """
88
89 def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers,
90 baseWidth=4, cardinality=32):
91 """ Constructor
92 Args:
93 baseWidth: baseWidth for ResNeXt.
94 cardinality: number of convolution groups.
95 layers: config of layers, e.g., [3, 4, 6, 3]
96 """
97 super(ResNeXt, self).__init__()
98
99 self.cardinality = cardinality
100 self.baseWidth = baseWidth
101 self.inplanes = 64
102 self.output_size = 64
103
104 self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
105 self.bn1 = get_norm(bn_norm, 64)
106 self.relu = nn.ReLU(inplace=True)
107 self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn)
109 self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn)
110 self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn)
111 self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn)
112
113 self.random_init()
114
115 # fmt: off
116 if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
117 else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
118 # fmt: on
119
120 def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', with_ibn=False):
121 """ Stack n bottleneck modules where n is inferred from the depth of the network.
122 Args:
123 block: block type used to construct ResNext
124 planes: number of output channels (need to multiply by block.expansion)
125 blocks: number of blocks to be built
126 stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
127 Returns: a Module consisting of n sequential bottlenecks.
128 """
129 downsample = None
130 if stride != 1 or self.inplanes != planes * block.expansion:
131 downsample = nn.Sequential(
132 nn.Conv2d(self.inplanes, planes * block.expansion,
133 kernel_size=1, stride=stride, bias=False),
134 get_norm(bn_norm, planes * block.expansion),
135 )
136
137 layers = []
138 layers.append(block(self.inplanes, planes, bn_norm, with_ibn,
139 self.baseWidth, self.cardinality, stride, downsample))
140 self.inplanes = planes * block.expansion
141 for i in range(1, blocks):
142 layers.append(
143 block(self.inplanes, planes, bn_norm, with_ibn, self.baseWidth, self.cardinality, 1, None))
144
145 return nn.Sequential(*layers)
146
147 def _build_nonlocal(self, layers, non_layers, bn_norm):
148 self.NL_1 = nn.ModuleList(
149 [Non_local(256, bn_norm) for _ in range(non_layers[0])])
150 self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
151 self.NL_2 = nn.ModuleList(
152 [Non_local(512, bn_norm) for _ in range(non_layers[1])])
153 self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
154 self.NL_3 = nn.ModuleList(
155 [Non_local(1024, bn_norm) for _ in range(non_layers[2])])
156 self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
157 self.NL_4 = nn.ModuleList(
158 [Non_local(2048, bn_norm) for _ in range(non_layers[3])])
159 self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
160
161 def forward(self, x):
162 x = self.conv1(x)
163 x = self.bn1(x)
164 x = self.relu(x)
165 x = self.maxpool1(x)
166
167 NL1_counter = 0
168 if len(self.NL_1_idx) == 0:
169 self.NL_1_idx = [-1]
170 for i in range(len(self.layer1)):
171 x = self.layer1[i](x)
172 if i == self.NL_1_idx[NL1_counter]:
173 _, C, H, W = x.shape
174 x = self.NL_1[NL1_counter](x)
175 NL1_counter += 1
176 # Layer 2
177 NL2_counter = 0
178 if len(self.NL_2_idx) == 0:
179 self.NL_2_idx = [-1]
180 for i in range(len(self.layer2)):
181 x = self.layer2[i](x)
182 if i == self.NL_2_idx[NL2_counter]:
183 _, C, H, W = x.shape
184 x = self.NL_2[NL2_counter](x)
185 NL2_counter += 1
186 # Layer 3
187 NL3_counter = 0
188 if len(self.NL_3_idx) == 0:
189 self.NL_3_idx = [-1]
190 for i in range(len(self.layer3)):
191 x = self.layer3[i](x)
192 if i == self.NL_3_idx[NL3_counter]:
193 _, C, H, W = x.shape
194 x = self.NL_3[NL3_counter](x)
195 NL3_counter += 1
196 # Layer 4
197 NL4_counter = 0
198 if len(self.NL_4_idx) == 0:
199 self.NL_4_idx = [-1]
200 for i in range(len(self.layer4)):
201 x = self.layer4[i](x)
202 if i == self.NL_4_idx[NL4_counter]:
203 _, C, H, W = x.shape
204 x = self.NL_4[NL4_counter](x)
205 NL4_counter += 1
206 return x
207
208 def random_init(self):
209 self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
210 for m in self.modules():
211 if isinstance(m, nn.Conv2d):
212 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
213 m.weight.data.normal_(0, math.sqrt(2. / n))
214 elif isinstance(m, nn.BatchNorm2d):
215 m.weight.data.fill_(1)
216 m.bias.data.zero_()
217 elif isinstance(m, nn.InstanceNorm2d):
218 m.weight.data.fill_(1)
219 m.bias.data.zero_()
220
221
223 """Initializes model with pretrained weights.
224
225 Layers that don't match with pretrained layers in name or size are kept unchanged.
226 """
227 import os
228 import errno
229 import gdown
230
231 def _get_torch_home():
232 ENV_TORCH_HOME = 'TORCH_HOME'
233 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
234 DEFAULT_CACHE_DIR = '~/.cache'
235 torch_home = os.path.expanduser(
236 os.getenv(
237 ENV_TORCH_HOME,
238 os.path.join(
239 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
240 )
241 )
242 )
243 return torch_home
244
245 torch_home = _get_torch_home()
246 model_dir = os.path.join(torch_home, 'checkpoints')
247 try:
248 os.makedirs(model_dir)
249 except OSError as e:
250 if e.errno == errno.EEXIST:
251 # Directory already exists, ignore.
252 pass
253 else:
254 # Unexpected OSError, re-raise.
255 raise
256
257 filename = model_urls[key].split('/')[-1]
258
259 cached_file = os.path.join(model_dir, filename)
260
261 if not os.path.exists(cached_file):
262 if comm.is_main_process():
263 gdown.download(model_urls[key], cached_file, quiet=False)
264
265 comm.synchronize()
266
267 logger.info(f"Loading pretrained model from {cached_file}")
268 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
269
270 return state_dict
271
272
273@BACKBONE_REGISTRY.register()
275 """
276 Create a ResNeXt instance from config.
277 Returns:
278 ResNeXt: a :class:`ResNeXt` instance.
279 """
280
281 # fmt: off
282 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
283 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
284 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
285 bn_norm = cfg.MODEL.BACKBONE.NORM
286 with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
287 with_nl = cfg.MODEL.BACKBONE.WITH_NL
288 depth = cfg.MODEL.BACKBONE.DEPTH
289 # fmt: on
290
291 num_blocks_per_stage = {
292 '50x': [3, 4, 6, 3],
293 '101x': [3, 4, 23, 3],
294 '152x': [3, 8, 36, 3], }[depth]
295 nl_layers_per_stage = {
296 '50x': [0, 2, 3, 0],
297 '101x': [0, 2, 3, 0]}[depth]
298 model = ResNeXt(last_stride, bn_norm, with_ibn, with_nl, Bottleneck,
299 num_blocks_per_stage, nl_layers_per_stage)
300 if pretrain:
301 if pretrain_path:
302 try:
303 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
304 # Remove module.encoder in name
305 new_state_dict = {}
306 for k in state_dict:
307 new_k = '.'.join(k.split('.')[2:])
308 if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
309 new_state_dict[new_k] = state_dict[k]
310 state_dict = new_state_dict
311 logger.info(f"Loading pretrained model from {pretrain_path}")
312 except FileNotFoundError as e:
313 logger.info(f'{pretrain_path} is not found! Please check this path.')
314 raise e
315 except KeyError as e:
316 logger.info("State dict keys error! Please check the state dict.")
317 raise e
318 else:
319 key = depth
320 if with_ibn: key = 'ibn_' + key
321
322 state_dict = init_pretrained_weights(key)
323
324 incompatible = model.load_state_dict(state_dict, strict=False)
325 if incompatible.missing_keys:
326 logger.info(
327 get_missing_parameters_message(incompatible.missing_keys)
328 )
329 if incompatible.unexpected_keys:
330 logger.info(
331 get_unexpected_parameters_message(incompatible.unexpected_keys)
332 )
333
334 return model
__init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1, downsample=None)
Definition resnext.py:34
_build_nonlocal(self, layers, non_layers, bn_norm)
Definition resnext.py:147
__init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, baseWidth=4, cardinality=32)
Definition resnext.py:90
_make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', with_ibn=False)
Definition resnext.py:120