61def test(cfg, device_ids):
62 num_workers = cfg.test.num_workers
63 batch_size = cfg.test.batch_size
64
65 if device_ids is None:
66 device = 'cpu'
67 else:
68 device = f'cuda:{device_ids[0]}'
69
70
71 print('---- data loader setting ----')
72 data_loader_cfg = copy.deepcopy(cfg.data_loader)
73 DataSetClass = __data_loader__[data_loader_cfg.pop('type')]
74 test_dataset = DataSetClass(mode='test', **data_loader_cfg)
75 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
76
77
78 print('---- build model ----')
79 model = build_model(cfg)
80 model.to(device)
81
82 load_model(cfg, model, cfg.test.pretrained , device=device)
83
84
85 print('---- build loss ----')
86 for key, loss_class in cfg.loss.items():
87 loss_class.update(dict(device=device))
88 loss_cls = build_loss(cfg)
89
90
91 create_directory(cfg.test.save_root)
92
93 result_msg = []
94 msg_path = os.path.join(cfg.test.save_root, 'test_result.json')
95
96 print(f"\n-start- ( {datetime.datetime.now()} )")
97
98 val_acc, val_loss = validation(model, loss_cls, test_data_loader, device)
99
100 msg = f'val loss : {val_loss:.6f} val acc : {val_acc*100:.6f}'
101 print(msg)
102
103 result_msg.append(msg)
104 save_json(result_msg, msg_path)
105
106
107
108 print(f"\n-end- ( {datetime.datetime.now()} )")
109