Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
classification_loss.MutiTaskCrossEntropyLoss Class Reference
Inheritance diagram for classification_loss.MutiTaskCrossEntropyLoss:

Public Member Functions

 __init__ (self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu')
 
 calc_loss (self, pred, gt, pred_key)
 
 forward (self, data)
 

Public Attributes

 task_key
 
 target_tasks
 
 weights
 
 pred_keys
 
 gt_keys
 
 loss
 

Detailed Description

멀티 라벨 or 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
args:
    task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
    pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
    gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
    target_tasks (list[str]): 학습 데이터의 타겟 테스크, task_key에 설정됨
    weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
    data_num (dict or None): 테스크별 클래스의 데이터 수량
    device (str): 모델이 구동하는 디바이스
return (Tensor): 각 테스크별 로스의 가중합

Definition at line 51 of file classification_loss.py.

Constructor & Destructor Documentation

◆ __init__()

classification_loss.MutiTaskCrossEntropyLoss.__init__ ( self,
task_key,
pred_keys,
gt_keys,
target_tasks,
weights = [1.0],
data_num = None,
device = 'cpu' )

Definition at line 66 of file classification_loss.py.

66 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, weights=[1.0], data_num=None, device='cpu'):
67 super(MutiTaskCrossEntropyLoss, self).__init__()
68 self.task_key = task_key
69 self.target_tasks = target_tasks
70 self.weights = weights
71 if weights is None:
72 self.weights = [1.0]*len(pred_keys)
73 self.pred_keys = pred_keys
74 self.gt_keys = gt_keys
75
76 self.loss = nn.ModuleDict()
77 for key in pred_keys:
78
79 if data_num is None:
80 self.loss[key] = nn.CrossEntropyLoss(ignore_index=-1)
81 else:
82 d_num = data_num[key]
83 s_num = sum(d_num)
84 weight = [n/s_num for n in d_num]
85 weight = torch.tensor(weight).to(device)
86 self.loss[key] = nn.CrossEntropyLoss(weight=weight, ignore_index=-1)
87

Member Function Documentation

◆ calc_loss()

classification_loss.MutiTaskCrossEntropyLoss.calc_loss ( self,
pred,
gt,
pred_key )

Definition at line 88 of file classification_loss.py.

88 def calc_loss(self, pred, gt, pred_key):
89 return self.loss[pred_key](pred, gt)
90
91

◆ forward()

classification_loss.MutiTaskCrossEntropyLoss.forward ( self,
data )

Definition at line 92 of file classification_loss.py.

92 def forward(self, data):
93 losses = []
94
95 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
96
97 #task에 해당하는 샘플 인덱스
98 target_idx = [task == target_task for task in data[self.task_key]]
99 if sum(target_idx) == 0:
100 continue
101
102 target_idx = torch.tensor(target_idx)
103
104 #로스 계산
105 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx], pred_key) )
106
107 return sum(losses)
108

Member Data Documentation

◆ gt_keys

classification_loss.MutiTaskCrossEntropyLoss.gt_keys

Definition at line 74 of file classification_loss.py.

◆ loss

classification_loss.MutiTaskCrossEntropyLoss.loss

Definition at line 76 of file classification_loss.py.

◆ pred_keys

classification_loss.MutiTaskCrossEntropyLoss.pred_keys

Definition at line 73 of file classification_loss.py.

◆ target_tasks

classification_loss.MutiTaskCrossEntropyLoss.target_tasks

Definition at line 69 of file classification_loss.py.

◆ task_key

classification_loss.MutiTaskCrossEntropyLoss.task_key

Definition at line 68 of file classification_loss.py.

◆ weights

classification_loss.MutiTaskCrossEntropyLoss.weights

Definition at line 70 of file classification_loss.py.


The documentation for this class was generated from the following file: