Safemotion Lib
Loading...
Searching...
No Matches
timm_backbone.py
Go to the documentation of this file.
1import timm
2import torch
3import torch.nn as nn
4
5class TimmBackbone(nn.Module):
6 """
7 timm 패키지의 백본을 사용하기 위한 클래스
8 timm 모델의 head는 nn.Identity()로 치완함
9 """
10
11 def __init__(self, model_name, input_key, pretrained=False):
12 """
13 args:
14 model_name (str) : timm 패키지에서 지원하는 모델 이름
15 input_key (str) : inference에 사용하는 입력데이터의 키값
16 pretrained (bool) : timm 에서 제공하는 프리트레인 파라미터 사용 여부
17 """
18 super().__init__()
19
20 self.input_key = input_key
21
22 if isinstance(pretrained, bool):
23 self.timm_model = timm.create_model(model_name, pretrained=pretrained)
24 self.timm_model.head = nn.Identity()
25 else:
26 self.timm_model = timm.create_model(model_name, pretrained=False)
27 self.timm_model.head = nn.Identity()
28
29 #TODO : 파라미터 로드하는 코드 작성 필요
30
31 def forward(self, sample):
32 x = sample[self.input_key]
33 return self.timm_model(x)
__init__(self, model_name, input_key, pretrained=False)