11from .config
import cfg
as regnet_cfg
12from ..build
import BACKBONE_REGISTRY
14logger = logging.getLogger(__name__)
16 '800x':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160905981/RegNetX-200MF_dds_8gpu.pyth',
17 '800y':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906567/RegNetY-800MF_dds_8gpu.pyth',
18 '1600x':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160990626/RegNetX-1.6GF_dds_8gpu.pyth',
19 '1600y':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906681/RegNetY-1.6GF_dds_8gpu.pyth',
20 '3200x':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906139/RegNetX-3.2GF_dds_8gpu.pyth',
21 '3200y':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906834/RegNetY-3.2GF_dds_8gpu.pyth',
22 '4000x':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906383/RegNetX-4.0GF_dds_8gpu.pyth',
23 '4000y':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906838/RegNetY-4.0GF_dds_8gpu.pyth',
24 '6400x':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161116590/RegNetX-6.4GF_dds_8gpu.pyth',
25 '6400y':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
30 """Performs ResNet-style weight initialization."""
31 if isinstance(m, nn.Conv2d):
33 fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
34 m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
35 elif isinstance(m, nn.BatchNorm2d):
37 hasattr(m,
"final_bn")
and m.final_bn
and regnet_cfg.BN.ZERO_INIT_FINAL_GAMMA
39 m.weight.data.fill_(0.0
if zero_init_gamma
else 1.0)
41 elif isinstance(m, nn.Linear):
42 m.weight.data.normal_(mean=0.0, std=0.01)
47 """Retrives the stem function by name."""
49 "res_stem_cifar": ResStemCifar,
50 "res_stem_in": ResStemIN,
51 "simple_stem_in": SimpleStemIN,
53 assert stem_type
in stem_funs.keys(),
"Stem type '{}' not supported".format(
56 return stem_funs[stem_type]
60 """Retrieves the block function by name."""
62 "vanilla_block": VanillaBlock,
63 "res_basic_block": ResBasicBlock,
64 "res_bottleneck_block": ResBottleneckBlock,
66 assert block_type
in block_funs.keys(),
"Block type '{}' not supported".format(
69 return block_funs[block_type]
73 """Drop connect (adapted from DARTS)."""
74 keep_ratio = 1.0 - drop_ratio
75 mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
76 mask.bernoulli_(keep_ratio)
87 self.
fc = nn.Linear(w_in, nc, bias=
True)
91 x = x.view(x.size(0), -1)
97 """Vanilla block: [3x3 conv, BN, Relu] x2"""
99 def __init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None):
101 bm
is None and gw
is None and se_r
is None
102 ),
"Vanilla block does not support bm, gw, and se_r options"
103 super(VanillaBlock, self).
__init__()
104 self.
construct(w_in, w_out, stride, bn_norm)
109 w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=
False
111 self.
a_bn = get_norm(bn_norm, w_out)
112 self.
a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
114 self.
b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=
False)
115 self.
b_bn = get_norm(bn_norm, w_out)
116 self.
b_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
119 for layer
in self.children():
125 """Basic transformation: [3x3 conv, BN, Relu] x2"""
128 super(BasicTransform, self).
__init__()
129 self.
construct(w_in, w_out, stride, bn_norm)
134 w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=
False
136 self.
a_bn = get_norm(bn_norm, w_out)
137 self.
a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
139 self.
b = nn.Conv2d(w_out, w_out, kernel_size=3, stride=1, padding=1, bias=
False)
140 self.
b_bn = get_norm(bn_norm, w_out)
141 self.
b_bn.final_bn =
True
144 for layer
in self.children():
150 """Residual basic block: x + F(x), F = basic transform"""
152 def __init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None):
154 bm
is None and gw
is None and se_r
is None
155 ),
"Basic transform does not support bm, gw, and se_r options"
156 super(ResBasicBlock, self).
__init__()
157 self.
construct(w_in, w_out, stride, bn_norm)
161 w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=
False
163 self.
bn = get_norm(bn_norm, w_out)
171 self.
relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
175 x = self.
bn(self.
proj(x)) + self.
f(x)
183 """Squeeze-and-Excitation (SE) block"""
194 nn.Conv2d(w_in, w_se, kernel_size=1, bias=
True),
195 nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE),
196 nn.Conv2d(w_se, w_in, kernel_size=1, bias=
True),
205 """Bottlenect transformation: 1x1, 3x3, 1x1"""
207 def __init__(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
208 super(BottleneckTransform, self).
__init__()
209 self.
construct(w_in, w_out, stride, bn_norm, bm, gw, se_r)
211 def construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
213 w_b = int(round(w_out * bm))
217 self.
a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=
False)
218 self.
a_bn = get_norm(bn_norm, w_b)
219 self.
a_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
222 w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=
False
224 self.
b_bn = get_norm(bn_norm, w_b)
225 self.
b_relu = nn.ReLU(inplace=regnet_cfg.MEM.RELU_INPLACE)
228 w_se = int(round(w_in * se_r))
231 self.
c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=
False)
232 self.
c_bn = get_norm(bn_norm, w_out)
233 self.
c_bn.final_bn =
True
236 for layer
in self.children():
242 """Residual bottleneck block: x + F(x), F = bottleneck transform"""
244 def __init__(self, w_in, w_out, stride, bn_norm, bm=1.0, gw=1, se_r=None):
245 super(ResBottleneckBlock, self).
__init__()
246 self.
construct(w_in, w_out, stride, bn_norm, bm, gw, se_r)
250 w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=
False
252 self.
bn = get_norm(bn_norm, w_out)
254 def construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r):
260 self.
relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
264 x = self.
bn(self.
proj(x)) + self.
f(x)
272 """ResNet stem for CIFAR."""
275 super(ResStemCifar, self).
__init__()
281 w_in, w_out, kernel_size=3, stride=1, padding=1, bias=
False
283 self.
bn = get_norm(bn_norm, w_out)
284 self.
relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
287 for layer
in self.children():
293 """ResNet stem for ImageNet."""
302 w_in, w_out, kernel_size=7, stride=2, padding=3, bias=
False
304 self.
bn = get_norm(bn_norm, w_out)
305 self.
relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
306 self.
pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
309 for layer
in self.children():
315 """Simple stem for ImageNet."""
318 super(SimpleStemIN, self).
__init__()
324 in_w, out_w, kernel_size=3, stride=2, padding=1, bias=
False
326 self.
bn = get_norm(bn_norm, out_w)
327 self.
relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
330 for layer
in self.children():
336 """AnyNet stage (sequence of blocks w/ the same output shape)."""
338 def __init__(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r):
340 self.
construct(w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
342 def construct(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r):
346 b_stride = stride
if i == 0
else 1
347 b_w_in = w_in
if i == 0
else w_out
350 "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bn_norm, bm, gw, se_r)
354 for block
in self.children():
366 stem_type=kwargs[
"stem_type"],
367 stem_w=kwargs[
"stem_w"],
368 block_type=kwargs[
"block_type"],
372 bn_norm=kwargs[
"bn_norm"],
379 stem_type=regnet_cfg.ANYNET.STEM_TYPE,
380 stem_w=regnet_cfg.ANYNET.STEM_W,
381 block_type=regnet_cfg.ANYNET.BLOCK_TYPE,
382 ds=regnet_cfg.ANYNET.DEPTHS,
383 ws=regnet_cfg.ANYNET.WIDTHS,
384 ss=regnet_cfg.ANYNET.STRIDES,
385 bn_norm=regnet_cfg.ANYNET.BN_NORM,
386 bms=regnet_cfg.ANYNET.BOT_MULS,
387 gws=regnet_cfg.ANYNET.GROUP_WS,
388 se_r=regnet_cfg.ANYNET.SE_R
if regnet_cfg.ANYNET.SE_ON
else None,
390 self.apply(init_weights)
392 def construct(self, stem_type, stem_w, block_type, ds, ws, ss, bn_norm, bms, gws, se_r):
394 bms = bms
if bms
else [1.0
for _d
in ds]
395 gws = gws
if gws
else [1
for _d
in ds]
397 stage_params = list(zip(ds, ws, ss, bms, gws))
400 self.
stem = stem_fun(3, stem_w, bn_norm)
404 for i, (d, w, s, bm, gw)
in enumerate(stage_params):
406 "s{}".format(i + 1),
AnyStage(prev_w, w, s, bn_norm, d, block_fun, bm, gw, se_r)
414 for module
in self.children():
420 """Converts a float to closest non-zero int divisible by q."""
421 return int(round(f / q) * q)
425 """Adjusts the compatibility of widths and groups."""
426 ws_bot = [int(w * b)
for w, b
in zip(ws, bms)]
427 gs = [min(g, w_bot)
for g, w_bot
in zip(gs, ws_bot)]
428 ws_bot = [
quantize_float(w_bot, g)
for w_bot, g
in zip(ws_bot, gs)]
429 ws = [int(w_bot / b)
for w_bot, b
in zip(ws_bot, bms)]
434 """Gets ws/ds of network at each stage from per block values."""
435 ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
436 ts = [w != wp
or r != rp
for w, wp, r, rp
in ts_temp]
437 s_ws = [w
for w, t
in zip(ws, ts[:-1])
if t]
438 s_ds = np.diff([d
for d, t
in zip(range(len(ts)), ts)
if t]).tolist()
443 """Generates per block ws from RegNet parameters."""
444 assert w_a >= 0
and w_0 > 0
and w_m > 1
and w_0 % q == 0
445 ws_cont = np.arange(d) * w_a + w_0
446 ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
447 ws = w_0 * np.power(w_m, ks)
448 ws = np.round(np.divide(ws, q)) * q
449 num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
450 ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
451 return ws, num_stages, max_stage, ws_cont
460 regnet_cfg.REGNET.WA, regnet_cfg.REGNET.W0, regnet_cfg.REGNET.WM, regnet_cfg.REGNET.DEPTH
465 gws = [regnet_cfg.REGNET.GROUP_W
for _
in range(num_s)]
466 bms = [regnet_cfg.REGNET.BOT_MUL
for _
in range(num_s)]
470 ss = [regnet_cfg.REGNET.STRIDE
for _
in range(num_s)]
473 se_r = regnet_cfg.REGNET.SE_R
if regnet_cfg.REGNET.SE_ON
else None
476 "stem_type": regnet_cfg.REGNET.STEM_TYPE,
477 "stem_w": regnet_cfg.REGNET.STEM_W,
478 "block_type": regnet_cfg.REGNET.BLOCK_TYPE,
487 super(RegNet, self).
__init__(**kwargs)
491 """Initializes model with pretrained weights.
493 Layers that don't match with pretrained layers in name or size are kept unchanged.
499 def _get_torch_home():
500 ENV_TORCH_HOME =
'TORCH_HOME'
501 ENV_XDG_CACHE_HOME =
'XDG_CACHE_HOME'
502 DEFAULT_CACHE_DIR =
'~/.cache'
503 torch_home = os.path.expanduser(
507 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR),
'torch'
513 torch_home = _get_torch_home()
514 model_dir = os.path.join(torch_home,
'checkpoints')
516 os.makedirs(model_dir)
518 if e.errno == errno.EEXIST:
525 filename = model_urls[key].split(
'/')[-1]
527 cached_file = os.path.join(model_dir, filename)
529 if not os.path.exists(cached_file):
530 if comm.is_main_process():
531 gdown.download(model_urls[key], cached_file, quiet=
False)
535 logger.info(f
"Loading pretrained model from {cached_file}")
536 state_dict = torch.load(cached_file, map_location=torch.device(
'cpu'))[
'model_state']
541@BACKBONE_REGISTRY.register()
544 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
545 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
546 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
547 bn_norm = cfg.MODEL.BACKBONE.NORM
548 depth = cfg.MODEL.BACKBONE.DEPTH
552 '800x':
'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
553 '800y':
'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
554 '1600x':
'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
555 '1600y':
'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
556 '3200x':
'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
557 '3200y':
'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
558 '4000x':
'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
559 '4000y':
'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
560 '6400x':
'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
561 '6400y':
'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
564 regnet_cfg.merge_from_file(cfg_files)
565 model =
RegNet(last_stride, bn_norm)
571 state_dict = torch.load(pretrain_path, map_location=torch.device(
'cpu'))
572 logger.info(f
"Loading pretrained model from {pretrain_path}")
573 except FileNotFoundError
as e:
574 logger.info(f
'{pretrain_path} is not found! Please check this path.')
576 except KeyError
as e:
577 logger.info(
"State dict keys error! Please check the state dict.")
583 incompatible = model.load_state_dict(state_dict, strict=
False)
584 if incompatible.missing_keys:
586 get_missing_parameters_message(incompatible.missing_keys)
588 if incompatible.unexpected_keys:
590 get_unexpected_parameters_message(incompatible.unexpected_keys)
construct(self, stem_type, stem_w, block_type, ds, ws, ss, bn_norm, bms, gws, se_r)
__init__(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
construct(self, w_in, w_out, stride, bn_norm, d, block_fun, bm, gw, se_r)
__init__(self, last_stride, bn_norm)
construct(self, w_in, w_out, stride, bn_norm)
__init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None)
_add_skip_proj(self, w_in, w_out, stride, bn_norm)
_add_skip_proj(self, w_in, w_out, stride, bn_norm)
__init__(self, w_in, w_out, stride, bn_norm, bm=1.0, gw=1, se_r=None)
construct(self, w_in, w_out, stride, bn_norm, bm, gw, se_r)
__init__(self, w_in, w_out, bn_norm)
construct(self, w_in, w_out, bn_norm)
construct(self, w_in, w_out, bn_norm)
__init__(self, w_in, w_out, bn_norm)
__init__(self, w_in, w_se)
construct(self, w_in, w_se)
__init__(self, in_w, out_w, bn_norm)
construct(self, in_w, out_w, bn_norm)
__init__(self, w_in, w_out, stride, bn_norm, bm=None, gw=None, se_r=None)
construct(self, w_in, w_out, stride, bn_norm)
init_pretrained_weights(key)
generate_regnet(w_a, w_0, w_m, d, q=8)
adjust_ws_gs_comp(ws, bms, gs)
build_regnet_backbone(cfg)
get_stages_from_blocks(ws, rs)
get_block_fun(block_type)
drop_connect(x, drop_ratio)