Safemotion Lib
Loading...
Searching...
No Matches
gcn_head.py
Go to the documentation of this file.
1import torch.nn as nn
2
3class GCNHead(nn.Module):
4
5 def __init__(self,
6 in_channels,
7 num_class,
8 input_key):
9 super().__init__()
10
11 self.pool = nn.AdaptiveAvgPool2d(1)
12 self.fc = nn.Linear(in_channels, num_class)
13 self.input_key = input_key
14
15
16 def forward(self, feat_dict):
17 x = feat_dict[self.input_key]
18
19 # N, M, C, T, V = x.shape
20 N, C, T, V = x.shape
21
22 # x = x.view(N * M, C, T, V)
23
24 # global pooling
25 x = self.pool(x)
26 x = x.view(N, C)
27 # x = x.mean(dim=1)
28
29 cls_scores = self.fc(x)
30
31 return cls_scores
__init__(self, in_channels, num_class, input_key)
Definition gcn_head.py:8
forward(self, feat_dict)
Definition gcn_head.py:16