239 """Initializes model with pretrained weights.
241 Layers that don't match with pretrained layers in name or size are kept unchanged.
247 def _get_torch_home():
248 ENV_TORCH_HOME =
'TORCH_HOME'
249 ENV_XDG_CACHE_HOME =
'XDG_CACHE_HOME'
250 DEFAULT_CACHE_DIR =
'~/.cache'
251 torch_home = os.path.expanduser(
255 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR),
'torch'
261 torch_home = _get_torch_home()
262 model_dir = os.path.join(torch_home,
'checkpoints')
264 os.makedirs(model_dir)
266 if e.errno == errno.EEXIST:
273 filename = model_urls[key].split(
'/')[-1]
275 cached_file = os.path.join(model_dir, filename)
277 if not os.path.exists(cached_file):
278 if comm.is_main_process():
279 gdown.download(model_urls[key], cached_file, quiet=
False)
283 logger.info(f
"Loading pretrained model from {cached_file}")
284 state_dict = torch.load(cached_file, map_location=torch.device(
'cpu'))
289@BACKBONE_REGISTRY.register()
292 Create a ResNet instance from config.
294 ResNet: a :class:`ResNet` instance.
298 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
299 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
300 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
301 bn_norm = cfg.MODEL.BACKBONE.NORM
302 with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
303 with_se = cfg.MODEL.BACKBONE.WITH_SE
304 with_nl = cfg.MODEL.BACKBONE.WITH_NL
305 depth = cfg.MODEL.BACKBONE.DEPTH
308 num_blocks_per_stage = {
312 '101x': [3, 4, 23, 3],
313 '152x': [3, 8, 36, 3],
316 nl_layers_per_stage = {
320 '101x': [0, 2, 9, 0],
321 '152x': [0, 4, 12, 0]
332 model =
ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
333 num_blocks_per_stage, nl_layers_per_stage)
338 state_dict = torch.load(pretrain_path, map_location=torch.device(
'cpu'))
339 logger.info(f
"Loading pretrained model from {pretrain_path}")
340 except FileNotFoundError
as e:
341 logger.info(f
'{pretrain_path} is not found! Please check this path.')
343 except KeyError
as e:
344 logger.info(
"State dict keys error! Please check the state dict.")
348 if with_ibn: key =
'ibn_' + key
349 if with_se: key =
'se_' + key
353 incompatible = model.load_state_dict(state_dict, strict=
False)
354 if incompatible.missing_keys:
356 get_missing_parameters_message(incompatible.missing_keys)
358 if incompatible.unexpected_keys:
360 get_unexpected_parameters_message(incompatible.unexpected_keys)
__init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16)
__init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16)