11 def __init__(self, model_name, input_key, pretrained=False):
12 """
13 args:
14 model_name (str) : timm 패키지에서 지원하는 모델 이름
15 input_key (str) : inference에 사용하는 입력데이터의 키값
16 pretrained (bool) : timm 에서 제공하는 프리트레인 파라미터 사용 여부
17 """
18 super().__init__()
19
20 self.input_key = input_key
21
22 if isinstance(pretrained, bool):
23 self.timm_model = timm.create_model(model_name, pretrained=pretrained)
24 self.timm_model.head = nn.Identity()
25 else:
26 self.timm_model = timm.create_model(model_name, pretrained=False)
27 self.timm_model.head = nn.Identity()
28
29
30