Safemotion Lib
Loading...
Searching...
No Matches
smaction
models
heads
gcn_head.py
Go to the documentation of this file.
1
import
torch.nn
as
nn
2
3
class
GCNHead
(nn.Module):
4
5
def
__init__
(self,
6
in_channels,
7
num_class,
8
input_key):
9
super().
__init__
()
10
11
self.
pool
= nn.AdaptiveAvgPool2d(1)
12
self.
fc
= nn.Linear(in_channels, num_class)
13
self.
input_key
= input_key
14
15
16
def
forward
(self, feat_dict):
17
x = feat_dict[self.
input_key
]
18
19
# N, M, C, T, V = x.shape
20
N, C, T, V = x.shape
21
22
# x = x.view(N * M, C, T, V)
23
24
# global pooling
25
x = self.
pool
(x)
26
x = x.view(N, C)
27
# x = x.mean(dim=1)
28
29
cls_scores = self.
fc
(x)
30
31
return
cls_scores
gcn_head.GCNHead
Definition
gcn_head.py:3
gcn_head.GCNHead.fc
fc
Definition
gcn_head.py:12
gcn_head.GCNHead.__init__
__init__(self, in_channels, num_class, input_key)
Definition
gcn_head.py:8
gcn_head.GCNHead.pool
pool
Definition
gcn_head.py:11
gcn_head.GCNHead.forward
forward(self, feat_dict)
Definition
gcn_head.py:16
gcn_head.GCNHead.input_key
input_key
Definition
gcn_head.py:13
torch.nn
Generated by
1.10.0