Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
stgcn.STGCNBlock Class Reference
Inheritance diagram for stgcn.STGCNBlock:

Public Member Functions

 __init__ (self, in_channels, out_channels, kernel_size, stride=1, dropout=0, residual=True)
 
 forward (self, x, A)
 

Public Attributes

 gcn
 
 tcn
 
 residual
 
 relu
 

Detailed Description

Applies a spatial temporal graph convolution over an input graph sequence.

Args:
    in_channels (int): Number of channels in the input sequence data
    out_channels (int): Number of channels produced by the convolution
    kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
    stride (int, optional): Stride of the temporal convolution. Default: 1
    dropout (int, optional): Dropout rate of the final output. Default: 0
    residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``

Shape:
    - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
    - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
    - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
    - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format

    where
        :math:`N` is a batch size,
        :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
        :math:`T_{in}/T_{out}` is a length of input/output sequence,
        :math:`V` is the number of graph nodes.

Definition at line 7 of file stgcn.py.

Constructor & Destructor Documentation

◆ __init__()

stgcn.STGCNBlock.__init__ ( self,
in_channels,
out_channels,
kernel_size,
stride = 1,
dropout = 0,
residual = True )

Definition at line 32 of file stgcn.py.

38 residual=True):
39 super().__init__()
40
41 assert len(kernel_size) == 2
42 assert kernel_size[0] % 2 == 1
43
44 self.gcn = GCNBlock(in_channels, out_channels, kernel_size[1])
45 self.tcn = TCNBlock(out_channels, kernel_size[0], stride, dropout)
46
47 if not residual:
48 self.residual = lambda x: 0
49
50 elif (in_channels == out_channels) and (stride == 1):
51 self.residual = lambda x: x
52
53 else:
54 self.residual = nn.Sequential(
55 nn.Conv2d(
56 in_channels,
57 out_channels,
58 kernel_size=1,
59 stride=(stride, 1)),
60 nn.BatchNorm2d(out_channels),
61 )
62
63 self.relu = nn.ReLU(inplace=True)
64

Member Function Documentation

◆ forward()

stgcn.STGCNBlock.forward ( self,
x,
A )

Definition at line 65 of file stgcn.py.

65 def forward(self, x, A):
66
67 res = self.residual(x)
68 x, A = self.gcn(x, A)
69 x = self.tcn(x) + res
70
71 return self.relu(x), A

Member Data Documentation

◆ gcn

stgcn.STGCNBlock.gcn

Definition at line 44 of file stgcn.py.

◆ relu

stgcn.STGCNBlock.relu

Definition at line 63 of file stgcn.py.

◆ residual

stgcn.STGCNBlock.residual

Definition at line 48 of file stgcn.py.

◆ tcn

stgcn.STGCNBlock.tcn

Definition at line 45 of file stgcn.py.


The documentation for this class was generated from the following file: