223 """Initializes model with pretrained weights.
225 Layers that don't match with pretrained layers in name or size are kept unchanged.
231 def _get_torch_home():
232 ENV_TORCH_HOME =
'TORCH_HOME'
233 ENV_XDG_CACHE_HOME =
'XDG_CACHE_HOME'
234 DEFAULT_CACHE_DIR =
'~/.cache'
235 torch_home = os.path.expanduser(
239 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR),
'torch'
245 torch_home = _get_torch_home()
246 model_dir = os.path.join(torch_home,
'checkpoints')
248 os.makedirs(model_dir)
250 if e.errno == errno.EEXIST:
257 filename = model_urls[key].split(
'/')[-1]
259 cached_file = os.path.join(model_dir, filename)
261 if not os.path.exists(cached_file):
262 if comm.is_main_process():
263 gdown.download(model_urls[key], cached_file, quiet=
False)
267 logger.info(f
"Loading pretrained model from {cached_file}")
268 state_dict = torch.load(cached_file, map_location=torch.device(
'cpu'))
273@BACKBONE_REGISTRY.register()
276 Create a ResNeXt instance from config.
278 ResNeXt: a :class:`ResNeXt` instance.
282 pretrain = cfg.MODEL.BACKBONE.PRETRAIN
283 pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
284 last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
285 bn_norm = cfg.MODEL.BACKBONE.NORM
286 with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
287 with_nl = cfg.MODEL.BACKBONE.WITH_NL
288 depth = cfg.MODEL.BACKBONE.DEPTH
291 num_blocks_per_stage = {
293 '101x': [3, 4, 23, 3],
294 '152x': [3, 8, 36, 3], }[depth]
295 nl_layers_per_stage = {
297 '101x': [0, 2, 3, 0]}[depth]
298 model =
ResNeXt(last_stride, bn_norm, with_ibn, with_nl, Bottleneck,
299 num_blocks_per_stage, nl_layers_per_stage)
303 state_dict = torch.load(pretrain_path, map_location=torch.device(
'cpu'))[
'model']
307 new_k =
'.'.join(k.split(
'.')[2:])
308 if new_k
in model.state_dict()
and (model.state_dict()[new_k].shape == state_dict[k].shape):
309 new_state_dict[new_k] = state_dict[k]
310 state_dict = new_state_dict
311 logger.info(f
"Loading pretrained model from {pretrain_path}")
312 except FileNotFoundError
as e:
313 logger.info(f
'{pretrain_path} is not found! Please check this path.')
315 except KeyError
as e:
316 logger.info(
"State dict keys error! Please check the state dict.")
320 if with_ibn: key =
'ibn_' + key
324 incompatible = model.load_state_dict(state_dict, strict=
False)
325 if incompatible.missing_keys:
327 get_missing_parameters_message(incompatible.missing_keys)
329 if incompatible.unexpected_keys:
331 get_unexpected_parameters_message(incompatible.unexpected_keys)
__init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1, downsample=None)
__init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, baseWidth=4, cardinality=32)