Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
classification_loss.MutiTaskSigmoidFocalLoss Class Reference
Inheritance diagram for classification_loss.MutiTaskSigmoidFocalLoss:

Public Member Functions

 __init__ (self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args)
 
 sigmoid_focal_loss (self, inputs, targets)
 
 calc_loss (self, pred, gt)
 
 forward (self, data)
 

Public Attributes

 task_key
 
 pred_keys
 
 gt_keys
 
 target_tasks
 
 weights
 
 alpha
 
 gamma
 
 beta
 
 reduction
 

Detailed Description

멀티 라벨 또는 멀티 테스크 문제를 학습하기 위한 시그모이드 포컬로스 클래스
멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
args:
    task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
    pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
    gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
    target_tasks (list[str]): 학습 데이터의 테스크, task_key에 설정됨, 해당 테스크 경로만 학습함
    weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
    alpha (float): 클래스 불균형 파라미터
    gamma (float): 샘플 난이도 가중치
    beta (float): 테스크 난이도에 대한 가중치
        TODO: pjm 추가, 성능이 좋지 않다고 판단되면 제거
    reduction (str): 최종 출력 로스에 적용할 연산
        'mean': 평균
        'sum': 합
        None : 연산이 적용되지 않음
return (Tensor): 각 테스크별 로스의 가중합

Definition at line 192 of file classification_loss.py.

Constructor & Destructor Documentation

◆ __init__()

classification_loss.MutiTaskSigmoidFocalLoss.__init__ ( self,
task_key,
pred_keys,
gt_keys,
target_tasks,
weights = None,
alpha = 0.25,
gamma = 2,
beta = 0,
reduction = 'mean',
** args )

Definition at line 213 of file classification_loss.py.

213 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args):
214 super(MutiTaskSigmoidFocalLoss, self).__init__()
215
216 self.task_key = task_key
217 self.pred_keys = pred_keys
218 self.gt_keys = gt_keys
219 self.target_tasks = target_tasks
220 self.weights = weights
221 if weights is None:
222 self.weights = [1.0] * len(pred_keys)
223
224 self.alpha = alpha
225 self.gamma = gamma
226 self.beta = beta
227 self.reduction = reduction
228

Member Function Documentation

◆ calc_loss()

classification_loss.MutiTaskSigmoidFocalLoss.calc_loss ( self,
pred,
gt )
시그모이드 포컬로스를 계산하는 기능
바이너리 크로스엔트로피 적용을 위해 GT에 대한 전처리(one hot encording) 장치가 추가됨
args:
    pred (Tensor): 모델의 예측 결과(logit)
    gt (Tensor): 클래스 번호

Definition at line 266 of file classification_loss.py.

266 def calc_loss(self, pred, gt):
267 """
268 시그모이드 포컬로스를 계산하는 기능
269 바이너리 크로스엔트로피 적용을 위해 GT에 대한 전처리(one hot encording) 장치가 추가됨
270 args:
271 pred (Tensor): 모델의 예측 결과(logit)
272 gt (Tensor): 클래스 번호
273 """
274 class_num = pred.shape[1] #클래스 수
275 gt_one_hot = F.one_hot(gt, class_num) #one hot encording
276
277 #로스 계산
278 return self.sigmoid_focal_loss(pred, gt_one_hot)
279
280

◆ forward()

classification_loss.MutiTaskSigmoidFocalLoss.forward ( self,
data )

Definition at line 281 of file classification_loss.py.

281 def forward(self, data):
282 losses = []
283
284 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
285
286 #task에 해당하는 샘플 인덱스
287 target_idx = [task == target_task for task in data[self.task_key]]
288 if sum(target_idx) == 0:
289 continue
290
291 target_idx = torch.tensor(target_idx)
292
293 #로스 계산
294 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
295
296 return sum(losses)

◆ sigmoid_focal_loss()

classification_loss.MutiTaskSigmoidFocalLoss.sigmoid_focal_loss ( self,
inputs,
targets )
시그모이드 포컬로스 계산하는 기능
args:
    inputs (Tensor): 모델의 출력 결과
    targets (Tensor): 모델의 출력에 대응하는 GT
return (Tensor)

Definition at line 229 of file classification_loss.py.

229 def sigmoid_focal_loss(self, inputs, targets):
230 """
231 시그모이드 포컬로스 계산하는 기능
232 args:
233 inputs (Tensor): 모델의 출력 결과
234 targets (Tensor): 모델의 출력에 대응하는 GT
235 return (Tensor)
236 """
237
238 inputs = inputs.float() #타입변환
239 targets = targets.float() #타입변환
240
241 #시그모이드 포컬로스 계산
242 p = torch.sigmoid(inputs) #시그모이드
243 ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") #바이너리 크로스엔트로피
244 p_t = p * targets + (1 - p) * (1 - targets)
245 loss = ce_loss * ((1 - p_t) ** self.gamma) #샘플 난이도 가중치
246
247 if self.alpha >= 0:
248 #클래스 불균형 조정
249 alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
250 loss = alpha_t * loss
251
252 if self.beta > 0:
253 #테스크 난이도 가중치
254 w = torch.sum((1 - p_t) * targets, dim=1).mean() ** self.beta
255 #w = (1 - p_t).mean() ** self.beta
256 loss = w * loss
257
258 #reduction
259 if self.reduction == "mean":
260 loss = loss.mean()
261 elif self.reduction == "sum":
262 loss = loss.sum()
263
264 return loss
265

Member Data Documentation

◆ alpha

classification_loss.MutiTaskSigmoidFocalLoss.alpha

Definition at line 224 of file classification_loss.py.

◆ beta

classification_loss.MutiTaskSigmoidFocalLoss.beta

Definition at line 226 of file classification_loss.py.

◆ gamma

classification_loss.MutiTaskSigmoidFocalLoss.gamma

Definition at line 225 of file classification_loss.py.

◆ gt_keys

classification_loss.MutiTaskSigmoidFocalLoss.gt_keys

Definition at line 218 of file classification_loss.py.

◆ pred_keys

classification_loss.MutiTaskSigmoidFocalLoss.pred_keys

Definition at line 217 of file classification_loss.py.

◆ reduction

classification_loss.MutiTaskSigmoidFocalLoss.reduction

Definition at line 227 of file classification_loss.py.

◆ target_tasks

classification_loss.MutiTaskSigmoidFocalLoss.target_tasks

Definition at line 219 of file classification_loss.py.

◆ task_key

classification_loss.MutiTaskSigmoidFocalLoss.task_key

Definition at line 216 of file classification_loss.py.

◆ weights

classification_loss.MutiTaskSigmoidFocalLoss.weights

Definition at line 220 of file classification_loss.py.


The documentation for this class was generated from the following file: