Safemotion Lib
Loading...
Searching...
No Matches
smrunner
metrics.py
Go to the documentation of this file.
1
import
copy
2
import
torch
3
4
class
AverageMeter
(object):
5
"""
6
평균 계산을 위한 클래스
7
수치의 합과 데이터의 수량을 기록함
8
"""
9
def
__init__
(self):
10
self.
sum_value
= 0.
11
self.
count
= 0
12
13
def
update
(self, x):
14
"""
15
평균을 위한 데이터를 업데이트 하는 기능
16
args:
17
x (list[float]): 측정 데이터
18
"""
19
self.
sum_value
+= sum(x)
20
self.
count
+= len(x)
21
22
def
mean
(self):
23
"""
24
평균을 출력해주는 기능
25
return (float): 측정치의 평균
26
"""
27
if
self.
count
== 0:
28
return
0
29
return
self.
sum_value
/ self.
count
30
31
class
AverageMeterDict
(object):
32
"""
33
데이터 또는 측정 방법에 대해 키를 설정하고 해당 키에 대한 측정치 평균계산을 위한 클래스
34
args:
35
keys (list[str]): 설정할 키
36
"""
37
def
__init__
(self, keys):
38
self.
meters
= {}
39
for
key
in
keys:
40
self.
meters
[key] =
AverageMeter
()
41
42
def
update
(self, data):
43
"""
44
키에 대한 측정치를 업데이트 하는 기능
45
args:
46
data (dict): 키에 대한 측정 데이터
47
"""
48
for
key, x
in
data.items():
49
self.
meters
[key].
update
(x)
50
51
def
mean
(self):
52
"""
53
키에 대한 측정치 평균을 출력해주는 기능
54
return (dict): 측정치의 평균
55
"""
56
ret = {}
57
for
key, meter
in
self.
meters
.items():
58
ret[key] = meter.mean()
59
return
ret
60
61
def
calc_saummary
(self):
62
"""
63
요약정보를 출력하는 기능
64
모든 데이터에 대한 평균과 키의 평균에대한 평균값을 계산하는 기능
65
"""
66
total_correct = 0
67
total_cnt = 0
68
sum_acc = 0
69
70
for
key, meter
in
self.
meters
.items():
71
total_correct += meter.sum_value
72
total_cnt += meter.count
73
sum_acc += meter.mean()
74
75
return
total_correct/total_cnt, sum_acc/len(self.
meters
.keys())
76
77
def
print_str
(self):
78
"""
79
요약 내용을 화면에 출력해 주는 기능
80
모든 데이터의 평균, 측정치의 평균
81
"""
82
msg =
''
83
84
for
key, meter
in
self.
meters
.items():
85
msg += f
' {key} : {meter.mean()*100:.2f}'
86
87
total_acc, acc_mean = self.
calc_saummary
()
88
msg += f
' total acc : {total_acc*100:.2f}, acc_mean : {acc_mean*100:.2f}'
89
return
msg
90
91
def
top1_acc
(data, pred_key, gt_key, target_idx=None):
92
"""
93
top1 정확도를 계산하는 기능
94
args:
95
data ():
96
pred_key ():
97
target_idx ():
98
"""
99
if
target_idx
is
None
:
100
return
data[pred_key] == data[gt_key]
101
else
:
102
return
data[pred_key][target_idx] == data[gt_key][target_idx]
103
104
def
top1_acc_dict
(data, pred_key, gt_key):
105
"""
106
여러 결과에 대한 top1 정확도를 계산하는 기능
107
"""
108
acc = {}
109
for
p_key, g_key
in
zip(pred_key, gt_key):
110
acc[p_key] =
top1_acc
(data, p_key, g_key)
111
112
return
acc
113
114
def
top1_acc_multi_task
(data, pred_key, gt_key, target_tasks, task_key):
115
"""
116
특정 테스크에 대한 top1 정확도를 계산하는 기능
117
"""
118
acc = {}
119
for
p_key, g_key, t_task
in
zip(pred_key, gt_key, target_tasks):
120
t_idx = [task == t_task
for
task
in
data[task_key]]
121
t_idx = torch.tensor(t_idx)
122
123
acc[p_key] =
top1_acc
(data, p_key, g_key, t_idx)
124
return
acc
smrunner.metrics.AverageMeterDict
Definition
metrics.py:31
smrunner.metrics.AverageMeterDict.__init__
__init__(self, keys)
Definition
metrics.py:37
smrunner.metrics.AverageMeterDict.meters
meters
Definition
metrics.py:38
smrunner.metrics.AverageMeterDict.mean
mean(self)
Definition
metrics.py:51
smrunner.metrics.AverageMeterDict.update
update(self, data)
Definition
metrics.py:42
smrunner.metrics.AverageMeterDict.print_str
print_str(self)
Definition
metrics.py:77
smrunner.metrics.AverageMeterDict.calc_saummary
calc_saummary(self)
Definition
metrics.py:61
smrunner.metrics.AverageMeter
Definition
metrics.py:4
smrunner.metrics.AverageMeter.__init__
__init__(self)
Definition
metrics.py:9
smrunner.metrics.AverageMeter.count
count
Definition
metrics.py:11
smrunner.metrics.AverageMeter.update
update(self, x)
Definition
metrics.py:13
smrunner.metrics.AverageMeter.sum_value
sum_value
Definition
metrics.py:10
smrunner.metrics.AverageMeter.mean
mean(self)
Definition
metrics.py:22
smrunner.metrics.top1_acc_dict
top1_acc_dict(data, pred_key, gt_key)
Definition
metrics.py:104
smrunner.metrics.top1_acc_multi_task
top1_acc_multi_task(data, pred_key, gt_key, target_tasks, task_key)
Definition
metrics.py:114
smrunner.metrics.top1_acc
top1_acc(data, pred_key, gt_key, target_idx=None)
Definition
metrics.py:91
Generated by
1.10.0