|
Safemotion Lib
|
Public Member Functions | |
| __init__ (self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu') | |
| calc_loss (self, pred, gt, pred_key) | |
| forward (self, data) | |
Public Attributes | |
| pred_keys | |
| gt_keys | |
| weights | |
| loss | |
모델의 출력(라벨 or 테스크)이 여러개일때 크로스엔트로피 로스를 계산하기위한 클래스
args:
pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
data_num (list[int] or None): 클래스의 데이터 수량
device (str): 모델이 구동하는 디바이스
return (Tensor): 각 테스크 로스의 가중합
Definition at line 5 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.__init__ | ( | self, | |
| pred_keys, | |||
| gt_keys, | |||
| weights = None, | |||
| data_num = None, | |||
| device = 'cpu' ) |
Definition at line 16 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.calc_loss | ( | self, | |
| pred, | |||
| gt, | |||
| pred_key ) |
Definition at line 38 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.forward | ( | self, | |
| data ) |
Definition at line 42 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.gt_keys |
Definition at line 21 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.loss |
Definition at line 26 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.pred_keys |
Definition at line 20 of file classification_loss.py.
| classification_loss.CrossEntropyLoss.weights |
Definition at line 22 of file classification_loss.py.