|
Safemotion Lib
|
Public Member Functions | |
| __init__ (self, task_key, pred_keys, gt_keys, target_tasks, weights=None, alpha=0.25, gamma=2, beta=0, reduction='mean', **args) | |
| sigmoid_focal_loss (self, inputs, targets) | |
| calc_loss (self, pred, gt) | |
| forward (self, data) | |
Public Attributes | |
| task_key | |
| pred_keys | |
| gt_keys | |
| target_tasks | |
| weights | |
| alpha | |
| gamma | |
| beta | |
| reduction | |
멀티 라벨 또는 멀티 테스크 문제를 학습하기 위한 시그모이드 포컬로스 클래스
멀티 라벨 데이터를 기반, 멀티 라벨 각각을 테스크로 정의함, 학습데이터는 타겟 테스크에 대한 라벨값을 가지고 있음
학습 데이터는 타겟 테스크가 설정되어 있음, 하나의 학습 데이터는 타겟 테스크에 대한 경로만 학습함
args:
task_key (str): 학습할 테스크 또는 라벨에 대한 정보가 설정된 키
pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
target_tasks (list[str]): 학습 데이터의 테스크, task_key에 설정됨, 해당 테스크 경로만 학습함
weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
alpha (float): 클래스 불균형 파라미터
gamma (float): 샘플 난이도 가중치
beta (float): 테스크 난이도에 대한 가중치
TODO: pjm 추가, 성능이 좋지 않다고 판단되면 제거
reduction (str): 최종 출력 로스에 적용할 연산
'mean': 평균
'sum': 합
None : 연산이 적용되지 않음
return (Tensor): 각 테스크별 로스의 가중합
Definition at line 192 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.__init__ | ( | self, | |
| task_key, | |||
| pred_keys, | |||
| gt_keys, | |||
| target_tasks, | |||
| weights = None, | |||
| alpha = 0.25, | |||
| gamma = 2, | |||
| beta = 0, | |||
| reduction = 'mean', | |||
| ** | args ) |
Definition at line 213 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.calc_loss | ( | self, | |
| pred, | |||
| gt ) |
시그모이드 포컬로스를 계산하는 기능
바이너리 크로스엔트로피 적용을 위해 GT에 대한 전처리(one hot encording) 장치가 추가됨
args:
pred (Tensor): 모델의 예측 결과(logit)
gt (Tensor): 클래스 번호
Definition at line 266 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.forward | ( | self, | |
| data ) |
Definition at line 281 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.sigmoid_focal_loss | ( | self, | |
| inputs, | |||
| targets ) |
시그모이드 포컬로스 계산하는 기능
args:
inputs (Tensor): 모델의 출력 결과
targets (Tensor): 모델의 출력에 대응하는 GT
return (Tensor)
Definition at line 229 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.alpha |
Definition at line 224 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.beta |
Definition at line 226 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.gamma |
Definition at line 225 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.gt_keys |
Definition at line 218 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.pred_keys |
Definition at line 217 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.reduction |
Definition at line 227 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.target_tasks |
Definition at line 219 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.task_key |
Definition at line 216 of file classification_loss.py.
| classification_loss.MutiTaskSigmoidFocalLoss.weights |
Definition at line 220 of file classification_loss.py.