10def load_model(model, path, ignore_key=None, device='cuda:0
'):
16 ignore_key (list[str]): 로드하지 않는 파라미터
17 device (str): 모델이 구동될 디바이스
20 print(
'model path : ', path)
22 state_dict = torch.load(path, map_location=device)
25 module_names = model.get_module_names()
26 if ignore_key
is not None:
27 module_names = remove_items(module_names, ignore_key)
30 for name
in module_names:
33 m = getattr(model, name)
34 except AttributeError:
35 print(f
'The model does not have {name} attribute.')
36 print(
'Please check the model’s get_module_names() function.')
40 if name
in state_dict.keys():
41 _dict = state_dict[name]
44 for i, k
in zip(m.state_dict(), _dict):
46 m.state_dict()[i].copy_(param)
48 print(f
'{name} is missing from the parameter file.')
100 for m
in model.modules():
101 if isinstance(m, nn.Conv2d):
102 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
103 m.weight.data.normal_(0, math.sqrt(2. / n))
104 elif isinstance(m, nn.Conv3d):
105 n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
106 m.weight.data.normal_(0, math.sqrt(2. / n))
107 elif isinstance(m, nn.BatchNorm2d):
108 m.weight.data.fill_(1)
110 elif isinstance(m, nn.BatchNorm3d):
111 m.weight.data.fill_(1)
113 elif isinstance(m, nn.Linear):