Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
classification_loss.CrossEntropyLoss Class Reference
Inheritance diagram for classification_loss.CrossEntropyLoss:

Public Member Functions

 __init__ (self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu')
 
 calc_loss (self, pred, gt, pred_key)
 
 forward (self, data)
 

Public Attributes

 pred_keys
 
 gt_keys
 
 weights
 
 loss
 

Detailed Description

모델의 출력(라벨 or 테스크)이 여러개일때 크로스엔트로피 로스를 계산하기위한 클래스
args:
    pred_keys (list[str]): 모델이 예측한 결과(inference) 값이 저장된 키
    gt_keys (list[str]): 학습 데이터가 저장된 키, pred_keys에 대응하는 키가 순서에 맞게 들어있어야함
    weights (list[float]): 각 테스크의 로스에 대한 가중치, 설정하지 않으면 동일 가중치 사용
    data_num (list[int] or None): 클래스의 데이터 수량
    device (str): 모델이 구동하는 디바이스
return (Tensor): 각 테스크 로스의 가중합

Definition at line 5 of file classification_loss.py.

Constructor & Destructor Documentation

◆ __init__()

classification_loss.CrossEntropyLoss.__init__ ( self,
pred_keys,
gt_keys,
weights = None,
data_num = None,
device = 'cpu' )

Definition at line 16 of file classification_loss.py.

16 def __init__(self, pred_keys, gt_keys, weights=None, data_num=None, device='cpu'):
17 super(CrossEntropyLoss, self).__init__()
18
19
20 self.pred_keys = pred_keys
21 self.gt_keys = gt_keys
22 self.weights = weights
23 if weights is None:
24 self.weights = [1.0]*len(pred_keys)
25
26 self.loss = nn.ModuleDict()
27 for key in pred_keys:
28
29 if data_num is None:
30 self.loss[key] = nn.CrossEntropyLoss()
31 else:
32 d_num = data_num[key]
33 s_num = sum(d_num)
34 weight = [n/s_num for n in d_num]
35 weight = torch.tensor(weight).to(device)
36 self.loss[key] = nn.CrossEntropyLoss(weight=weight)
37

Member Function Documentation

◆ calc_loss()

classification_loss.CrossEntropyLoss.calc_loss ( self,
pred,
gt,
pred_key )

Definition at line 38 of file classification_loss.py.

38 def calc_loss(self, pred, gt, pred_key):
39 return self.loss[pred_key](pred, gt)
40
41

◆ forward()

classification_loss.CrossEntropyLoss.forward ( self,
data )

Definition at line 42 of file classification_loss.py.

42 def forward(self, data):
43 losses = []
44
45 for pred_key, gt_key, weight in zip(self.pred_keys, self.gt_keys, self.weights):
46 losses.append( weight * self.calc_loss(data[pred_key], data[gt_key], pred_key) )
47
48 return sum(losses)
49
50

Member Data Documentation

◆ gt_keys

classification_loss.CrossEntropyLoss.gt_keys

Definition at line 21 of file classification_loss.py.

◆ loss

classification_loss.CrossEntropyLoss.loss

Definition at line 26 of file classification_loss.py.

◆ pred_keys

classification_loss.CrossEntropyLoss.pred_keys

Definition at line 20 of file classification_loss.py.

◆ weights

classification_loss.CrossEntropyLoss.weights

Definition at line 22 of file classification_loss.py.


The documentation for this class was generated from the following file: