Safemotion Lib
Loading...
Searching...
No Matches
smaction
models
heads
i3d_fusion.py
Go to the documentation of this file.
1
import
torch
2
import
torch.nn
as
nn
3
4
class
I3DFusion
(nn.Module):
5
"""
6
3차원 입력 텐서들을 연결하는 기능의 모듈
7
입력 텐서들을 풀링하여 하나의 텐서로 이어붙이고 1차원으로 변환함, 추가적으로 fc+bn+relu 연산을 적용할 수 있음
8
args:
9
in_channels (int) : 입력 텐서 채널들의 합
10
out_channels (int) : 출력 텐서의 채널 수, 0 이상값을 설정하면 fc+bn+relu 연산을 적용함
11
dropout_ratio (float) : fc+bn+relu 연산을 사용할 경우 드랍아웃 비율을 설정하는 파라미터
12
input_key (str) : inference에 사용하는 입력데이터의 키값들
13
"""
14
def
__init__
(self, in_channels, out_channels, dropout_ratio, input_key):
15
super().
__init__
()
16
self.
out_channels
= out_channels
17
self.
dropout_ratio
= dropout_ratio
18
self.
input_key
= input_key
19
self.
avg_pool
= nn.AdaptiveAvgPool3d((1, 1, 1))
20
21
if
self.
out_channels
> 0:
22
self.
fc
= nn.Linear(in_channels, out_channels)
23
self.
bn
= nn.BatchNorm1d(out_channels)
24
self.
relu
= nn.ReLU(inplace=
True
)
25
26
if
self.
dropout_ratio
!= 0:
27
self.
dropout
= nn.Dropout(p=self.
dropout_ratio
)
28
else
:
29
self.
dropout
=
None
30
31
def
forward
(self, feat_dict):
32
"""
33
args:
34
sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
35
self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W)
36
B : 배치 크기
37
C : 입력 채널
38
T : 시간
39
40
return (Tensor):
41
1차원 특징 벡터 -> shape (B, C_o)
42
B : 배치 크기
43
C_o : 채널 크기, out_channels 또는 입력 채널의 합
44
"""
45
feats = []
46
for
key
in
self.
input_key
:
47
x = feat_dict[key]
48
x = self.
avg_pool
(x)
49
feats.append(x)
50
51
fusion_feat = torch.cat(feats, dim=1)
52
fusion_feat = fusion_feat.view(x.shape[0], -1)
53
54
if
self.
out_channels
> 0:
55
if
self.
dropout
is
not
None
:
56
fusion_feat = self.
dropout
(fusion_feat)
57
58
fusion_feat = self.
relu
(self.
bn
(self.
fc
(fusion_feat)))
59
60
return
fusion_feat
61
else
:
62
return
fusion_feat
63
i3d_fusion.I3DFusion
Definition
i3d_fusion.py:4
i3d_fusion.I3DFusion.fc
fc
Definition
i3d_fusion.py:22
i3d_fusion.I3DFusion.dropout
dropout
Definition
i3d_fusion.py:27
i3d_fusion.I3DFusion.relu
relu
Definition
i3d_fusion.py:24
i3d_fusion.I3DFusion.avg_pool
avg_pool
Definition
i3d_fusion.py:19
i3d_fusion.I3DFusion.__init__
__init__(self, in_channels, out_channels, dropout_ratio, input_key)
Definition
i3d_fusion.py:14
i3d_fusion.I3DFusion.out_channels
out_channels
Definition
i3d_fusion.py:16
i3d_fusion.I3DFusion.forward
forward(self, feat_dict)
Definition
i3d_fusion.py:31
i3d_fusion.I3DFusion.dropout_ratio
dropout_ratio
Definition
i3d_fusion.py:17
i3d_fusion.I3DFusion.input_key
input_key
Definition
i3d_fusion.py:18
i3d_fusion.I3DFusion.bn
bn
Definition
i3d_fusion.py:23
torch.nn
Generated by
1.10.0