94 edge_importance_weighting, input_key, **kwargs):
95 super().__init__()
96
97
98 self.graph = Graph(**graph_args)
99 A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
100 self.register_buffer('A', A)
101
102
103 spatial_kernel_size = A.size(0)
104 temporal_kernel_size = 9
105 kernel_size = (temporal_kernel_size, spatial_kernel_size)
106 self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
107 kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
108 self.st_gcn_networks = nn.ModuleList((
109 STGCNBlock(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
110 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
111 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
112 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
113 STGCNBlock(64, 128, kernel_size, 2, **kwargs),
114 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
115 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
116 STGCNBlock(128, 256, kernel_size, 2, **kwargs),
117 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
118 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
119 ))
120
121
122 if edge_importance_weighting:
123 self.edge_importance = nn.ParameterList([
124 nn.Parameter(torch.ones(self.A.size()))
125 for i in self.st_gcn_networks
126 ])
127 else:
128 self.edge_importance = [1] * len(self.st_gcn_networks)
129
130 self.input_key = input_key
131