402def init_pretrained_weights(model, key=''):
403 """Initializes model with pretrained weights.
404
405 Layers that don't match with pretrained layers in name or size are kept unchanged.
406 """
407 import os
408 import errno
409 import gdown
410 from collections import OrderedDict
411 import warnings
412 import logging
413
414 logger = logging.getLogger(__name__)
415
416 def _get_torch_home():
417 ENV_TORCH_HOME = 'TORCH_HOME'
418 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
419 DEFAULT_CACHE_DIR = '~/.cache'
420 torch_home = os.path.expanduser(
421 os.getenv(
422 ENV_TORCH_HOME,
423 os.path.join(
424 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
425 )
426 )
427 )
428 return torch_home
429
430 torch_home = _get_torch_home()
431 model_dir = os.path.join(torch_home, 'checkpoints')
432 try:
433 os.makedirs(model_dir)
434 except OSError as e:
435 if e.errno == errno.EEXIST:
436
437 pass
438 else:
439
440 raise
441 filename = key + '_imagenet.pth'
442 cached_file = os.path.join(model_dir, filename)
443
444 if not os.path.exists(cached_file):
445 if comm.is_main_process():
446 gdown.download(model_urls[key], cached_file, quiet=False)
447
448 comm.synchronize()
449
450 state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
451 model_dict = model.state_dict()
452 new_state_dict = OrderedDict()
453 matched_layers, discarded_layers = [], []
454
455 for k, v in state_dict.items():
456 if k.startswith('module.'):
457 k = k[7:]
458
459 if k in model_dict and model_dict[k].size() == v.size():
460 new_state_dict[k] = v
461 matched_layers.append(k)
462 else:
463 discarded_layers.append(k)
464
465 model_dict.update(new_state_dict)
466 model.load_state_dict(model_dict)
467
468 if len(matched_layers) == 0:
469 warnings.warn(
470 'The pretrained weights from "{}" cannot be loaded, '
471 'please check the key names manually '
472 '(** ignored and continue **)'.format(cached_file)
473 )
474 else:
475 logger.info(
476 'Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file)
477 )
478 if len(discarded_layers) > 0:
479 logger.info(
480 '** The following layers are discarded '
481 'due to unmatched keys or layer size: {}'.format(discarded_layers)
482 )
483
484
485@BACKBONE_REGISTRY.register()