94 edge_importance_weighting, input_key, **kwargs):
98 self.
graph = Graph(**graph_args)
99 A = torch.tensor(self.
graph.A, dtype=torch.float32, requires_grad=
False)
100 self.register_buffer(
'A', A)
103 spatial_kernel_size = A.size(0)
104 temporal_kernel_size = 9
105 kernel_size = (temporal_kernel_size, spatial_kernel_size)
106 self.
data_bn = nn.BatchNorm1d(in_channels * A.size(1))
107 kwargs0 = {k: v
for k, v
in kwargs.items()
if k !=
'dropout'}
109 STGCNBlock(in_channels, 64, kernel_size, 1, residual=
False, **kwargs0),
113 STGCNBlock(64, 128, kernel_size, 2, **kwargs),
114 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
115 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
116 STGCNBlock(128, 256, kernel_size, 2, **kwargs),
117 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
118 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
122 if edge_importance_weighting:
124 nn.Parameter(torch.ones(self.A.size()))
143 N, T, V, C = x.size()
144 x = x.permute(0, 2, 3, 1).contiguous()
145 x = x.view(N, V * C, T)
148 x = x.view(N, V, C, T)
149 x = x.permute(0, 2, 3, 1).contiguous()
153 x, _ = gcn(x, self.A * importance)