Safemotion Lib
Loading...
Searching...
No Matches
smaction
models
heads
cat_layer.py
Go to the documentation of this file.
1
import
torch
2
import
torch.nn
as
nn
3
4
class
CatLayer
(nn.Module):
5
"""
6
입력 텐서들을 연결하는 기능의 모듈
7
입력 텐서들을 풀링하여 1차원으로 만들고 하나의 텐서로 이어붙이는 기능을 수행함
8
args:
9
input_key (list[str]) : 모듈의 inference에 사용하는 입력데이터의 키값들
10
input_type (list[str]) : input_key에 대응하는 데이터의 구조
11
ex) ['2d', '3d']
12
"""
13
def
__init__
(self, input_key, input_type):
14
super().
__init__
()
15
self.
input_key
= input_key
16
self.
input_type
= input_type
17
18
self.
avg_pool_dict
= nn.ModuleDict()
19
for
key, in_type
in
zip(self.
input_key
, self.
input_type
):
20
if
in_type ==
'3d'
:
21
self.
avg_pool_dict
[key] = nn.AdaptiveAvgPool3d((1, 1, 1))
22
elif
in_type ==
'2d'
:
23
self.
avg_pool_dict
[key] = nn.AdaptiveAvgPool2d((1, 1))
24
25
def
forward
(self, feat_dict):
26
"""
27
args:
28
sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
29
self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W) or (B, C, H(or T), W)
30
B : 배치 크기
31
C : 입력 채널
32
T : 시간
33
34
return (Tensor):
35
1차원 특징 벡터 -> shape (B, C_o)
36
B : 배치 크기
37
C_o : 채널 크기, 입력 채널의 합
38
"""
39
feats = []
40
for
key
in
self.
input_key
:
41
x = feat_dict[key]
42
x = self.
avg_pool_dict
[key](x)
43
x = x.view(x.shape[0], -1)
44
feats.append(x)
45
46
cat_feat = torch.cat(feats, dim=1)
47
48
return
cat_feat
cat_layer.CatLayer
Definition
cat_layer.py:4
cat_layer.CatLayer.input_type
input_type
Definition
cat_layer.py:16
cat_layer.CatLayer.__init__
__init__(self, input_key, input_type)
Definition
cat_layer.py:13
cat_layer.CatLayer.input_key
input_key
Definition
cat_layer.py:15
cat_layer.CatLayer.forward
forward(self, feat_dict)
Definition
cat_layer.py:25
cat_layer.CatLayer.avg_pool_dict
avg_pool_dict
Definition
cat_layer.py:18
torch.nn
Generated by
1.10.0