Safemotion Lib
Loading...
Searching...
No Matches
effnet.py
Go to the documentation of this file.
1# !/usr/bin/env python3
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4#
5# This source code is licensed under the MIT license found in the
6# LICENSE file in the root directory of this source tree.
7
8"""EfficientNet models."""
9
10import logging
11
12import torch
13import torch.nn as nn
14
15from fastreid.layers import *
16from fastreid.modeling.backbones.build import BACKBONE_REGISTRY
17from fastreid.utils import comm
18from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
19from .config import cfg as effnet_cfg
20from .regnet import drop_connect, init_weights
21
22logger = logging.getLogger(__name__)
23model_urls = {
24 'b0': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
25 'b1': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
26 'b2': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
27 'b3': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
28 'b4': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
29 'b5': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
30 'b6': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
31 'b7': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
32}
33
34
35class EffHead(nn.Module):
36 """EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
37
38 def __init__(self, w_in, w_out, bn_norm):
39 super(EffHead, self).__init__()
40 self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
41 self.conv_bn = get_norm(bn_norm, w_out)
43
44 def forward(self, x):
45 x = self.conv_swish(self.conv_bn(self.conv(x)))
46 return x
47
48
49class Swish(nn.Module):
50 """Swish activation function: x * sigmoid(x)."""
51
52 def __init__(self):
53 super(Swish, self).__init__()
54
55 def forward(self, x):
56 return x * torch.sigmoid(x)
57
58
59class SE(nn.Module):
60 """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
61
62 def __init__(self, w_in, w_se):
63 super(SE, self).__init__()
64 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
65 self.f_ex = nn.Sequential(
66 nn.Conv2d(w_in, w_se, 1, bias=True),
67 Swish(),
68 nn.Conv2d(w_se, w_in, 1, bias=True),
69 nn.Sigmoid(),
70 )
71
72 def forward(self, x):
73 return x * self.f_ex(self.avg_pool(x))
74
75
76class MBConv(nn.Module):
77 """Mobile inverted bottleneck block w/ SE (MBConv)."""
78
79 def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
80 # expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
81 super(MBConv, self).__init__()
82 self.exp = None
83 w_exp = int(w_in * exp_r)
84 if w_exp != w_in:
85 self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
86 self.exp_bn = get_norm(bn_norm, w_exp)
88 dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
89 self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
90 self.dwise_bn = get_norm(bn_norm, w_exp)
92 self.se = SE(w_exp, int(w_in * se_r))
93 self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
94 self.lin_proj_bn = get_norm(bn_norm, w_out)
95 # Skip connection if in and out shapes are the same (MN-V2 style)
96 self.has_skip = stride == 1 and w_in == w_out
97
98 def forward(self, x):
99 f_x = x
100 if self.exp:
101 f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
102 f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
103 f_x = self.se(f_x)
104 f_x = self.lin_proj_bn(self.lin_proj(f_x))
105 if self.has_skip:
106 if self.training and effnet_cfg.EN.DC_RATIO > 0.0:
107 f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
108 f_x = x + f_x
109 return f_x
110
111
112class EffStage(nn.Module):
113 """EfficientNet stage."""
114
115 def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
116 super(EffStage, self).__init__()
117 for i in range(d):
118 b_stride = stride if i == 0 else 1
119 b_w_in = w_in if i == 0 else w_out
120 name = "b{}".format(i + 1)
121 self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
122
123 def forward(self, x):
124 for block in self.children():
125 x = block(x)
126 return x
127
128
129class StemIN(nn.Module):
130 """EfficientNet stem for ImageNet: 3x3, BN, Swish."""
131
132 def __init__(self, w_in, w_out, bn_norm):
133 super(StemIN, self).__init__()
134 self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
135 self.bn = get_norm(bn_norm, w_out)
136 self.swish = Swish()
137
138 def forward(self, x):
139 for layer in self.children():
140 x = layer(x)
141 return x
142
143
144class EffNet(nn.Module):
145 """EfficientNet model."""
146
147 @staticmethod
148 def get_args():
149 return {
150 "stem_w": effnet_cfg.EN.STEM_W,
151 "ds": effnet_cfg.EN.DEPTHS,
152 "ws": effnet_cfg.EN.WIDTHS,
153 "exp_rs": effnet_cfg.EN.EXP_RATIOS,
154 "se_r": effnet_cfg.EN.SE_R,
155 "ss": effnet_cfg.EN.STRIDES,
156 "ks": effnet_cfg.EN.KERNELS,
157 "head_w": effnet_cfg.EN.HEAD_W,
158 }
159
160 def __init__(self, last_stride, bn_norm, **kwargs):
161 super(EffNet, self).__init__()
162 kwargs = self.get_args() if not kwargs else kwargs
163 self._construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
164 self.apply(init_weights)
165
166 def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
167 stage_params = list(zip(ds, ws, exp_rs, ss, ks))
168 self.stem = StemIN(3, stem_w, bn_norm)
169 prev_w = stem_w
170 for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
171 name = "s{}".format(i + 1)
172 if i == 5: stride = last_stride
173 self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
174 prev_w = w
175 self.head = EffHead(prev_w, head_w, bn_norm)
176
177 def forward(self, x):
178 for module in self.children():
179 x = module(x)
180 return x
181
182
184 """Initializes model with pretrained weights.
185
186 Layers that don't match with pretrained layers in name or size are kept unchanged.
187 """
188 import os
189 import errno
190 import gdown
191
192 def _get_torch_home():
193 ENV_TORCH_HOME = 'TORCH_HOME'
194 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
195 DEFAULT_CACHE_DIR = '~/.cache'
196 torch_home = os.path.expanduser(
197 os.getenv(
198 ENV_TORCH_HOME,
199 os.path.join(
200 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
201 )
202 )
203 )
204 return torch_home
205
206 torch_home = _get_torch_home()
207 model_dir = os.path.join(torch_home, 'checkpoints')
208 try:
209 os.makedirs(model_dir)
210 except OSError as e:
211 if e.errno == errno.EEXIST:
212 # Directory already exists, ignore.
213 pass
214 else:
215 # Unexpected OSError, re-raise.
216 raise
217
218 filename = model_urls[key].split('/')[-1]
219
220 cached_file = os.path.join(model_dir, filename)
221
222 if not os.path.exists(cached_file):
223 if comm.is_main_process():
224 gdown.download(model_urls[key], cached_file, quiet=False)
225
226 comm.synchronize()
227
228 logger.info(f"Loading pretrained model from {cached_file}")
229 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
230
231 return state_dict
232
233
234@BACKBONE_REGISTRY.register()
236 # fmt: off
237 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
238 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
239 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
240 bn_norm = cfg.MODEL.BACKBONE.NORM
241 depth = cfg.MODEL.BACKBONE.DEPTH
242 # fmt: on
243
244 cfg_files = {
245 'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
246 'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
247 'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
248 'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
249 'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
250 'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
251 }[depth]
252
253 effnet_cfg.merge_from_file(cfg_files)
254 model = EffNet(last_stride, bn_norm)
255
256 if pretrain:
257 # Load pretrain path if specifically
258 if pretrain_path:
259 try:
260 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
261 logger.info(f"Loading pretrained model from {pretrain_path}")
262 except FileNotFoundError as e:
263 logger.info(f'{pretrain_path} is not found! Please check this path.')
264 raise e
265 except KeyError as e:
266 logger.info("State dict keys error! Please check the state dict.")
267 raise e
268 else:
269 key = depth
270 state_dict = init_pretrained_weights(key)
271
272 incompatible = model.load_state_dict(state_dict, strict=False)
273 if incompatible.missing_keys:
274 logger.info(
275 get_missing_parameters_message(incompatible.missing_keys)
276 )
277 if incompatible.unexpected_keys:
278 logger.info(
279 get_unexpected_parameters_message(incompatible.unexpected_keys)
280 )
281 return model
__init__(self, w_in, w_out, bn_norm)
Definition effnet.py:38
_construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm)
Definition effnet.py:166
__init__(self, last_stride, bn_norm, **kwargs)
Definition effnet.py:160
__init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm)
Definition effnet.py:115
__init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm)
Definition effnet.py:79
__init__(self, w_in, w_out, bn_norm)
Definition effnet.py:132