Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
classification_loss.MutiTaskMultiLabelCrossEntropyLoss Class Reference
Inheritance diagram for classification_loss.MutiTaskMultiLabelCrossEntropyLoss:

Public Member Functions

 __init__ (self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args)
 
 calc_loss (self, pred, gt)
 
 forward (self, data)
 

Public Attributes

 task_key
 
 pred_keys
 
 gt_keys
 
 target_tasks
 
 train_tasks
 
 weights
 
 target_task_weight
 
 non_target_task_weight
 
 loss
 

Detailed Description

멀티 라벨, 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 모든 테스크에 대한 라벨값을 가지고 있음(멀티 라벨 데이터)
학습 데이터는 메인 테스크가 설정되어 있음, 하나의 학습 데이터에 대해 모든 테스크에 대해 로스를 계산하고 가중합을 취함
메인 테스크와 메인이 아닌 테스크에 대한 중요도 조절을 위해 가중치 조정이 필요함
TODO
    모든 라벨에 대해서 학습하기 때문에 데이터 불균현 문제가 존재함 -> 해결 방안이 필요함
args:
    task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
    pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
    gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
    target_tasks (list[str]): 학습 데이터의 메인 테스크, task_key에 설정됨
    train_tasks (list[str]): 학습할 모든 테스크(라벨)
    target_task_weight (float): 메인 테스크에 대한 가중치
    non_target_task_weight (float): 메인이 아닌 테스크에 대한 가중치
    weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
return (Tensor): 각 테스크별 로스의 가중합

Definition at line 109 of file classification_loss.py.

Constructor & Destructor Documentation

◆ __init__()

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.__init__ ( self,
task_key,
pred_keys,
gt_keys,
target_tasks,
train_tasks,
target_task_weight = 1.0,
non_target_task_weight = 0.025,
weights = None,
** args )

Definition at line 128 of file classification_loss.py.

128 def __init__(self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args):
129 super(MutiTaskMultiLabelCrossEntropyLoss, self).__init__()
130 self.task_key = task_key
131 self.pred_keys = pred_keys
132 self.gt_keys = gt_keys
133 self.target_tasks = target_tasks
134 self.train_tasks = train_tasks
135 self.weights = weights
136 if weights is None:
137 self.weights = [1.0]*len(pred_keys)
138
139 self.target_task_weight = target_task_weight
140 self.non_target_task_weight = non_target_task_weight
141
142 self.loss = nn.CrossEntropyLoss(ignore_index=-1)
143
144

Member Function Documentation

◆ calc_loss()

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.calc_loss ( self,
pred,
gt )

Definition at line 145 of file classification_loss.py.

145 def calc_loss(self, pred, gt):
146 return self.loss(pred, gt)
147
148

◆ forward()

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.forward ( self,
data )

Definition at line 149 of file classification_loss.py.

149 def forward(self, data):
150
151 losses = []
152
153 for pred_key, gt_key, target_task, weight in zip(self.pred_keys, self.gt_keys, self.target_tasks, self.weights):
154
155 #TODO : 코드가 비효율적으로 작성된것 처럼 보임, 효율적인 구조로 변경 필요
156 if self.non_target_task_weight > 0: #하나의 샘플이 모든 테스크에 대해 학습
157 for train_task in self.train_tasks: #학습 테스크
158
159 #가중치 설정, 테스크 기본 가중치* 메인 or Non-Main 테크스 가중치
160 if train_task == target_task: #메인 테스크
161 weight = weight*self.target_task_weight
162 else: #메인 테스크가 아닌 테스크
163 weight = weight*self.non_target_task_weight
164
165 #학습 테스크 샘플 인덱스
166 idx = [ task == train_task for task in data[self.task_key]]
167 if sum(idx) == 0:
168 continue
169
170 idx = torch.tensor(idx)
171
172 #학습 테스크의 클래스가 -1만 있을경우 처리, -1은 크로스 엔트로피에 ignore_index로 설정되어 있음, 따라서 -1인 경우 학습되지 않음
173 if (data[gt_key][idx] != -1).sum() == 0:
174 continue
175
176 losses.append(weight*self.calc_loss(data[pred_key][idx], data[gt_key][idx]))
177 else: #하나의 샘플이 메인 테스크에 대해서만 학습, MutiTaskCrossEntropyLoss와 동일
178 #task에 해당하는 샘플 인덱스
179 target_idx = [task == target_task for task in data[self.task_key]]
180 if sum(target_idx) == 0:
181 continue
182
183 target_idx = torch.tensor(target_idx)
184
185 #로스 계산
186 losses.append( weight * self.calc_loss(data[pred_key][target_idx], data[gt_key][target_idx]) )
187
188
189 return sum(losses)
190
191

Member Data Documentation

◆ gt_keys

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.gt_keys

Definition at line 132 of file classification_loss.py.

◆ loss

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.loss

Definition at line 142 of file classification_loss.py.

◆ non_target_task_weight

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.non_target_task_weight

Definition at line 140 of file classification_loss.py.

◆ pred_keys

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.pred_keys

Definition at line 131 of file classification_loss.py.

◆ target_task_weight

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.target_task_weight

Definition at line 139 of file classification_loss.py.

◆ target_tasks

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.target_tasks

Definition at line 133 of file classification_loss.py.

◆ task_key

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.task_key

Definition at line 130 of file classification_loss.py.

◆ train_tasks

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.train_tasks

Definition at line 134 of file classification_loss.py.

◆ weights

classification_loss.MutiTaskMultiLabelCrossEntropyLoss.weights

Definition at line 135 of file classification_loss.py.


The documentation for this class was generated from the following file: