Safemotion Lib
Loading...
Searching...
No Matches
gcn_utils.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3
4class GCNBlock(nn.Module):
5
6 r"""The basic module for applying a graph convolution.
7
8 Args:
9 in_channels (int): Number of channels in the input sequence data
10 out_channels (int): Number of channels produced by the convolution
11 kernel_size (int): Size of the graph convolving kernel
12 t_kernel_size (int): Size of the temporal convolving kernel
13 t_stride (int, optional): Stride of the temporal convolution. Default: 1
14 t_padding (int, optional): Temporal zero-padding added to both sides of
15 the input. Default: 0
16 t_dilation (int, optional): Spacing between temporal kernel elements.
17 Default: 1
18 bias (bool, optional): If ``True``, adds a learnable bias to the output.
19 Default: ``True``
20
21 Shape:
22 - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
23 - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
24 - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
25 - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
26
27 where
28 :math:`N` is a batch size,
29 :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
30 :math:`T_{in}/T_{out}` is a length of input/output sequence,
31 :math:`V` is the number of graph nodes.
32 """
33
34 def __init__(self,
35 in_channels,
36 out_channels,
37 kernel_size,
38 t_kernel_size=1,
39 t_stride=1,
40 t_padding=0,
41 t_dilation=1,
42 bias=True):
43 super().__init__()
44
45 self.kernel_size = kernel_size
46 self.conv = nn.Conv2d(
47 in_channels,
48 out_channels * kernel_size,
49 kernel_size=(t_kernel_size, 1),
50 padding=(t_padding, 0),
51 stride=(t_stride, 1),
52 dilation=(t_dilation, 1),
53 bias=bias)
54
55 def forward(self, x, A):
56 assert A.size(0) == self.kernel_size
57
58 x = self.conv(x)
59
60 n, kc, t, v = x.size()
61 x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
62 x = torch.einsum('nkctv,kvw->nctw', (x, A))
63
64 return x.contiguous(), A
65
66class TCNBlock(nn.Module):
67
68 def __init__(self,
69 in_channels,
70 kernel_size,
71 stride,
72 dropout=0):
73 super().__init__()
74 padding = ((kernel_size - 1) // 2, 0)
75 self.kernel_size = kernel_size
76 self.tcn = nn.Sequential(
77 nn.BatchNorm2d(in_channels),
78 nn.ReLU(inplace=True),
79 nn.Conv2d(
80 in_channels,
81 in_channels,
82 (kernel_size, 1),
83 (stride, 1),
84 padding,
85 ),
86 nn.BatchNorm2d(in_channels),
87 nn.Dropout(dropout, inplace=True),
88 )
89
90 def forward(self, x):
91 x = self.tcn(x)
92 return x
93
forward(self, x, A)
Definition gcn_utils.py:55
__init__(self, in_channels, out_channels, kernel_size, t_kernel_size=1, t_stride=1, t_padding=0, t_dilation=1, bias=True)
Definition gcn_utils.py:42
__init__(self, in_channels, kernel_size, stride, dropout=0)
Definition gcn_utils.py:72