Safemotion Lib
Loading...
Searching...
No Matches
smaction
utils
gcn_utils.py
Go to the documentation of this file.
1
import
torch
2
import
torch.nn
as
nn
3
4
class
GCNBlock
(nn.Module):
5
6
r"""The basic module for applying a graph convolution.
7
8
Args:
9
in_channels (int): Number of channels in the input sequence data
10
out_channels (int): Number of channels produced by the convolution
11
kernel_size (int): Size of the graph convolving kernel
12
t_kernel_size (int): Size of the temporal convolving kernel
13
t_stride (int, optional): Stride of the temporal convolution. Default: 1
14
t_padding (int, optional): Temporal zero-padding added to both sides of
15
the input. Default: 0
16
t_dilation (int, optional): Spacing between temporal kernel elements.
17
Default: 1
18
bias (bool, optional): If ``True``, adds a learnable bias to the output.
19
Default: ``True``
20
21
Shape:
22
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
23
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
24
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
25
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
26
27
where
28
:math:`N` is a batch size,
29
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
30
:math:`T_{in}/T_{out}` is a length of input/output sequence,
31
:math:`V` is the number of graph nodes.
32
"""
33
34
def
__init__
(self,
35
in_channels,
36
out_channels,
37
kernel_size,
38
t_kernel_size=1,
39
t_stride=1,
40
t_padding=0,
41
t_dilation=1,
42
bias=True):
43
super().
__init__
()
44
45
self.
kernel_size
= kernel_size
46
self.
conv
= nn.Conv2d(
47
in_channels,
48
out_channels * kernel_size,
49
kernel_size=(t_kernel_size, 1),
50
padding=(t_padding, 0),
51
stride=(t_stride, 1),
52
dilation=(t_dilation, 1),
53
bias=bias)
54
55
def
forward
(self, x, A):
56
assert
A.size(0) == self.
kernel_size
57
58
x = self.
conv
(x)
59
60
n, kc, t, v = x.size()
61
x = x.view(n, self.
kernel_size
, kc//self.
kernel_size
, t, v)
62
x = torch.einsum(
'nkctv,kvw->nctw'
, (x, A))
63
64
return
x.contiguous(), A
65
66
class
TCNBlock
(nn.Module):
67
68
def
__init__
(self,
69
in_channels,
70
kernel_size,
71
stride,
72
dropout=0):
73
super().
__init__
()
74
padding = ((kernel_size - 1) // 2, 0)
75
self.
kernel_size
= kernel_size
76
self.
tcn
= nn.Sequential(
77
nn.BatchNorm2d(in_channels),
78
nn.ReLU(inplace=
True
),
79
nn.Conv2d(
80
in_channels,
81
in_channels,
82
(kernel_size, 1),
83
(stride, 1),
84
padding,
85
),
86
nn.BatchNorm2d(in_channels),
87
nn.Dropout(dropout, inplace=
True
),
88
)
89
90
def
forward
(self, x):
91
x = self.
tcn
(x)
92
return
x
93
gcn_utils.GCNBlock
Definition
gcn_utils.py:4
gcn_utils.GCNBlock.forward
forward(self, x, A)
Definition
gcn_utils.py:55
gcn_utils.GCNBlock.__init__
__init__(self, in_channels, out_channels, kernel_size, t_kernel_size=1, t_stride=1, t_padding=0, t_dilation=1, bias=True)
Definition
gcn_utils.py:42
gcn_utils.GCNBlock.kernel_size
kernel_size
Definition
gcn_utils.py:45
gcn_utils.GCNBlock.conv
conv
Definition
gcn_utils.py:46
gcn_utils.TCNBlock
Definition
gcn_utils.py:66
gcn_utils.TCNBlock.__init__
__init__(self, in_channels, kernel_size, stride, dropout=0)
Definition
gcn_utils.py:72
gcn_utils.TCNBlock.tcn
tcn
Definition
gcn_utils.py:76
gcn_utils.TCNBlock.forward
forward(self, x)
Definition
gcn_utils.py:90
gcn_utils.TCNBlock.kernel_size
kernel_size
Definition
gcn_utils.py:75
torch.nn
Generated by
1.10.0