Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
timm_backbone.TimmBackbone Class Reference
Inheritance diagram for timm_backbone.TimmBackbone:

Public Member Functions

 __init__ (self, model_name, input_key, pretrained=False)
 
 forward (self, sample)
 

Public Attributes

 input_key
 
 timm_model
 

Detailed Description

timm 패키지의 백본을 사용하기 위한 클래스
timm 모델의 head는 nn.Identity()로 치완함

Definition at line 5 of file timm_backbone.py.

Constructor & Destructor Documentation

◆ __init__()

timm_backbone.TimmBackbone.__init__ ( self,
model_name,
input_key,
pretrained = False )
args:
    model_name (str) : timm 패키지에서 지원하는 모델 이름
    input_key (str) : inference에 사용하는 입력데이터의 키값
    pretrained (bool) : timm 에서 제공하는 프리트레인 파라미터 사용 여부

Definition at line 11 of file timm_backbone.py.

11 def __init__(self, model_name, input_key, pretrained=False):
12 """
13 args:
14 model_name (str) : timm 패키지에서 지원하는 모델 이름
15 input_key (str) : inference에 사용하는 입력데이터의 키값
16 pretrained (bool) : timm 에서 제공하는 프리트레인 파라미터 사용 여부
17 """
18 super().__init__()
19
20 self.input_key = input_key
21
22 if isinstance(pretrained, bool):
23 self.timm_model = timm.create_model(model_name, pretrained=pretrained)
24 self.timm_model.head = nn.Identity()
25 else:
26 self.timm_model = timm.create_model(model_name, pretrained=False)
27 self.timm_model.head = nn.Identity()
28
29 #TODO : 파라미터 로드하는 코드 작성 필요
30

Member Function Documentation

◆ forward()

timm_backbone.TimmBackbone.forward ( self,
sample )

Definition at line 31 of file timm_backbone.py.

31 def forward(self, sample):
32 x = sample[self.input_key]
33 return self.timm_model(x)

Member Data Documentation

◆ input_key

timm_backbone.TimmBackbone.input_key

Definition at line 20 of file timm_backbone.py.

◆ timm_model

timm_backbone.TimmBackbone.timm_model

Definition at line 23 of file timm_backbone.py.


The documentation for this class was generated from the following file: