Safemotion Lib
Loading...
Searching...
No Matches
Classes | Functions
loss_builder Namespace Reference

Classes

class  CombinedLossEvaluators
 

Functions

 build_loss (cfg)
 

Function Documentation

◆ build_loss()

loss_builder.build_loss ( cfg)
모델의 학습에 사용할 로스 계산기를 빌드하는 기능
args:
    cfg (str or Config): 학습 파라미터가 정의된 모델의 config 파일 경로 또는 mmengine.config.Config.fromfile()로 생성한 객체

Definition at line 31 of file loss_builder.py.

31def build_loss(cfg):
32 """
33 모델의 학습에 사용할 로스 계산기를 빌드하는 기능
34 args:
35 cfg (str or Config): 학습 파라미터가 정의된 모델의 config 파일 경로 또는 mmengine.config.Config.fromfile()로 생성한 객체
36 """
37
38 #config 로드
39 if isinstance(cfg, str):
40 cfg = Config.fromfile(cfg_path)
41
42 loss_evaluators = []
43 loss_weights = []
44
45 losses = cfg.loss #로스 파라미터
46 loss_types = losses.keys() #로스 클래스 목록
47
48 for type in loss_types:
49
50 #예외처리
51 assert type in __loss_builder__, \
52 f'not found loss type : {type}'
53
54 loss_cfg = losses.get(type) #losses[type]
55 loss_weights.append(loss_cfg.weight) #로스에 대한 가중치
56 loss_cfg.pop('weight')
57 loss_evaluators.append(__loss_builder__[type](**loss_cfg)) #로스 클래스 빌드
58
59 #로스의 결합 클래스 빌드
60 return CombinedLossEvaluators(loss_evaluators, loss_weights)