Safemotion Lib
Loading...
Searching...
No Matches
metrics.py
Go to the documentation of this file.
1import copy
2import torch
3
4class AverageMeter(object):
5 """
6 평균 계산을 위한 클래스
7 수치의 합과 데이터의 수량을 기록함
8 """
9 def __init__(self):
10 self.sum_value = 0.
11 self.count = 0
12
13 def update(self, x):
14 """
15 평균을 위한 데이터를 업데이트 하는 기능
16 args:
17 x (list[float]): 측정 데이터
18 """
19 self.sum_value += sum(x)
20 self.count += len(x)
21
22 def mean(self):
23 """
24 평균을 출력해주는 기능
25 return (float): 측정치의 평균
26 """
27 if self.count == 0:
28 return 0
29 return self.sum_value / self.count
30
31class AverageMeterDict(object):
32 """
33 데이터 또는 측정 방법에 대해 키를 설정하고 해당 키에 대한 측정치 평균계산을 위한 클래스
34 args:
35 keys (list[str]): 설정할 키
36 """
37 def __init__(self, keys):
38 self.meters = {}
39 for key in keys:
40 self.meters[key] = AverageMeter()
41
42 def update(self, data):
43 """
44 키에 대한 측정치를 업데이트 하는 기능
45 args:
46 data (dict): 키에 대한 측정 데이터
47 """
48 for key, x in data.items():
49 self.meters[key].update(x)
50
51 def mean(self):
52 """
53 키에 대한 측정치 평균을 출력해주는 기능
54 return (dict): 측정치의 평균
55 """
56 ret = {}
57 for key, meter in self.meters.items():
58 ret[key] = meter.mean()
59 return ret
60
61 def calc_saummary(self):
62 """
63 요약정보를 출력하는 기능
64 모든 데이터에 대한 평균과 키의 평균에대한 평균값을 계산하는 기능
65 """
66 total_correct = 0
67 total_cnt = 0
68 sum_acc = 0
69
70 for key, meter in self.meters.items():
71 total_correct += meter.sum_value
72 total_cnt += meter.count
73 sum_acc += meter.mean()
74
75 return total_correct/total_cnt, sum_acc/len(self.meters.keys())
76
77 def print_str(self):
78 """
79 요약 내용을 화면에 출력해 주는 기능
80 모든 데이터의 평균, 측정치의 평균
81 """
82 msg = ''
83
84 for key, meter in self.meters.items():
85 msg += f' {key} : {meter.mean()*100:.2f}'
86
87 total_acc, acc_mean = self.calc_saummary()
88 msg += f' total acc : {total_acc*100:.2f}, acc_mean : {acc_mean*100:.2f}'
89 return msg
90
91def top1_acc(data, pred_key, gt_key, target_idx=None):
92 """
93 top1 정확도를 계산하는 기능
94 args:
95 data ():
96 pred_key ():
97 target_idx ():
98 """
99 if target_idx is None:
100 return data[pred_key] == data[gt_key]
101 else:
102 return data[pred_key][target_idx] == data[gt_key][target_idx]
103
104def top1_acc_dict(data, pred_key, gt_key):
105 """
106 여러 결과에 대한 top1 정확도를 계산하는 기능
107 """
108 acc = {}
109 for p_key, g_key in zip(pred_key, gt_key):
110 acc[p_key] = top1_acc(data, p_key, g_key)
111
112 return acc
113
114def top1_acc_multi_task(data, pred_key, gt_key, target_tasks, task_key):
115 """
116 특정 테스크에 대한 top1 정확도를 계산하는 기능
117 """
118 acc = {}
119 for p_key, g_key, t_task in zip(pred_key, gt_key, target_tasks):
120 t_idx = [task == t_task for task in data[task_key]]
121 t_idx = torch.tensor(t_idx)
122
123 acc[p_key] = top1_acc(data, p_key, g_key, t_idx)
124 return acc
top1_acc_dict(data, pred_key, gt_key)
Definition metrics.py:104
top1_acc_multi_task(data, pred_key, gt_key, target_tasks, task_key)
Definition metrics.py:114
top1_acc(data, pred_key, gt_key, target_idx=None)
Definition metrics.py:91