Safemotion Lib
Loading...
Searching...
No Matches
train.py
Go to the documentation of this file.
2import torch
3import torch.nn as nn
4import torch.optim as optim
5from torch.utils.data import DataLoader
6
7from smaction.datasets.action_dataset_loader import ActionDatasetLoader, ActionDatasetLoader_v22, ActionDatasetLoader_mtml
8from smutils.utils_os import create_directory, save_json
9from smrunner.model_runner_builder import build_model
10from smrunner.losses.loss_builder import build_loss
11from smrunner.utils import *
12from smrunner.metrics import *
13import json
14
15import datetime
16import numpy as np
17import copy
18import os
19
20#TODO
21#빌더로 빼서 따로 관리 필요할 것 같음 : 데이터 로더, 옵티마이저, 스케줄러
22#데이터 로더를 어떻게 관리할지 결정 필요
23# 현재 : 행동인식만 학습하기 때문에 smaction/dataset에 데이터로더를 구현해둠
24# 변경 안 : 데이터 로더만 따로 모아서 관리
25# -> smrunner에 데이터로더 관련 기능들 정리
26# -> smdataset에 데이터로더 관련 기능들 정리
27#학습 코드를 어떻게 관리할지 결정 필요
28# 현재 : smrunner/train.py에 학습 코드가 구현되어 있음. 특별 케이스만 작동됨
29# 변경 안: smrunner/trainer에 학습 관련 기능 구현(일반화시켜서)
30#데이터 전처리 기능이 임시로 구현되어 있음 : tmp_pipeline
31# 일반화 시켜서 따로 관리 필요함 : mmlab 참조해보기
32# 현재 : 일부 전처리 기능은 데이터로더에서 작업됨, 디바이스로 전송하는 기능은 학습 코드에 임시로 작성됨
33#학습 파라미터 저장 및 로드 구조 개선 필요
34#멀티 gpu 학습이 가능하도록 코드 및 구조 개선 필요
35
36
37#데이터 로더
38__data_loader__ = {
39 "ActionDatasetLoader" : ActionDatasetLoader,
40 "ActionDatasetLoader_v22" : ActionDatasetLoader_v22,
41 "ActionDatasetLoader_mtml" : ActionDatasetLoader_mtml,
42}
43
44#옵티마이저
45__optimizer__ = {
46 "SGD" : optim.SGD,
47 "Adam" : optim.Adam
48}
49
50#lr 스케줄러(step)
51def adjust_learning_rate(optimizer, epoch, lr, adjust_epoch, adjust_rate):
52 #적용 lr 계산
53 for ep, rate in zip(adjust_epoch, adjust_rate):
54 if epoch >= ep:
55 lr=rate*lr
56
57 #lr 셋팅
58 for param_group in optimizer.param_groups:
59 param_group['lr'] = lr
60
61 return lr
62
63#데이터 전처리, 학습 디바이스로 데이터 전송
64def tmp_pipeline(sample, keys, device):
65 for key in keys:
66 sample[key] = sample[key].to(device)
67
68 return sample
69
70#배치에 대해 학습
71def train_sample(model, loss_cls, optimizer, sample):
72 optimizer.zero_grad() #옵티마이저 기울기 초기화
73
74 result = model(sample) #inference
75 sample.update(result) #샘플에 inference 결과 추가
76
77 loss = loss_cls(sample) #로스 계산
78 loss = torch.mean(loss) #로스 평균
79 loss.backward() #백프로파게이션, 기울기 계산
80 optimizer.step() #모델 파라미터 업데이트
81
82 return loss.item()
83
84#1-epoch 학습 기능
85def train_epoch(model, loss_cls, optimizer, data_loader, collect_keys, device):
86
87 model.train() #모델 학습 모드로 변경
88 max_step = len(data_loader) #몇번 업데이트하는지 확인, 출력용도
89 sum_loss = 0 #로스 합, 출력용도
90 for step, sample in enumerate(data_loader):
91
92 sample = tmp_pipeline(sample, collect_keys, device) #데이터를 디바이스로 전송
93 loss = train_sample(model, loss_cls, optimizer, sample) #샘플(배치) 훈련
94
95 #진행도 화면 출력
96 sum_loss += loss
97 print(f" train => [{step+1}/{max_step}] -> avg loss : {(sum_loss/(step+1)):.6f} , loss : {loss:.6f}", end='\r')
98 print('')
99 return sum_loss/max_step
100
101#배치에 대해 평가
102def validation_sample(model, loss_cls, sample, metric_args):
103
104 result = model(sample) #inference
105 sample.update(result) #샘플에 inference 결과 추가
106 loss = loss_cls(sample) #로스 계산
107 loss = torch.mean(loss) #로스 평균
108
109 #top-1 정확도 계산
110 # acc_dict = top1_acc_dict(sample, ['pred_action', 'pred_pose'], ['label_action', 'label_pose'])
111 acc_dict = top1_acc_multi_task(sample, **metric_args)
112 # pred_key=['pred_pose', 'pred_hand', 'pred_foot'],
113 # gt_key=['label', 'label', 'label'],
114 # target_tasks=['pose', 'hand', 'foot'],
115 # task_key='category')
116 return acc_dict, loss.item()
117
118#전체 평가 데이터에 대한 평가
119def validation(model, loss_cls, data_loader, collect_keys, device, metric_args):
120 model.eval() #평가모드
121 max_step = len(data_loader) #몇 스텝 반복하는지 확인, 출력용도
122 sum_loss = 0 #로스 합, 출력용도
123 meter = AverageMeterDict(model.predict_keys.keys()) #평균 계산기
124
125 #평가
126 with torch.no_grad():
127 for step, sample in enumerate(data_loader):
128
129 #평가용 데이터를 디바이스로 전송
130 sample = tmp_pipeline(sample, collect_keys, device)
131
132 #배치에 대해 평가
133 val_acc, val_loss = validation_sample(model, loss_cls, sample, metric_args)
134
135 meter.update(val_acc) #평균 계산기에 정확도 계산 결과 업데이트
136 sum_loss += val_loss #로스합
137 print(f' val => [{step+1}/{max_step}] -> {meter.print_str()}', end='\r')
138 print('')
139 return meter, sum_loss/max_step
140
141#학습 함수, 1-gpu를 이용해서 학습함
142def train(cfg, device_ids):
143 #변수 셋팅
144 num_workers = cfg.train.num_workers
145 init_lr = cfg.train.init_lr
146 epochs = cfg.train.epochs
147 batch_size = cfg.train.batch_size
148 checkpoint_folder = os.path.join(cfg.train.save_root, 'weights')
149
150 #학습 디바이스 셋팅
151 if device_ids is None:
152 device = 'cpu'
153 else:
154 device = f'cuda:{device_ids[0]}'
155
156 # 데이터셋 로더 셋팅
157 print('---- data loader setting ----')
158 data_loader_cfg = copy.deepcopy(cfg.data_loader) #데이터 로더 파라미터
159 DataSetClass = __data_loader__[data_loader_cfg.pop('type')] #사용할 데이터 로더 선택
160
161 #학습용 데이터 로더 빌드
162 train_dataset = DataSetClass(mode='train', **data_loader_cfg)
163 train_data_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
164 #평가용 데이터 로더 빌드
165 test_dataset = DataSetClass(mode='val', **data_loader_cfg)
166 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
167
168 # 모델 빌드
169 print('---- build model ----')
170 model = build_model(cfg)
171 model.to(device)
172 if cfg.train.pretrained is not None:
173 load_model(cfg, model, cfg.train.pretrained , device=device) #사전학습파라미터 로드
174 else:
175 init_weights(model) #모델 파라미터 초기화
176
177 # 로스 빌드
178 print('---- build loss ----')
179 for key, loss_class in cfg.loss.items():
180 loss_class.update(dict(device=device))
181 loss_cls = build_loss(cfg)
182
183 # 옵티마이저 설정
184 print('---- optimizer setting ----')
185 assert cfg.train.optimizer in __optimizer__, \
186 f'not found optimizer type : {cfg.train.optimizer}'
187
188 optimizer = __optimizer__[cfg.train.optimizer](model.parameters(), lr=init_lr, **cfg.train.optimizer_args)
189 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, **cfg.train.scheduler_args)
190
191 #훈련 결과 저장 폴더 생성
192 create_directory(checkpoint_folder)
193
194 result_msg = []
195 msg_path = os.path.join(cfg.train.save_root, 'train_result.json')
196
197 #훈련 시작
198 model.train()
199 for ep in range(epochs):
200 #lr 조정
201 if cfg.train.scheduler == 'StepLR':
202 lr = adjust_learning_rate(optimizer, ep, init_lr, cfg.train.adjust_lr_epoch, cfg.train.adjust_lr_rate)
203 else:
204 lr = scheduler.get_last_lr()
205 print(f'[{ep+1}]/[{epochs}] -> time( {datetime.datetime.now()} ), lr({lr})')
206 # print(f'[{ep+1}]/[{epochs}] -> time( {datetime.datetime.now()} )')
207
208 #train
209 avg_loss = train_epoch(model, loss_cls, optimizer, train_data_loader, cfg.collect_keys, device)
210
211 #eval
212 if (ep+1) % cfg.train.val_interval == 0:
213 meter, val_loss = validation(model, loss_cls, test_data_loader, cfg.collect_keys, device, cfg.metric_args)
214
215 if cfg.train.update_loss_weight and (ep+1) % cfg.train.update_loss_weight_interval == 0:
216 val_acc = meter.mean()
217 new_weights = [base_weight + (1-val_acc[key]) for key in pred_keys]
218 loss_cls.weights = new_weights
219
220 msg = f'[{ep+1}]/[{epochs}] train loss : {avg_loss:.6f} val loss : {val_loss:.6f} {meter.print_str()}'
221 print(msg)
222
223 result_msg.append(msg)
224 save_json(result_msg, msg_path)
225
226 if cfg.train.scheduler != 'StepLR':
227 scheduler.step()
228
229 #모델 파라미터 저장
230 checkpoint_path = os.path.join(checkpoint_folder, f'{ep}.pth')
231 save_model(model, checkpoint_path)
232
233 print(f"\n-end- ( {datetime.datetime.now()} )")
234
top1_acc_multi_task(data, pred_key, gt_key, target_tasks, task_key)
Definition metrics.py:114
train_epoch(model, loss_cls, optimizer, data_loader, collect_keys, device)
Definition train.py:85
validation(model, loss_cls, data_loader, collect_keys, device, metric_args)
Definition train.py:119
adjust_learning_rate(optimizer, epoch, lr, adjust_epoch, adjust_rate)
Definition train.py:51
tmp_pipeline(sample, keys, device)
Definition train.py:64
train_sample(model, loss_cls, optimizer, sample)
Definition train.py:71
validation_sample(model, loss_cls, sample, metric_args)
Definition train.py:102
save_model(model, save_path)
Definition utils.py:50