Safemotion Lib
Loading...
Searching...
No Matches
stgcn.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3
4from smaction.utils.gcn_utils import GCNBlock, TCNBlock
5from smaction.utils.graph import Graph
6
7class STGCNBlock(nn.Module):
8 r"""Applies a spatial temporal graph convolution over an input graph sequence.
9
10 Args:
11 in_channels (int): Number of channels in the input sequence data
12 out_channels (int): Number of channels produced by the convolution
13 kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
14 stride (int, optional): Stride of the temporal convolution. Default: 1
15 dropout (int, optional): Dropout rate of the final output. Default: 0
16 residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
17
18 Shape:
19 - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
20 - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
21 - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
22 - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
23
24 where
25 :math:`N` is a batch size,
26 :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
27 :math:`T_{in}/T_{out}` is a length of input/output sequence,
28 :math:`V` is the number of graph nodes.
29
30 """
31
32 def __init__(self,
33 in_channels,
34 out_channels,
35 kernel_size,
36 stride=1,
37 dropout=0,
38 residual=True):
39 super().__init__()
40
41 assert len(kernel_size) == 2
42 assert kernel_size[0] % 2 == 1
43
44 self.gcn = GCNBlock(in_channels, out_channels, kernel_size[1])
45 self.tcn = TCNBlock(out_channels, kernel_size[0], stride, dropout)
46
47 if not residual:
48 self.residual = lambda x: 0
49
50 elif (in_channels == out_channels) and (stride == 1):
51 self.residual = lambda x: x
52
53 else:
54 self.residual = nn.Sequential(
55 nn.Conv2d(
56 in_channels,
57 out_channels,
58 kernel_size=1,
59 stride=(stride, 1)),
60 nn.BatchNorm2d(out_channels),
61 )
62
63 self.relu = nn.ReLU(inplace=True)
64
65 def forward(self, x, A):
66
67 res = self.residual(x)
68 x, A = self.gcn(x, A)
69 x = self.tcn(x) + res
70
71 return self.relu(x), A
72
73class STGCN(nn.Module):
74 r"""Spatial temporal graph convolutional networks.
75
76 Args:
77 in_channels (int): Number of channels in the input data
78 num_class (int): Number of classes for the classification task
79 graph_args (dict): The arguments for building the graph
80 edge_importance_weighting (bool): If ``True``, adds a learnable
81 importance weighting to the edges of the graph
82 **kwargs (optional): Other parameters for graph convolution units
83
84 Shape:
85 - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
86 - Output: :math:`(N, num_class)` where
87 :math:`N` is a batch size,
88 :math:`T_{in}` is a length of input sequence,
89 :math:`V_{in}` is the number of graph nodes,
90 :math:`M_{in}` is the number of instance in a frame.
91 """
92
93 def __init__(self, in_channels, graph_args,
94 edge_importance_weighting, input_key, **kwargs):
95 super().__init__()
96
97 # load graph
98 self.graph = Graph(**graph_args)
99 A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
100 self.register_buffer('A', A)
101
102 # build networks
103 spatial_kernel_size = A.size(0)
104 temporal_kernel_size = 9
105 kernel_size = (temporal_kernel_size, spatial_kernel_size)
106 self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
107 kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
108 self.st_gcn_networks = nn.ModuleList((
109 STGCNBlock(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
110 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
111 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
112 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
113 STGCNBlock(64, 128, kernel_size, 2, **kwargs),
114 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
115 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
116 STGCNBlock(128, 256, kernel_size, 2, **kwargs),
117 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
118 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
119 ))
120
121 # initialize parameters for edge importance weighting
122 if edge_importance_weighting:
123 self.edge_importance = nn.ParameterList([
124 nn.Parameter(torch.ones(self.A.size()))
125 for i in self.st_gcn_networks
126 ])
127 else:
128 self.edge_importance = [1] * len(self.st_gcn_networks)
129
130 self.input_key = input_key
131
132 def forward(self, sample_dict):
133 x = sample_dict[self.input_key]
134 # data normalization
135 # N, C, T, V, M = x.size() # -> N, M, V, C, T
136 # x = x.permute(0, 4, 3, 1, 2).contiguous()
137 # x = x.view(N * M, V * C, T)
138 # x = self.data_bn(x)
139 # x = x.view(N, M, V, C, T) #-> N, M, C, T, V
140 # x = x.permute(0, 1, 3, 4, 2).contiguous()
141 # x = x.view(N * M, C, T, V)
142
143 N, T, V, C = x.size() # -> N, V, C, T
144 x = x.permute(0, 2, 3, 1).contiguous()
145 x = x.view(N, V * C, T)
146
147 x = self.data_bn(x)
148 x = x.view(N, V, C, T)
149 x = x.permute(0, 2, 3, 1).contiguous() # N, C, T, V
150
151 # forwad
152 for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
153 x, _ = gcn(x, self.A * importance)
154
155 # x = x.reshape((N, M) + x.shape[1:])
156 return x
forward(self, x, A)
Definition stgcn.py:65
__init__(self, in_channels, out_channels, kernel_size, stride=1, dropout=0, residual=True)
Definition stgcn.py:38
__init__(self, in_channels, graph_args, edge_importance_weighting, input_key, **kwargs)
Definition stgcn.py:94
st_gcn_networks
Definition stgcn.py:108
forward(self, sample_dict)
Definition stgcn.py:132
edge_importance
Definition stgcn.py:123