Safemotion Lib
Loading...
Searching...
No Matches
test.py
Go to the documentation of this file.
2import torch
3import torch.nn as nn
4import torch.optim as optim
5from torch.utils.data import DataLoader
6
7from smaction.datasets.action_dataset_loader import ActionDatasetLoader
8from smutils.utils_os import create_directory, save_json
9from smrunner.model_runner_builder import build_model
10from smrunner.losses.loss_builder import build_loss
11from smrunner.utils import *
12from smrunner.metrics import *
13import json
14
15import datetime
16import numpy as np
17import copy
18import os
19
20__data_loader__ = {
21 "ActionDatasetLoader": ActionDatasetLoader,
22}
23
24def tmp_pipeline(sample, keys, device):
25 for key in keys:
26 sample[key] = sample[key].to(device)
27 # sample['keypoints'] = sample['keypoints'].to(device)
28 # sample['label'] = sample['label'].to(device)
29
30 return sample
31
32
33def validation_sample(model, loss_cls, sample, device):
34
35 tmp_keys = ['keypoints', 'label_action', 'label_pose']
36 sample = tmp_pipeline(sample, tmp_keys, device)
37 result = model(sample)
38 sample.update(result)
39 loss = loss_cls(sample)
40 loss = torch.mean(loss)
41
42 acc = top1_acc(sample, 'action_labels', 'label_action')
43 return acc, loss.item()
44
45
46def validation(model, loss_cls, data_loader, device):
47 model.eval()
48 max_step = len(data_loader)
49 sum_loss = 0
50 meter = AverageMeter()
51 with torch.no_grad():
52 for step, sample in enumerate(data_loader):
53 val_acc, val_loss = validation_sample(model, loss_cls, sample, device)
54 meter.update(val_acc)
55 sum_loss += val_loss
56 print(f' val => [{step+1}/{max_step}] -> top1 acc : {meter.mean()*100:.6f}%', end='\r')
57 print('')
58 return meter.mean(), sum_loss/max_step
59
60
61def test(cfg, device_ids):
62 num_workers = cfg.test.num_workers
63 batch_size = cfg.test.batch_size
64
65 if device_ids is None:
66 device = 'cpu'
67 else:
68 device = f'cuda:{device_ids[0]}'
69
70 # 데이터셋 로더 셋팅
71 print('---- data loader setting ----')
72 data_loader_cfg = copy.deepcopy(cfg.data_loader)
73 DataSetClass = __data_loader__[data_loader_cfg.pop('type')]
74 test_dataset = DataSetClass(mode='test', **data_loader_cfg)
75 test_data_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
76
77 # 모델 빌드
78 print('---- build model ----')
79 model = build_model(cfg)
80 model.to(device)
81
82 load_model(cfg, model, cfg.test.pretrained , device=device)
83
84 # 로스 빌드
85 print('---- build loss ----')
86 for key, loss_class in cfg.loss.items():
87 loss_class.update(dict(device=device))
88 loss_cls = build_loss(cfg)
89
90 #테스트 결과 저장 폴더 생성
91 create_directory(cfg.test.save_root)
92
93 result_msg = []
94 msg_path = os.path.join(cfg.test.save_root, 'test_result.json')
95
96 print(f"\n-start- ( {datetime.datetime.now()} )")
97
98 val_acc, val_loss = validation(model, loss_cls, test_data_loader, device)
99
100 msg = f'val loss : {val_loss:.6f} val acc : {val_acc*100:.6f}'
101 print(msg)
102
103 result_msg.append(msg)
104 save_json(result_msg, msg_path)
105
106
107
108 print(f"\n-end- ( {datetime.datetime.now()} )")
109
top1_acc(data, pred_key, gt_key, target_idx=None)
Definition metrics.py:91
validation_sample(model, loss_cls, sample, device)
Definition test.py:33
validation(model, loss_cls, data_loader, device)
Definition test.py:46
tmp_pipeline(sample, keys, device)
Definition test.py:24