Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
cat_layer.CatLayer Class Reference
Inheritance diagram for cat_layer.CatLayer:

Public Member Functions

 __init__ (self, input_key, input_type)
 
 forward (self, feat_dict)
 

Public Attributes

 input_key
 
 input_type
 
 avg_pool_dict
 

Detailed Description

입력 텐서들을 연결하는 기능의 모듈
입력 텐서들을 풀링하여 1차원으로 만들고 하나의 텐서로 이어붙이는 기능을 수행함
args:
    input_key (list[str]) : 모듈의 inference에 사용하는 입력데이터의 키값들
    input_type (list[str]) : input_key에 대응하는 데이터의 구조
        ex) ['2d', '3d']

Definition at line 4 of file cat_layer.py.

Constructor & Destructor Documentation

◆ __init__()

cat_layer.CatLayer.__init__ ( self,
input_key,
input_type )

Definition at line 13 of file cat_layer.py.

13 def __init__(self, input_key, input_type):
14 super().__init__()
15 self.input_key = input_key
16 self.input_type = input_type
17
18 self.avg_pool_dict = nn.ModuleDict()
19 for key, in_type in zip(self.input_key, self.input_type):
20 if in_type == '3d':
21 self.avg_pool_dict[key] = nn.AdaptiveAvgPool3d((1, 1, 1))
22 elif in_type == '2d':
23 self.avg_pool_dict[key] = nn.AdaptiveAvgPool2d((1, 1))
24

Member Function Documentation

◆ forward()

cat_layer.CatLayer.forward ( self,
feat_dict )
args:
    sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
        self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W) or (B, C, H(or T), W)
            B : 배치 크기
            C : 입력 채널
            T : 시간

return (Tensor):
    1차원 특징 벡터 -> shape (B, C_o)
        B : 배치 크기
        C_o : 채널 크기, 입력 채널의 합

Definition at line 25 of file cat_layer.py.

25 def forward(self, feat_dict):
26 """
27 args:
28 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
29 self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W) or (B, C, H(or T), W)
30 B : 배치 크기
31 C : 입력 채널
32 T : 시간
33
34 return (Tensor):
35 1차원 특징 벡터 -> shape (B, C_o)
36 B : 배치 크기
37 C_o : 채널 크기, 입력 채널의 합
38 """
39 feats = []
40 for key in self.input_key:
41 x = feat_dict[key]
42 x = self.avg_pool_dict[key](x)
43 x = x.view(x.shape[0], -1)
44 feats.append(x)
45
46 cat_feat = torch.cat(feats, dim=1)
47
48 return cat_feat

Member Data Documentation

◆ avg_pool_dict

cat_layer.CatLayer.avg_pool_dict

Definition at line 18 of file cat_layer.py.

◆ input_key

cat_layer.CatLayer.input_key

Definition at line 15 of file cat_layer.py.

◆ input_type

cat_layer.CatLayer.input_type

Definition at line 16 of file cat_layer.py.


The documentation for this class was generated from the following file: