10def load_model(model, path, ignore_key=None, device='cuda:0'):
11 """
12 모델의 파라미터를 로드하는 기능
13 args:
14 model : 파이토치 모델
15 path (str): 파라미터 경로
16 ignore_key (list[str]): 로드하지 않는 파라미터
17 device (str): 모델이 구동될 디바이스
18 """
19
20 print('model path : ', path)
21
22 state_dict = torch.load(path, map_location=device)
23
24
25 module_names = model.get_module_names()
26 if ignore_key is not None:
27 module_names = remove_items(module_names, ignore_key)
28
29
30 for name in module_names:
31
32 try:
33 m = getattr(model, name)
34 except AttributeError:
35 print(f'The model does not have {name} attribute.')
36 print('Please check the model’s get_module_names() function.')
37 continue
38
39
40 if name in state_dict.keys():
41 _dict = state_dict[name]
42 print('load ', name)
43
44 for i, k in zip(m.state_dict(), _dict):
45 param = _dict[k]
46 m.state_dict()[i].copy_(param)
47 else:
48 print(f'{name} is missing from the parameter file.')
49