Safemotion Lib
Loading...
Searching...
No Matches
classification_loss.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3from torch.nn import functional as F
4
5class CrossEntropyLoss(nn.Module):
6 """
7 모델의 출력(라벨 or 테스크)이 여러개일때 크로스엔트로피 로스를 계산하기위한 클래스
8 args:
9 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
10 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
11 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
12 data_num (list[int] or None): 클래스의 데이터 수량
13 device (str): 모델이 구동하는 디바이스
14 return (Tensor): 각 테스크 로스의 가중합
15 """
16 def __init__(self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu'):
17 super(CrossEntropyLoss, self).__init__()
18
19
20 self.pred_keys = pred_keys
21 self.gt_keys = gt_keys
22 self.weights = weights
23 if weights is None:
24 self.weights = [1.0]*len(pred_keys)
25
26 self.loss = nn.ModuleDict()
27 for key in pred_keys:
28
29 if data_num is None:
30 self.loss[key] = nn.CrossEntropyLoss()
31 else:
32 d_num = data_num[key]
33 s_num = sum(d_num)
34 weight = [n/s_num for n in d_num]
35 weight = torch.tensor(weight).to(device)
36 self.loss[key] = nn.CrossEntropyLoss(weight=weight)
37
38 def calc_loss(self, pred, gt, pred_key):
39 return self.loss[pred_key](pred, gt)
40
41
42 def forward(self, data):
43 losses = []
44
45 for pred_key, gt_key, weight in zip(self.pred_keys, self.gt_keys, self.weights):
46 losses.append( weight * self.calc_loss(data[pred_key], data[gt_key], pred_key) )
47
48 return sum(losses)
49
50
51class MutiTaskCrossEntropyLoss(nn.Module):
52 """
53 멀티 라벨 or 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
54 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
55 학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
56 args:
57 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
58 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
59 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
60 target_tasks (list[str]): 학습 데이터의 타겟 테스크, task_key에 설정됨
61 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
62 data_num (dict or None): 테스크별 클래스의 데이터 수량
63 device (str): 모델이 구동하는 디바이스
64 return (Tensor): 각 테스크별 로스의 가중합
65 """
66 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu'):
67 super(MutiTaskCrossEntropyLoss, self).__init__()
68 self.task_key = task_key
69 self.target_tasks = target_tasks
70 self.weights = weights
71 if weights is None:
72 self.weights = [1.0]*len(pred_keys)
73 self.pred_keys = pred_keys
74 self.gt_keys = gt_keys
75
76 self.loss = nn.ModuleDict()
77 for key in pred_keys:
78
79 if data_num is None:
80 self.loss[key] = nn.CrossEntropyLoss(ignore_index=-1)
81 else:
82 d_num = data_num[key]
83 s_num = sum(d_num)
84 weight = [n/s_num for n in d_num]
85 weight = torch.tensor(weight).to(device)
86 self.loss[key] = nn.CrossEntropyLoss(weight=weight, ignore_index=-1)
87
88 def calc_loss(self, pred, gt, pred_key):
89 return self.loss[pred_key](pred, gt)
90
91
92 def forward(self, data):
93 losses = []
94
95 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
96
97 #task에 해당하는 샘플 인덱스
98 target_idx = [task == target_task for task in data[self.task_key]]
99 if sum(target_idx) == 0:
100 continue
101
102 target_idx = torch.tensor(target_idx)
103
104 #로스 계산
105 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx], pred_key) )
106
107 return sum(losses)
108
110 """
111 멀티 라벨, 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
112 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 모든 테스크에 대한 라벨값을 가지고 있음(멀티 라벨 데이터)
113 학습 데이터는 메인 테스크가 설정되어 있음, 하나의 학습 데이터에 대해 모든 테스크에 대해 로스를 계산하고 가중합을 취함
114 메인 테스크와 메인이 아닌 테스크에 대한 중요도 조절을 위해 가중치 조정이 필요함
115 TODO
116 모든 라벨에 대해서 학습하기 때문에 데이터 불균현 문제가 존재함 -> 해결 방안이 필요함
117 args:
118 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
119 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
120 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
121 target_tasks (list[str]): 학습 데이터의 메인 테스크, task_key에 설정됨
122 train_tasks (list[str]): 학습할 모든 테스크(라벨)
123 target_task_weight (float): 메인 테스크에 대한 가중치
124 non_target_task_weight (float): 메인이 아닌 테스크에 대한 가중치
125 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
126 return (Tensor): 각 테스크별 로스의 가중합
127 """
128 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args):
129 super(MutiTaskMultiLabelCrossEntropyLoss, self).__init__()
130 self.task_key = task_key
131 self.pred_keys = pred_keys
132 self.gt_keys = gt_keys
133 self.target_tasks = target_tasks
134 self.train_tasks = train_tasks
135 self.weights = weights
136 if weights is None:
137 self.weights = [1.0]*len(pred_keys)
138
139 self.target_task_weight = target_task_weight
140 self.non_target_task_weight = non_target_task_weight
141
142 self.loss = nn.CrossEntropyLoss(ignore_index=-1)
143
144
145 def calc_loss(self, pred, gt):
146 return self.loss(pred, gt)
147
148
149 def forward(self, data):
150
151 losses = []
152
153 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
154
155 #TODO : 코드가 비효율적으로 작성된것 처럼 보임, 효율적인 구조로 변경 필요
156 if self.non_target_task_weight > 0: #하나의 샘플이 모든 테스크에 대해 학습
157 for train_task in self.train_tasks: #학습 테스크
158
159 #가중치 설정, 테스크 기본 가중치* 메인 or Non-Main 테크스 가중치
160 if train_task == target_task: #메인 테스크
161 weight = weight*self.target_task_weight
162 else: #메인 테스크가 아닌 테스크
163 weight = weight*self.non_target_task_weight
164
165 #학습 테스크 샘플 인덱스
166 idx = [ task == train_task for task in data[self.task_key]]
167 if sum(idx) == 0:
168 continue
169
170 idx = torch.tensor(idx)
171
172 #학습 테스크의 클래스가 -1만 있을경우 처리, -1은 크로스 엔트로피에 ignore_index로 설정되어 있음, 따라서 -1인 경우 학습되지 않음
173 if (data[gt_key][idx] != -1).sum() == 0:
174 continue
175
176 losses.append(weight*self.calc_loss(data[pred_key][idx], data[gt_key][idx]))
177 else: #하나의 샘플이 메인 테스크에 대해서만 학습, MutiTaskCrossEntropyLoss와 동일
178 #task에 해당하는 샘플 인덱스
179 target_idx = [task == target_task for task in data[self.task_key]]
180 if sum(target_idx) == 0:
181 continue
182
183 target_idx = torch.tensor(target_idx)
184
185 #로스 계산
186 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
187
188
189 return sum(losses)
190
191
193 """
194 멀티 라벨 또는 멀티 테스크 문제를 학습하기 위한 시그모이드 포컬로스 클래스
195 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
196 학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
197 args:
198 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
199 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
200 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
201 target_tasks (list[str]): 학습 데이터의 테스크, task_key에 설정됨, 해당 테스크 경로만 학습함
202 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
203 alpha (float): 클래스 불균형 파라미터
204 gamma (float): 샘플 난이도 가중치
205 beta (float): 테스크 난이도에 대한 가중치
206 TODO: pjm 추가, 성능이 좋지 않다고 판단되면 제거
207 reduction (str): 최종 출력 로스에 적용할 연산
208 'mean': 평균
209 'sum': 합
210 None : 연산이 적용되지 않음
211 return (Tensor): 각 테스크별 로스의 가중합
212 """
213 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args):
214 super(MutiTaskSigmoidFocalLoss, self).__init__()
215
216 self.task_key = task_key
217 self.pred_keys = pred_keys
218 self.gt_keys = gt_keys
219 self.target_tasks = target_tasks
220 self.weights = weights
221 if weights is None:
222 self.weights = [1.0] * len(pred_keys)
223
224 self.alpha = alpha
225 self.gamma = gamma
226 self.beta = beta
227 self.reduction = reduction
228
229 def sigmoid_focal_loss(self, inputs, targets):
230 """
231 시그모이드 포컬로스 계산하는 기능
232 args:
233 inputs (Tensor): 모델의 출력 결과
234 targets (Tensor): 모델의 출력에 대응하는 GT
235 return (Tensor)
236 """
237
238 inputs = inputs.float() #타입변환
239 targets = targets.float() #타입변환
240
241 #시그모이드 포컬로스 계산
242 p = torch.sigmoid(inputs) #시그모이드
243 ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") #바이너리 크로스엔트로피
244 p_t = p * targets + (1 - p) * (1 - targets)
245 loss = ce_loss * ((1 - p_t) ** self.gamma) #샘플 난이도 가중치
246
247 if self.alpha >= 0:
248 #클래스 불균형 조정
249 alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
250 loss = alpha_t * loss
251
252 if self.beta > 0:
253 #테스크 난이도 가중치
254 w = torch.sum((1 - p_t) * targets, dim=1).mean() ** self.beta
255 #w = (1 - p_t).mean() ** self.beta
256 loss = w * loss
257
258 #reduction
259 if self.reduction == "mean":
260 loss = loss.mean()
261 elif self.reduction == "sum":
262 loss = loss.sum()
263
264 return loss
265
266 def calc_loss(self, pred, gt):
267 """
268 시그모이드 포컬로스를 계산하는 기능
269 바이너리 크로스엔트로피 적용을 위해 GT에 대한 전처리(one hot encording) 장치가 추가됨
270 args:
271 pred (Tensor): 모델의 예측 결과(logit)
272 gt (Tensor): 클래스 번호
273 """
274 class_num = pred.shape[1] #클래스 수
275 gt_one_hot = F.one_hot(gt, class_num) #one hot encording
276
277 #로스 계산
278 return self.sigmoid_focal_loss(pred, gt_one_hot)
279
280
281 def forward(self, data):
282 losses = []
283
284 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
285
286 #task에 해당하는 샘플 인덱스
287 target_idx = [task == target_task for task in data[self.task_key]]
288 if sum(target_idx) == 0:
289 continue
290
291 target_idx = torch.tensor(target_idx)
292
293 #로스 계산
294 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
295
296 return sum(losses)
__init__(self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu')
__init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu')
__init__(self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args)
__init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args)