8"""EfficientNet models."""
19from .config
import cfg
as effnet_cfg
20from .regnet
import drop_connect, init_weights
22logger = logging.getLogger(__name__)
24 'b0':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
25 'b1':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
26 'b2':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
27 'b3':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
28 'b4':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
29 'b5':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
30 'b6':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
31 'b7':
'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
36 """EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
40 self.
conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=
False)
50 """Swish activation function: x * sigmoid(x)."""
53 super(Swish, self).__init__()
56 return x * torch.sigmoid(x)
60 """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
66 nn.Conv2d(w_in, w_se, 1, bias=
True),
68 nn.Conv2d(w_se, w_in, 1, bias=
True),
77 """Mobile inverted bottleneck block w/ SE (MBConv)."""
79 def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
83 w_exp = int(w_in * exp_r)
85 self.
exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=
False)
86 self.
exp_bn = get_norm(bn_norm, w_exp)
88 dwise_args = {
"groups": w_exp,
"padding": (kernel - 1) // 2,
"bias":
False}
89 self.
dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
92 self.
se =
SE(w_exp, int(w_in * se_r))
93 self.
lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=
False)
106 if self.training
and effnet_cfg.EN.DC_RATIO > 0.0:
107 f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
113 """EfficientNet stage."""
115 def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
118 b_stride = stride
if i == 0
else 1
119 b_w_in = w_in
if i == 0
else w_out
120 name =
"b{}".format(i + 1)
121 self.add_module(name,
MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
124 for block
in self.children():
130 """EfficientNet stem for ImageNet: 3x3, BN, Swish."""
134 self.
conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=
False)
135 self.
bn = get_norm(bn_norm, w_out)
139 for layer
in self.children():
145 """EfficientNet model."""
150 "stem_w": effnet_cfg.EN.STEM_W,
151 "ds": effnet_cfg.EN.DEPTHS,
152 "ws": effnet_cfg.EN.WIDTHS,
153 "exp_rs": effnet_cfg.EN.EXP_RATIOS,
154 "se_r": effnet_cfg.EN.SE_R,
155 "ss": effnet_cfg.EN.STRIDES,
156 "ks": effnet_cfg.EN.KERNELS,
157 "head_w": effnet_cfg.EN.HEAD_W,
160 def __init__(self, last_stride, bn_norm, **kwargs):
162 kwargs = self.
get_args()
if not kwargs
else kwargs
163 self.
_construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
164 self.apply(init_weights)
166 def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
167 stage_params = list(zip(ds, ws, exp_rs, ss, ks))
170 for i, (d, w, exp_r, stride, kernel)
in enumerate(stage_params):
171 name =
"s{}".format(i + 1)
172 if i == 5: stride = last_stride
173 self.add_module(name,
EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
178 for module
in self.children():
184 """Initializes model with pretrained weights.
186 Layers that don't match with pretrained layers in name or size are kept unchanged.
192 def _get_torch_home():
193 ENV_TORCH_HOME =
'TORCH_HOME'
194 ENV_XDG_CACHE_HOME =
'XDG_CACHE_HOME'
195 DEFAULT_CACHE_DIR =
'~/.cache'
196 torch_home = os.path.expanduser(
200 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR),
'torch'
206 torch_home = _get_torch_home()
207 model_dir = os.path.join(torch_home,
'checkpoints')
209 os.makedirs(model_dir)
211 if e.errno == errno.EEXIST:
218 filename = model_urls[key].split(
'/')[-1]
220 cached_file = os.path.join(model_dir, filename)
222 if not os.path.exists(cached_file):
223 if comm.is_main_process():
224 gdown.download(model_urls[key], cached_file, quiet=
False)
228 logger.info(f
"Loading pretrained model from {cached_file}")
229 state_dict = torch.load(cached_file, map_location=torch.device(
'cpu'))[
'model_state']
234@BACKBONE_REGISTRY.register()
237 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
238 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
239 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
240 bn_norm = cfg.MODEL.BACKBONE.NORM
241 depth = cfg.MODEL.BACKBONE.DEPTH
245 'b0':
'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
246 'b1':
'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
247 'b2':
'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
248 'b3':
'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
249 'b4':
'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
250 'b5':
'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
253 effnet_cfg.merge_from_file(cfg_files)
254 model =
EffNet(last_stride, bn_norm)
260 state_dict = torch.load(pretrain_path, map_location=torch.device(
'cpu'))
261 logger.info(f
"Loading pretrained model from {pretrain_path}")
262 except FileNotFoundError
as e:
263 logger.info(f
'{pretrain_path} is not found! Please check this path.')
265 except KeyError
as e:
266 logger.info(
"State dict keys error! Please check the state dict.")
272 incompatible = model.load_state_dict(state_dict, strict=
False)
273 if incompatible.missing_keys:
275 get_missing_parameters_message(incompatible.missing_keys)
277 if incompatible.unexpected_keys:
279 get_unexpected_parameters_message(incompatible.unexpected_keys)
__init__(self, w_in, w_out, bn_norm)
_construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm)
__init__(self, last_stride, bn_norm, **kwargs)
__init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm)
__init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm)
__init__(self, w_in, w_se)
__init__(self, w_in, w_out, bn_norm)
build_effnet_backbone(cfg)
init_pretrained_weights(key)