142def train(cfg, device_ids):
143
144 num_workers = cfg.train.num_workers
145 init_lr = cfg.train.init_lr
146 epochs = cfg.train.epochs
147 batch_size = cfg.train.batch_size
148 checkpoint_folder = os.path.join(cfg.train.save_root, 'weights')
149
150
151 if device_ids is None:
152 device = 'cpu'
153 else:
154 device = f'cuda:{device_ids[0]}'
155
156
157 print('---- data loader setting ----')
158 data_loader_cfg = copy.deepcopy(cfg.data_loader)
159 DataSetClass = __data_loader__[data_loader_cfg.pop('type')]
160
161
162 train_dataset = DataSetClass(mode='train', **data_loader_cfg)
163 train_data_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
164
165 test_dataset = DataSetClass(mode='val', **data_loader_cfg)
166 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
167
168
169 print('---- build model ----')
170 model = build_model(cfg)
171 model.to(device)
172 if cfg.train.pretrained is not None:
173 load_model(cfg, model, cfg.train.pretrained , device=device)
174 else:
175 init_weights(model)
176
177
178 print('---- build loss ----')
179 for key, loss_class in cfg.loss.items():
180 loss_class.update(dict(device=device))
181 loss_cls = build_loss(cfg)
182
183
184 print('---- optimizer setting ----')
185 assert cfg.train.optimizer in __optimizer__, \
186 f'not found optimizer type : {cfg.train.optimizer}'
187
188 optimizer = __optimizer__[cfg.train.optimizer](model.parameters(), lr=init_lr, **cfg.train.optimizer_args)
189 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, **cfg.train.scheduler_args)
190
191
192 create_directory(checkpoint_folder)
193
194 result_msg = []
195 msg_path = os.path.join(cfg.train.save_root, 'train_result.json')
196
197
198 model.train()
199 for ep in range(epochs):
200
201 if cfg.train.scheduler == 'StepLR':
202 lr = adjust_learning_rate(optimizer, ep, init_lr, cfg.train.adjust_lr_epoch, cfg.train.adjust_lr_rate)
203 else:
204 lr = scheduler.get_last_lr()
205 print(f'[{ep+1}]/[{epochs}] -> time( {datetime.datetime.now()} ), lr({lr})')
206
207
208
209 avg_loss = train_epoch(model, loss_cls, optimizer, train_data_loader, cfg.collect_keys, device)
210
211
212 if (ep+1) % cfg.train.val_interval == 0:
213 meter, val_loss = validation(model, loss_cls, test_data_loader, cfg.collect_keys, device, cfg.metric_args)
214
215 if cfg.train.update_loss_weight and (ep+1) % cfg.train.update_loss_weight_interval == 0:
216 val_acc = meter.mean()
217 new_weights = [base_weight + (1-val_acc[key]) for key in pred_keys]
218 loss_cls.weights = new_weights
219
220 msg = f'[{ep+1}]/[{epochs}] train loss : {avg_loss:.6f} val loss : {val_loss:.6f} {meter.print_str()}'
221 print(msg)
222
223 result_msg.append(msg)
224 save_json(result_msg, msg_path)
225
226 if cfg.train.scheduler != 'StepLR':
227 scheduler.step()
228
229
230 checkpoint_path = os.path.join(checkpoint_folder, f'{ep}.pth')
231 save_model(model, checkpoint_path)
232
233 print(f"\n-end- ( {datetime.datetime.now()} )")
234