7from smaction.datasets.action_dataset_loader
import ActionDatasetLoader, ActionDatasetLoader_v22, ActionDatasetLoader_mtml
85def train_epoch(model, loss_cls, optimizer, data_loader, collect_keys, device):
88 max_step = len(data_loader)
90 for step, sample
in enumerate(data_loader):
97 print(f
" train => [{step+1}/{max_step}] -> avg loss : {(sum_loss/(step+1)):.6f} , loss : {loss:.6f}", end=
'\r')
99 return sum_loss/max_step
119def validation(model, loss_cls, data_loader, collect_keys, device, metric_args):
121 max_step = len(data_loader)
126 with torch.no_grad():
127 for step, sample
in enumerate(data_loader):
135 meter.update(val_acc)
137 print(f
' val => [{step+1}/{max_step}] -> {meter.print_str()}', end=
'\r')
139 return meter, sum_loss/max_step
142def train(cfg, device_ids):
144 num_workers = cfg.train.num_workers
145 init_lr = cfg.train.init_lr
146 epochs = cfg.train.epochs
147 batch_size = cfg.train.batch_size
148 checkpoint_folder = os.path.join(cfg.train.save_root,
'weights')
151 if device_ids
is None:
154 device = f
'cuda:{device_ids[0]}'
157 print(
'---- data loader setting ----')
158 data_loader_cfg = copy.deepcopy(cfg.data_loader)
159 DataSetClass = __data_loader__[data_loader_cfg.pop(
'type')]
162 train_dataset = DataSetClass(mode=
'train', **data_loader_cfg)
163 train_data_loader = DataLoader(train_dataset, batch_size, shuffle=
True, num_workers=num_workers, drop_last=
True)
165 test_dataset = DataSetClass(mode=
'val', **data_loader_cfg)
166 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=
True, num_workers=num_workers, drop_last=
True)
169 print(
'---- build model ----')
170 model = build_model(cfg)
172 if cfg.train.pretrained
is not None:
173 load_model(cfg, model, cfg.train.pretrained , device=device)
178 print(
'---- build loss ----')
179 for key, loss_class
in cfg.loss.items():
180 loss_class.update(dict(device=device))
181 loss_cls = build_loss(cfg)
184 print(
'---- optimizer setting ----')
185 assert cfg.train.optimizer
in __optimizer__, \
186 f
'not found optimizer type : {cfg.train.optimizer}'
188 optimizer = __optimizer__[cfg.train.optimizer](model.parameters(), lr=init_lr, **cfg.train.optimizer_args)
189 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, **cfg.train.scheduler_args)
192 create_directory(checkpoint_folder)
195 msg_path = os.path.join(cfg.train.save_root,
'train_result.json')
199 for ep
in range(epochs):
201 if cfg.train.scheduler ==
'StepLR':
202 lr =
adjust_learning_rate(optimizer, ep, init_lr, cfg.train.adjust_lr_epoch, cfg.train.adjust_lr_rate)
204 lr = scheduler.get_last_lr()
205 print(f
'[{ep+1}]/[{epochs}] -> time( {datetime.datetime.now()} ), lr({lr})')
209 avg_loss =
train_epoch(model, loss_cls, optimizer, train_data_loader, cfg.collect_keys, device)
212 if (ep+1) % cfg.train.val_interval == 0:
213 meter, val_loss =
validation(model, loss_cls, test_data_loader, cfg.collect_keys, device, cfg.metric_args)
215 if cfg.train.update_loss_weight
and (ep+1) % cfg.train.update_loss_weight_interval == 0:
216 val_acc = meter.mean()
217 new_weights = [base_weight + (1-val_acc[key])
for key
in pred_keys]
218 loss_cls.weights = new_weights
220 msg = f
'[{ep+1}]/[{epochs}] train loss : {avg_loss:.6f} val loss : {val_loss:.6f} {meter.print_str()}'
223 result_msg.append(msg)
224 save_json(result_msg, msg_path)
226 if cfg.train.scheduler !=
'StepLR':
230 checkpoint_path = os.path.join(checkpoint_folder, f
'{ep}.pth')
233 print(f
"\n-end- ( {datetime.datetime.now()} )")