Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
mlp_head.MLPHead Class Reference
Inheritance diagram for mlp_head.MLPHead:

Public Member Functions

 __init__ (self, in_channels, num_classes, layer_channels, dropout_ratio, input_key, input_type='linear')
 
 forward (self, feat_dict)
 

Public Attributes

 dropout_ratio
 
 input_key
 
 mlp
 
 avg_pool
 

Detailed Description

여러 텐서들을 입력으로 다층 레이어를 기반으로 클래스 스코어를 출력하는 기능
3차원 또는 1차원 입력 텐서들을 받아서 1차원 테서로 풀링한 뒤 하나의 텐서로 이어 붙임
이후 MLP레이어를 통과시켜 클래스 스코어를 출력함
args:
    in_channels (int) : 입력 텐서들의 채널 합
    num_classes (int) : 클래스 수
    layer_channels (list[int]) : 각 측의 레이어 채널 수
    dropout_ratio (int) : 드랍아웃 비율
    input_key (list[str]) : 모듈의 inference에 사용하는 입력데이터의 키값들
    input_type (str) : input_key에 대응하는 데이터의 구조
        ex) ['linear', '3d']

Definition at line 4 of file mlp_head.py.

Constructor & Destructor Documentation

◆ __init__()

mlp_head.MLPHead.__init__ ( self,
in_channels,
num_classes,
layer_channels,
dropout_ratio,
input_key,
input_type = 'linear' )

Definition at line 18 of file mlp_head.py.

18 def __init__(self, in_channels, num_classes, layer_channels, dropout_ratio, input_key, input_type='linear'):
19 super().__init__()
20
21 self.dropout_ratio = dropout_ratio
22 self.input_key = input_key
23
24 layer_num = len(layer_channels)
25 layer_list = []
26
27 in_ch = in_channels
28 for i in range(layer_num):
29 out_ch = layer_channels[i]
30 if self.dropout_ratio != 0:
31 layer_list.append(nn.Dropout(p=self.dropout_ratio))
32 layer_list.append(nn.Linear(in_ch, out_ch))
33 layer_list.append(nn.BatchNorm1d(out_ch))
34 layer_list.append(nn.ReLU(inplace=True))
35 in_ch = out_ch
36
37 if self.dropout_ratio != 0:
38 layer_list.append(nn.Dropout(p=self.dropout_ratio))
39
40 layer_list.append(nn.Linear(in_ch, num_classes))
41
42 self.mlp = nn.Sequential(*layer_list)
43
44 if input_type == '3d':
45 self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
46 else:
47 self.avg_pool = None
48
49

Member Function Documentation

◆ forward()

mlp_head.MLPHead.forward ( self,
feat_dict )
args:
    sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
        self.input_key의 아이템은 Tensor 타입 -> shape (B, C), (B, C, T, H, W)
            B : 배치 크기
            C : 채널 크기
            T : 시간축 크기

return (Tensor):
    1차원 특징 벡터 -> shape (B, num_classes)

Definition at line 50 of file mlp_head.py.

50 def forward(self, feat_dict):
51 """
52 args:
53 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
54 self.input_key의 아이템은 Tensor 타입 -> shape (B, C), (B, C, T, H, W)
55 B : 배치 크기
56 C : 채널 크기
57 T : 시간축 크기
58
59 return (Tensor):
60 1차원 특징 벡터 -> shape (B, num_classes)
61 """
62 x = feat_dict[self.input_key]
63 if self.avg_pool is not None:
64 x = self.avg_pool(x)
65 x = x.view(x.shape[0], -1)
66
67 return self.mlp(x)
68

Member Data Documentation

◆ avg_pool

mlp_head.MLPHead.avg_pool

Definition at line 45 of file mlp_head.py.

◆ dropout_ratio

mlp_head.MLPHead.dropout_ratio

Definition at line 21 of file mlp_head.py.

◆ input_key

mlp_head.MLPHead.input_key

Definition at line 22 of file mlp_head.py.

◆ mlp

mlp_head.MLPHead.mlp

Definition at line 42 of file mlp_head.py.


The documentation for this class was generated from the following file: