|
Safemotion Lib
|
Public Member Functions | |
| __init__ (self, task_key, pred_keys, gt_keys, target_tasks, train_tasks, target_task_weight=1.0, non_target_task_weight=0.025, weights=None, **args) | |
| calc_loss (self, pred, gt) | |
| forward (self, data) | |
Public Attributes | |
| task_key | |
| pred_keys | |
| gt_keys | |
| target_tasks | |
| train_tasks | |
| weights | |
| target_task_weight | |
| non_target_task_weight | |
| loss | |
멀티 라벨, 멀티 테스크 문제를 학습하기 위한 크로스엔트로피 클래스
멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 모든 테스크에 대한 라벨값을 가지고 있음(멀티 라벨 데이터)
학습 데이터는 메인 테스크가 설정되어 있음, 하나의 학습 데이터에 대해 모든 테스크에 대해 로스를 계산하고 가중합을 취함
메인 테스크와 메인이 아닌 테스크에 대한 중요도 조절을 위해 가중치 조정이 필요함
TODO
모든 라벨에 대해서 학습하기 때문에 데이터 불균현 문제가 존재함 -> 해결 방안이 필요함
args:
task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
target_tasks (list[str]): 학습 데이터의 메인 테스크, task_key에 설정됨
train_tasks (list[str]): 학습할 모든 테스크(라벨)
target_task_weight (float): 메인 테스크에 대한 가중치
non_target_task_weight (float): 메인이 아닌 테스크에 대한 가중치
weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
return (Tensor): 각 테스크별 로스의 가중합
Definition at line 109 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.__init__ | ( | self, | |
| task_key, | |||
| pred_keys, | |||
| gt_keys, | |||
| target_tasks, | |||
| train_tasks, | |||
| target_task_weight = 1.0, | |||
| non_target_task_weight = 0.025, | |||
| weights = None, | |||
| ** | args ) |
Definition at line 128 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.calc_loss | ( | self, | |
| pred, | |||
| gt ) |
Definition at line 145 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.forward | ( | self, | |
| data ) |
Definition at line 149 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.gt_keys |
Definition at line 132 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.loss |
Definition at line 142 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.non_target_task_weight |
Definition at line 140 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.pred_keys |
Definition at line 131 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.target_task_weight |
Definition at line 139 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.target_tasks |
Definition at line 133 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.task_key |
Definition at line 130 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.train_tasks |
Definition at line 134 of file classification_loss.py.
| classification_loss.MutiTaskMultiLabelCrossEntropyLoss.weights |
Definition at line 135 of file classification_loss.py.