48 max_step = len(data_loader)
52 for step, sample
in enumerate(data_loader):
56 print(f
' val => [{step+1}/{max_step}] -> top1 acc : {meter.mean()*100:.6f}%', end=
'\r')
58 return meter.mean(), sum_loss/max_step
61def test(cfg, device_ids):
62 num_workers = cfg.test.num_workers
63 batch_size = cfg.test.batch_size
65 if device_ids
is None:
68 device = f
'cuda:{device_ids[0]}'
71 print(
'---- data loader setting ----')
72 data_loader_cfg = copy.deepcopy(cfg.data_loader)
73 DataSetClass = __data_loader__[data_loader_cfg.pop(
'type')]
74 test_dataset = DataSetClass(mode=
'test', **data_loader_cfg)
75 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=
False, num_workers=num_workers, drop_last=
False)
78 print(
'---- build model ----')
79 model = build_model(cfg)
82 load_model(cfg, model, cfg.test.pretrained , device=device)
85 print(
'---- build loss ----')
86 for key, loss_class
in cfg.loss.items():
87 loss_class.update(dict(device=device))
88 loss_cls = build_loss(cfg)
91 create_directory(cfg.test.save_root)
94 msg_path = os.path.join(cfg.test.save_root,
'test_result.json')
96 print(f
"\n-start- ( {datetime.datetime.now()} )")
98 val_acc, val_loss =
validation(model, loss_cls, test_data_loader, device)
100 msg = f
'val loss : {val_loss:.6f} val acc : {val_acc*100:.6f}'
103 result_msg.append(msg)
104 save_json(result_msg, msg_path)
108 print(f
"\n-end- ( {datetime.datetime.now()} )")