Safemotion Lib
Loading...
Searching...
No Matches
regnet.py
Go to the documentation of this file.
1import logging
2import math
3
4import numpy as np
5import torch
6import torch.nn as nn
7
8from fastreid.layers import get_norm
9from fastreid.utils import comm
10from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
11from .config import cfg as regnet_cfg
12from ..build import BACKBONE_REGISTRY
13
14logger = logging.getLogger(__name__)
15model_urls = {
16 '800x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160905981/RegNetX-200MF_dds_8gpu.pyth',
17 '800y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906567/RegNetY-800MF_dds_8gpu.pyth',
18 '1600x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160990626/RegNetX-1.6GF_dds_8gpu.pyth',
19 '1600y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906681/RegNetY-1.6GF_dds_8gpu.pyth',
20 '3200x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906139/RegNetX-3.2GF_dds_8gpu.pyth',
21 '3200y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906834/RegNetY-3.2GF_dds_8gpu.pyth',
22 '4000x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906383/RegNetX-4.0GF_dds_8gpu.pyth',
23 '4000y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906838/RegNetY-4.0GF_dds_8gpu.pyth',
24 '6400x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161116590/RegNetX-6.4GF_dds_8gpu.pyth',
25 '6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
26}
27
28
30 """Performs ResNet-style weight initialization."""
31 if isinstance(m, nn.Conv2d):
32 # Note that there is no bias due to BN
33 fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
34 m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
35 elif isinstance(m, nn.BatchNorm2d):
36 zero_init_gamma = (
37 hasattr(m, "final_bn") and m.final_bn and regnet_cfg.BN.ZERO_INIT_FINAL_GAMMA
38 )
39 m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
40 m.bias.data.zero_()
41 elif isinstance(m, nn.Linear):
42 m.weight.data.normal_(mean=0.0, std=0.01)
43 m.bias.data.zero_()
44
45
46def get_stem_fun(stem_type):
47 """Retrives the stem function by name."""
48 stem_funs = {
49 "res_stem_cifar": ResStemCifar,
50 "res_stem_in": ResStemIN,
51 "simple_stem_in": SimpleStemIN,
52 }
53 assert stem_type in stem_funs.keys(), "Stem type '{}' not supported".format(
54 stem_type
55 )
56 return stem_funs[stem_type]
57
58
59def get_block_fun(block_type):
60 """Retrieves the block function by name."""
61 block_funs = {
62 "vanilla_block": VanillaBlock,
63 "res_basic_block": ResBasicBlock,
64 "res_bottleneck_block": ResBottleneckBlock,
65 }
66 assert block_type in block_funs.keys(), "Block type '{}' not supported".format(
67 block_type
68 )
69 return block_funs[block_type]
70
71
72def drop_connect(x, drop_ratio):
73 """Drop connect (adapted from DARTS)."""
74 keep_ratio = 1.0 - drop_ratio
75 mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
76 mask.bernoulli_(keep_ratio)
77 x.div_(keep_ratio)
78 x.mul_(mask)
79 return x
80
81class AnyHead(nn.Module):
82 """AnyNet head."""
83
84 def __init__(self, w_in, nc):
85 super(AnyHead, self).__init__()
86 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
87 self.fc = nn.Linear(w_in, nc, bias=True)
88
89 def forward(self, x):
90 x = self.avg_pool(x)
91 x = x.view(x.size(0), -1)
92 x = self.fc(x)
93 return x
94
95
96class VanillaBlock(nn.Module):
97 """Vanilla block: [3x3 conv, BN, Relu] x2"""
98
99 def __init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None):
100 assert (
101 bm is None and gw is None and se_r is None
102 ), "Vanilla block does not support bm, gw, and se_r options"
103 super(VanillaBlock, self).__init__()
104 self.construct(w_in, w_out, stride, bn_norm)
105
106 def construct(self, w_in, w_out, stride, bn_norm):
107 # 3x3, BN, ReLU
108 self.a = nn.Conv2d(
109 w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
110 )
111 self.a_bn = get_norm(bn_norm, w_out)
112 self.a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
113 # 3x3, BN, ReLU
114 self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
115 self.b_bn = get_norm(bn_norm, w_out)
116 self.b_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
117
118 def forward(self, x):
119 for layer in self.children():
120 x = layer(x)
121 return x
122
123
124class BasicTransform(nn.Module):
125 """Basic transformation: [3x3 conv, BN, Relu] x2"""
126
127 def __init__(self, w_in, w_out, stride, bn_norm):
128 super(BasicTransform, self).__init__()
129 self.construct(w_in, w_out, stride, bn_norm)
130
131 def construct(self, w_in, w_out, stride, bn_norm):
132 # 3x3, BN, ReLU
133 self.a = nn.Conv2d(
134 w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
135 )
136 self.a_bn = get_norm(bn_norm, w_out)
137 self.a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
138 # 3x3, BN
139 self.b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=False)
140 self.b_bn = get_norm(bn_norm, w_out)
141 self.b_bn.final_bn = True
142
143 def forward(self, x):
144 for layer in self.children():
145 x = layer(x)
146 return x
147
148
149class ResBasicBlock(nn.Module):
150 """Residual basic block: x + F(x), F = basic transform"""
151
152 def __init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None):
153 assert (
154 bm is None and gw is None and se_r is None
155 ), "Basic transform does not support bm, gw, and se_r options"
156 super(ResBasicBlock, self).__init__()
157 self.construct(w_in, w_out, stride, bn_norm)
158
159 def _add_skip_proj(self, w_in, w_out, stride, bn_norm):
160 self.proj = nn.Conv2d(
161 w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
162 )
163 self.bn = get_norm(bn_norm, w_out)
164
165 def construct(self, w_in, w_out, stride, bn_norm):
166 # Use skip connection with projection if shape changes
167 self.proj_block = (w_in != w_out) or (stride != 1)
168 if self.proj_block:
169 self._add_skip_proj(w_in, w_out, stride, bn_norm)
170 self.f = BasicTransform(w_in, w_out, stride, bn_norm)
171 self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
172
173 def forward(self, x):
174 if self.proj_block:
175 x = self.bn(self.proj(x)) + self.f(x)
176 else:
177 x = x + self.f(x)
178 x = self.relu(x)
179 return x
180
181
182class SE(nn.Module):
183 """Squeeze-and-Excitation (SE) block"""
184
185 def __init__(self, w_in, w_se):
186 super(SE, self).__init__()
187 self.construct(w_in, w_se)
188
189 def construct(self, w_in, w_se):
190 # AvgPool
191 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
192 # FC, Activation, FC, Sigmoid
193 self.f_ex = nn.Sequential(
194 nn.Conv2d(w_in, w_se, kernel_size=1, bias=True),
195 nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE),
196 nn.Conv2d(w_se, w_in, kernel_size=1, bias=True),
197 nn.Sigmoid(),
198 )
199
200 def forward(self, x):
201 return x * self.f_ex(self.avg_pool(x))
202
203
204class BottleneckTransform(nn.Module):
205 """Bottlenect transformation: 1x1, 3x3, 1x1"""
206
207 def __init__(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
208 super(BottleneckTransform, self).__init__()
209 self.construct(w_in, w_out, stride, bn_norm, bm, gw, se_r)
210
211 def construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
212 # Compute the bottleneck width
213 w_b = int(round(w_out * bm))
214 # Compute the number of groups
215 num_gs = w_b // gw
216 # 1x1, BN, ReLU
217 self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False)
218 self.a_bn = get_norm(bn_norm, w_b)
219 self.a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
220 # 3x3, BN, ReLU
221 self.b = nn.Conv2d(
222 w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False
223 )
224 self.b_bn = get_norm(bn_norm, w_b)
225 self.b_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
226 # Squeeze-and-Excitation (SE)
227 if se_r:
228 w_se = int(round(w_in * se_r))
229 self.se = SE(w_b, w_se)
230 # 1x1, BN
231 self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False)
232 self.c_bn = get_norm(bn_norm, w_out)
233 self.c_bn.final_bn = True
234
235 def forward(self, x):
236 for layer in self.children():
237 x = layer(x)
238 return x
239
240
241class ResBottleneckBlock(nn.Module):
242 """Residual bottleneck block: x + F(x), F = bottleneck transform"""
243
244 def __init__(self, w_in, w_out, stride, bn_norm, bm=1.0, gw=1, se_r=None):
245 super(ResBottleneckBlock, self).__init__()
246 self.construct(w_in, w_out, stride, bn_norm, bm, gw, se_r)
247
248 def _add_skip_proj(self, w_in, w_out, stride, bn_norm):
249 self.proj = nn.Conv2d(
250 w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False
251 )
252 self.bn = get_norm(bn_norm, w_out)
253
254 def construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
255 # Use skip connection with projection if shape changes
256 self.proj_block = (w_in != w_out) or (stride != 1)
257 if self.proj_block:
258 self._add_skip_proj(w_in, w_out, stride, bn_norm)
259 self.f = BottleneckTransform(w_in, w_out, stride, bn_norm, bm, gw, se_r)
260 self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
261
262 def forward(self, x):
263 if self.proj_block:
264 x = self.bn(self.proj(x)) + self.f(x)
265 else:
266 x = x + self.f(x)
267 x = self.relu(x)
268 return x
269
270
271class ResStemCifar(nn.Module):
272 """ResNet stem for CIFAR."""
273
274 def __init__(self, w_in, w_out, bn_norm):
275 super(ResStemCifar, self).__init__()
276 self.construct(w_in, w_out, bn_norm)
277
278 def construct(self, w_in, w_out, bn_norm):
279 # 3x3, BN, ReLU
280 self.conv = nn.Conv2d(
281 w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
282 )
283 self.bn = get_norm(bn_norm, w_out)
284 self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
285
286 def forward(self, x):
287 for layer in self.children():
288 x = layer(x)
289 return x
290
291
292class ResStemIN(nn.Module):
293 """ResNet stem for ImageNet."""
294
295 def __init__(self, w_in, w_out, bn_norm):
296 super(ResStemIN, self).__init__()
297 self.construct(w_in, w_out, bn_norm)
298
299 def construct(self, w_in, w_out, bn_norm):
300 # 7x7, BN, ReLU, maxpool
301 self.conv = nn.Conv2d(
302 w_in, w_out, kernel_size=7, stride=2, padding=3, bias=False
303 )
304 self.bn = get_norm(bn_norm, w_out)
305 self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
306 self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
307
308 def forward(self, x):
309 for layer in self.children():
310 x = layer(x)
311 return x
312
313
314class SimpleStemIN(nn.Module):
315 """Simple stem for ImageNet."""
316
317 def __init__(self, in_w, out_w, bn_norm):
318 super(SimpleStemIN, self).__init__()
319 self.construct(in_w, out_w, bn_norm)
320
321 def construct(self, in_w, out_w, bn_norm):
322 # 3x3, BN, ReLU
323 self.conv = nn.Conv2d(
324 in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False
325 )
326 self.bn = get_norm(bn_norm, out_w)
327 self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
328
329 def forward(self, x):
330 for layer in self.children():
331 x = layer(x)
332 return x
333
334
335class AnyStage(nn.Module):
336 """AnyNet stage (sequence of blocks w/ the same output shape)."""
337
338 def __init__(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r):
339 super(AnyStage, self).__init__()
340 self.construct(w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
341
342 def construct(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r):
343 # Construct the blocks
344 for i in range(d):
345 # Stride and w_in apply to the first block of the stage
346 b_stride = stride if i == 0 else 1
347 b_w_in = w_in if i == 0 else w_out
348 # Construct the block
349 self.add_module(
350 "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bn_norm, bm, gw, se_r)
351 )
352
353 def forward(self, x):
354 for block in self.children():
355 x = block(x)
356 return x
357
358
359class AnyNet(nn.Module):
360 """AnyNet model."""
361
362 def __init__(self, **kwargs):
363 super(AnyNet, self).__init__()
364 if kwargs:
365 self.construct(
366 stem_type=kwargs["stem_type"],
367 stem_w=kwargs["stem_w"],
368 block_type=kwargs["block_type"],
369 ds=kwargs["ds"],
370 ws=kwargs["ws"],
371 ss=kwargs["ss"],
372 bn_norm=kwargs["bn_norm"],
373 bms=kwargs["bms"],
374 gws=kwargs["gws"],
375 se_r=kwargs["se_r"],
376 )
377 else:
378 self.construct(
379 stem_type=regnet_cfg.ANYNET.STEM_TYPE,
380 stem_w=regnet_cfg.ANYNET.STEM_W,
381 block_type=regnet_cfg.ANYNET.BLOCK_TYPE,
382 ds=regnet_cfg.ANYNET.DEPTHS,
383 ws=regnet_cfg.ANYNET.WIDTHS,
384 ss=regnet_cfg.ANYNET.STRIDES,
385 bn_norm=regnet_cfg.ANYNET.BN_NORM,
386 bms=regnet_cfg.ANYNET.BOT_MULS,
387 gws=regnet_cfg.ANYNET.GROUP_WS,
388 se_r=regnet_cfg.ANYNET.SE_R if regnet_cfg.ANYNET.SE_ON else None,
389 )
390 self.apply(init_weights)
391
392 def construct(self, stem_type, stem_w, block_type, ds, ws, ss, bn_norm, bms, gws, se_r):
393 # Generate dummy bot muls and gs for models that do not use them
394 bms = bms if bms else [1.0 for _d in ds]
395 gws = gws if gws else [1 for _d in ds]
396 # Group params by stage
397 stage_params = list(zip(ds, ws, ss, bms, gws))
398 # Construct the stem
399 stem_fun = get_stem_fun(stem_type)
400 self.stem = stem_fun(3, stem_w, bn_norm)
401 # Construct the stages
402 block_fun = get_block_fun(block_type)
403 prev_w = stem_w
404 for i, (d, w, s, bm, gw) in enumerate(stage_params):
405 self.add_module(
406 "s{}".format(i + 1), AnyStage(prev_w, w, s, bn_norm, d, block_fun, bm, gw, se_r)
407 )
408 prev_w = w
409 # Construct the head
410 self.in_planes = prev_w
411 # self.head = AnyHead(w_in=prev_w, nc=nc)
412
413 def forward(self, x):
414 for module in self.children():
415 x = module(x)
416 return x
417
418
420 """Converts a float to closest non-zero int divisible by q."""
421 return int(round(f / q) * q)
422
423
424def adjust_ws_gs_comp(ws, bms, gs):
425 """Adjusts the compatibility of widths and groups."""
426 ws_bot = [int(w * b) for w, b in zip(ws, bms)]
427 gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
428 ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
429 ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
430 return ws, gs
431
432
434 """Gets ws/ds of network at each stage from per block values."""
435 ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
436 ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
437 s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
438 s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
439 return s_ws, s_ds
440
441
442def generate_regnet(w_a, w_0, w_m, d, q=8):
443 """Generates per block ws from RegNet parameters."""
444 assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
445 ws_cont = np.arange(d) * w_a + w_0
446 ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
447 ws = w_0 * np.power(w_m, ks)
448 ws = np.round(np.divide(ws, q)) * q
449 num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
450 ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
451 return ws, num_stages, max_stage, ws_cont
452
453
455 """RegNet model."""
456
457 def __init__(self, last_stride, bn_norm):
458 # Generate RegNet ws per block
459 b_ws, num_s, _, _ = generate_regnet(
460 regnet_cfg.REGNET.WA, regnet_cfg.REGNET.W0, regnet_cfg.REGNET.WM, regnet_cfg.REGNET.DEPTH
461 )
462 # Convert to per stage format
463 ws, ds = get_stages_from_blocks(b_ws, b_ws)
464 # Generate group widths and bot muls
465 gws = [regnet_cfg.REGNET.GROUP_W for _ in range(num_s)]
466 bms = [regnet_cfg.REGNET.BOT_MUL for _ in range(num_s)]
467 # Adjust the compatibility of ws and gws
468 ws, gws = adjust_ws_gs_comp(ws, bms, gws)
469 # Use the same stride for each stage
470 ss = [regnet_cfg.REGNET.STRIDE for _ in range(num_s)]
471 ss[-1] = last_stride
472 # Use SE for RegNetY
473 se_r = regnet_cfg.REGNET.SE_R if regnet_cfg.REGNET.SE_ON else None
474 # Construct the model
475 kwargs = {
476 "stem_type": regnet_cfg.REGNET.STEM_TYPE,
477 "stem_w": regnet_cfg.REGNET.STEM_W,
478 "block_type": regnet_cfg.REGNET.BLOCK_TYPE,
479 "ss": ss,
480 "ds": ds,
481 "ws": ws,
482 "bn_norm": bn_norm,
483 "bms": bms,
484 "gws": gws,
485 "se_r": se_r,
486 }
487 super(RegNet, self).__init__(**kwargs)
488
489
491 """Initializes model with pretrained weights.
492
493 Layers that don't match with pretrained layers in name or size are kept unchanged.
494 """
495 import os
496 import errno
497 import gdown
498
499 def _get_torch_home():
500 ENV_TORCH_HOME = 'TORCH_HOME'
501 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
502 DEFAULT_CACHE_DIR = '~/.cache'
503 torch_home = os.path.expanduser(
504 os.getenv(
505 ENV_TORCH_HOME,
506 os.path.join(
507 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
508 )
509 )
510 )
511 return torch_home
512
513 torch_home = _get_torch_home()
514 model_dir = os.path.join(torch_home, 'checkpoints')
515 try:
516 os.makedirs(model_dir)
517 except OSError as e:
518 if e.errno == errno.EEXIST:
519 # Directory already exists, ignore.
520 pass
521 else:
522 # Unexpected OSError, re-raise.
523 raise
524
525 filename = model_urls[key].split('/')[-1]
526
527 cached_file = os.path.join(model_dir, filename)
528
529 if not os.path.exists(cached_file):
530 if comm.is_main_process():
531 gdown.download(model_urls[key], cached_file, quiet=False)
532
533 comm.synchronize()
534
535 logger.info(f"Loading pretrained model from {cached_file}")
536 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
537
538 return state_dict
539
540
541@BACKBONE_REGISTRY.register()
543 # fmt: off
544 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
545 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
546 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
547 bn_norm = cfg.MODEL.BACKBONE.NORM
548 depth = cfg.MODEL.BACKBONE.DEPTH
549 # fmt: on
550
551 cfg_files = {
552 '800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
553 '800y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
554 '1600x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
555 '1600y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
556 '3200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
557 '3200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
558 '4000x': 'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
559 '4000y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
560 '6400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
561 '6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
562 }[depth]
563
564 regnet_cfg.merge_from_file(cfg_files)
565 model = RegNet(last_stride, bn_norm)
566
567 if pretrain:
568 # Load pretrain path if specifically
569 if pretrain_path:
570 try:
571 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
572 logger.info(f"Loading pretrained model from {pretrain_path}")
573 except FileNotFoundError as e:
574 logger.info(f'{pretrain_path} is not found! Please check this path.')
575 raise e
576 except KeyError as e:
577 logger.info("State dict keys error! Please check the state dict.")
578 raise e
579 else:
580 key = depth
581 state_dict = init_pretrained_weights(key)
582
583 incompatible = model.load_state_dict(state_dict, strict=False)
584 if incompatible.missing_keys:
585 logger.info(
586 get_missing_parameters_message(incompatible.missing_keys)
587 )
588 if incompatible.unexpected_keys:
589 logger.info(
590 get_unexpected_parameters_message(incompatible.unexpected_keys)
591 )
592 return model
construct(self, stem_type, stem_w, block_type, ds, ws, ss, bn_norm, bms, gws, se_r)
Definition regnet.py:392
__init__(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
Definition regnet.py:338
construct(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
Definition regnet.py:342
__init__(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:127
construct(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:131
construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r)
Definition regnet.py:211
__init__(self, w_in, w_out, stride, bn_norm, bm, gw, se_r)
Definition regnet.py:207
__init__(self, last_stride, bn_norm)
Definition regnet.py:457
construct(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:165
__init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None)
Definition regnet.py:152
_add_skip_proj(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:159
_add_skip_proj(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:248
__init__(self, w_in, w_out, stride, bn_norm, bm=1.0, gw=1, se_r=None)
Definition regnet.py:244
construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r)
Definition regnet.py:254
__init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None)
Definition regnet.py:99
construct(self, w_in, w_out, stride, bn_norm)
Definition regnet.py:106
generate_regnet(w_a, w_0, w_m, d, q=8)
Definition regnet.py:442