Safemotion Lib
Loading...
Searching...
No Matches
utils.py
Go to the documentation of this file.
2
3import torch
4import torch.nn as nn
5
6import math
7
8from smutils.utils_data import remove_items
9
10def load_model(model, path, ignore_key=None, device='cuda:0'):
11 """
12 모델의 파라미터를 로드하는 기능
13 args:
14 model : 파이토치 모델
15 path (str): 파라미터 경로
16 ignore_key (list[str]): 로드하지 않는 파라미터
17 device (str): 모델이 구동될 디바이스
18 """
19
20 print('model path : ', path) #모델 파라미터 경로 출력
21
22 state_dict = torch.load(path, map_location=device) #파일 읽기
23
24 #로드할 모듈 셋팅
25 module_names = model.get_module_names() #모듈 네임가져오기, runner에 구현되어 있어야함
26 if ignore_key is not None:
27 module_names = remove_items(module_names, ignore_key)
28
29 #모델 파라미터 로드
30 for name in module_names:
31 #모듈 가져오기, 없으면 메시지 출력하고 다음으로 이동
32 try:
33 m = getattr(model, name)
34 except AttributeError:
35 print(f'The model does not have {name} attribute.')
36 print('Please check the model’s get_module_names() function.')
37 continue
38
39 #파라미터 로드
40 if name in state_dict.keys():
41 _dict = state_dict[name]
42 print('load ', name)
43
44 for i, k in zip(m.state_dict(), _dict):
45 param = _dict[k]
46 m.state_dict()[i].copy_(param)
47 else:
48 print(f'{name} is missing from the parameter file.')
49
50def save_model(model, save_path):
51 """
52 모델의 파라미터를 저장하는 기능
53 args:
54 model : 파이토치 모델
55 save_path (str): 파라미터 저장 경로
56 """
57
58 #로드할 모듈 셋팅
59 checkpoint_model = {}
60 module_names = model.get_module_names() #모듈 네임가져오기, runner에 구현되어 있어야함
61 for name in module_names:
62 try :
63 #모듈 네임 - 파라미터 구조의 딕셔너리 생성
64 checkpoint_model[name] = getattr(model, name).state_dict()
65 except AttributeError:
66 print(f'The model does not have {name} attribute.')
67 print('Please check the model’s get_module_names() function.')
68
69 #저장
70 torch.save(checkpoint_model, save_path)
71
72def load_optimizer(optimizer, path):
73 """
74 옵티마티저 파라미터 로드 기능
75 args:
76 optimizer : 파이토치 옵티마이저
77 path (str): 옵티마이저 파라미터 로드 경로
78 """
79 checkpoint = torch.load(path)
80 optimizer.load_state_dict(checkpoint['optimizer'])
81
82
83def save_optimizer(optimizer, save_path):
84 """
85 옵티마티저 파라미터 저장 기능
86 args:
87 optimizer : 파이토치 옵티마이저
88 save_path (str): 옵티마이저 파라미터 저장 경로
89 """
90 checkpoint_opt = {'optimizer': optimizer.state_dict()}
91 torch.save(checkpoint_opt, save_path)
92
93
94def init_weights(model):
95 """
96 모델 파라미터 초기화 기능
97 args:
98 model : 파이토치 모델
99 """
100 for m in model.modules():
101 if isinstance(m, nn.Conv2d): #conv 2d 초기화
102 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
103 m.weight.data.normal_(0, math.sqrt(2. / n))
104 elif isinstance(m, nn.Conv3d): #conv 3d 초기화
105 n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
106 m.weight.data.normal_(0, math.sqrt(2. / n))
107 elif isinstance(m, nn.BatchNorm2d): #배치놈 2d 초기화
108 m.weight.data.fill_(1)
109 m.bias.data.zero_()
110 elif isinstance(m, nn.BatchNorm3d): #배치놈 3d 초기화
111 m.weight.data.fill_(1)
112 m.bias.data.zero_()
113 elif isinstance(m, nn.Linear): #Linear 초기화, 0으로 초기화
114 m.bias.data.zero_()
115
116# def init_weights(model):
117# for m in model.modules():
118# if isinstance(m, nn.Conv2d):
119# # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
120# # m.weight.data.normal_(0, math.sqrt(2. / n))
121# m.weight.data.normal_(0.0, 0.02)
122# if m.bias is not None:
123# m.bias.data.fill_(0)
124# elif isinstance(m, nn.Conv3d):
125# n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
126# m.weight.data.normal_(0, math.sqrt(2. / n))
127# elif isinstance(m, nn.BatchNorm2d):
128# m.weight.data.normal_(1.0, 0.02)
129# m.bias.data.fill_(0)
130# # m.weight.data.fill_(1)
131# # m.bias.data.zero_()
132# elif isinstance(m, nn.BatchNorm3d):
133# m.weight.data.fill_(1)
134# m.bias.data.zero_()
135# elif isinstance(m, nn.Linear):
136# m.weight.data.normal_(0.0, 0.02)
137# if m.bias is not None:
138# m.bias.data.zero_()
init_weights(model)
Definition utils.py:94
save_optimizer(optimizer, save_path)
Definition utils.py:83
load_model(model, path, ignore_key=None, device='cuda:0')
Definition utils.py:10
save_model(model, save_path)
Definition utils.py:50
load_optimizer(optimizer, path)
Definition utils.py:72