Safemotion Lib
Loading...
Searching...
No Matches
loss_builder.py
Go to the documentation of this file.
1from smrunner.losses.classification_loss import CrossEntropyLoss, MutiTaskCrossEntropyLoss, MutiTaskSigmoidFocalLoss, MutiTaskMultiLabelCrossEntropyLoss
2
3
4__loss_builder__ = {
5 "CrossEntropyLoss" : CrossEntropyLoss,
6 "MutiTaskCrossEntropyLoss" : MutiTaskCrossEntropyLoss,
7 "MutiTaskSigmoidFocalLoss" : MutiTaskSigmoidFocalLoss,
8 "MutiTaskMultiLabelCrossEntropyLoss" : MutiTaskMultiLabelCrossEntropyLoss
9
10}
11
13 """
14 다양한 로스의 결합을 지원하는 클래스
15 args:
16 loss_evaluators (list): 로스 클래스 인스턴스 리스트
17 loss_weights (list[float]): 각 로스별 가중치
18 """
19 def __init__(self, loss_evaluators, loss_weights):
20 self.loss_evaluators = loss_evaluators
21 self.loss_weights = loss_weights
22
23 def __call__(self, data):
24 losses = []
25 for weight, loss_evaluator in zip(self.loss_weights, self.loss_evaluators):
26 losses.append(weight*loss_evaluator(data))
27
28 return sum(losses)
29
30
31def build_loss(cfg):
32 """
33 모델의 학습에 사용할 로스 계산기를 빌드하는 기능
34 args:
35 cfg (str or Config): 학습 파라미터가 정의된 모델의 config 파일 경로 또는 mmengine.config.Config.fromfile()로 생성한 객체
36 """
37
38 #config 로드
39 if isinstance(cfg, str):
40 cfg = Config.fromfile(cfg_path)
41
42 loss_evaluators = []
43 loss_weights = []
44
45 losses = cfg.loss #로스 파라미터
46 loss_types = losses.keys() #로스 클래스 목록
47
48 for type in loss_types:
49
50 #예외처리
51 assert type in __loss_builder__, \
52 f'not found loss type : {type}'
53
54 loss_cfg = losses.get(type) #losses[type]
55 loss_weights.append(loss_cfg.weight) #로스에 대한 가중치
56 loss_cfg.pop('weight')
57 loss_evaluators.append(__loss_builder__[type](**loss_cfg)) #로스 클래스 빌드
58
59 #로스의 결합 클래스 빌드
60 return CombinedLossEvaluators(loss_evaluators, loss_weights)
__init__(self, loss_evaluators, loss_weights)