7 모델의 출력(라벨 or 테스크)이 여러개일때 크로스엔트로피 로스를 계산하기위한 클래스
9 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
10 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
11 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
12 data_num (list[int] or None): 클래스의 데이터 수량
13 device (str): 모델이 구동하는 디바이스
14 return (Tensor): 각 테스크 로스의 가중합
16 def __init__(self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu'):
17 super(CrossEntropyLoss, self).
__init__()
24 self.
weights = [1.0]*len(pred_keys)
26 self.
loss = nn.ModuleDict()
30 self.
loss[key] = nn.CrossEntropyLoss()
34 weight = [n/s_num
for n
in d_num]
35 weight = torch.tensor(weight).to(device)
36 self.
loss[key] = nn.CrossEntropyLoss(weight=weight)
39 return self.
loss[pred_key](pred, gt)
46 losses.append( weight * self.
calc_loss(data[pred_key], data[gt_key], pred_key) )
53 멀티 라벨 or 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
54 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
55 학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
57 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
58 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
59 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
60 target_tasks (list[str]): 학습 데이터의 타겟 테스크, task_key에 설정됨
61 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
62 data_num (dict or None): 테스크별 클래스의 데이터 수량
63 device (str): 모델이 구동하는 디바이스
64 return (Tensor): 각 테스크별 로스의 가중합
66 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu'):
67 super(MutiTaskCrossEntropyLoss, self).
__init__()
72 self.
weights = [1.0]*len(pred_keys)
76 self.
loss = nn.ModuleDict()
80 self.
loss[key] = nn.CrossEntropyLoss(ignore_index=-1)
84 weight = [n/s_num
for n
in d_num]
85 weight = torch.tensor(weight).to(device)
86 self.
loss[key] = nn.CrossEntropyLoss(weight=weight, ignore_index=-1)
89 return self.
loss[pred_key](pred, gt)
98 target_idx = [task == target_task
for task
in data[self.
task_key]]
99 if sum(target_idx) == 0:
102 target_idx = torch.tensor(target_idx)
105 losses.append( weight * self.
calc_loss(data[pred_key][target_idx], data[gt_key][target_idx], pred_key) )
111 멀티 라벨, 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
112 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 모든 테스크에 대한 라벨값을 가지고 있음(멀티 라벨 데이터)
113 학습 데이터는 메인 테스크가 설정되어 있음, 하나의 학습 데이터에 대해 모든 테스크에 대해 로스를 계산하고 가중합을 취함
114 메인 테스크와 메인이 아닌 테스크에 대한 중요도 조절을 위해 가중치 조정이 필요함
116 모든 라벨에 대해서 학습하기 때문에 데이터 불균현 문제가 존재함 -> 해결 방안이 필요함
118 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
119 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
120 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
121 target_tasks (list[str]): 학습 데이터의 메인 테스크, task_key에 설정됨
122 train_tasks (list[str]): 학습할 모든 테스크(라벨)
123 target_task_weight (float): 메인 테스크에 대한 가중치
124 non_target_task_weight (float): 메인이 아닌 테스크에 대한 가중치
125 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
126 return (Tensor): 각 테스크별 로스의 가중합
128 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args):
129 super(MutiTaskMultiLabelCrossEntropyLoss, self).
__init__()
137 self.
weights = [1.0]*len(pred_keys)
142 self.
loss = nn.CrossEntropyLoss(ignore_index=-1)
146 return self.
loss(pred, gt)
160 if train_task == target_task:
166 idx = [ task == train_task
for task
in data[self.
task_key]]
170 idx = torch.tensor(idx)
173 if (data[gt_key][idx] != -1).sum() == 0:
176 losses.append(weight*self.
calc_loss(data[pred_key][idx], data[gt_key][idx]))
179 target_idx = [task == target_task
for task
in data[self.
task_key]]
180 if sum(target_idx) == 0:
183 target_idx = torch.tensor(target_idx)
186 losses.append( weight * self.
calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
194 멀티 라벨 또는 멀티 테스크 문제를 학습하기 위한 시그모이드 포컬로스 클래스
195 멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
196 학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
198 task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
199 pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
200 gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
201 target_tasks (list[str]): 학습 데이터의 테스크, task_key에 설정됨, 해당 테스크 경로만 학습함
202 weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
203 alpha (float): 클래스 불균형 파라미터
204 gamma (float): 샘플 난이도 가중치
205 beta (float): 테스크 난이도에 대한 가중치
206 TODO: pjm 추가, 성능이 좋지 않다고 판단되면 제거
207 reduction (str): 최종 출력 로스에 적용할 연산
211 return (Tensor): 각 테스크별 로스의 가중합
213 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args):
214 super(MutiTaskSigmoidFocalLoss, self).
__init__()
222 self.
weights = [1.0] * len(pred_keys)
233 inputs (Tensor): 모델의 출력 결과
234 targets (Tensor): 모델의 출력에 대응하는 GT
238 inputs = inputs.float()
239 targets = targets.float()
242 p = torch.sigmoid(inputs)
243 ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=
"none")
244 p_t = p * targets + (1 - p) * (1 - targets)
245 loss = ce_loss * ((1 - p_t) ** self.
gamma)
249 alpha_t = self.
alpha * targets + (1 - self.
alpha) * (1 - targets)
250 loss = alpha_t * loss
254 w = torch.sum((1 - p_t) * targets, dim=1).mean() ** self.
beta
269 바이너리 크로스엔트로피 적용을 위해 GT에 대한 전처리(one hot encording) 장치가 추가됨
271 pred (Tensor): 모델의 예측 결과(logit)
274 class_num = pred.shape[1]
275 gt_one_hot = F.one_hot(gt, class_num)
287 target_idx = [task == target_task
for task
in data[self.
task_key]]
288 if sum(target_idx) == 0:
291 target_idx = torch.tensor(target_idx)
294 losses.append( weight * self.
calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
__init__(self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu')
calc_loss(self, pred, gt, pred_key)
calc_loss(self, pred, gt, pred_key)
__init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu')
__init__(self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args)
calc_loss(self, pred, gt)
__init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args)
calc_loss(self, pred, gt)
sigmoid_focal_loss(self, inputs, targets)