Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
stgcn.STGCN Class Reference
Inheritance diagram for stgcn.STGCN:

Public Member Functions

 __init__ (self, in_channels, graph_args, edge_importance_weighting, input_key, **kwargs)
 
 forward (self, sample_dict)
 

Public Attributes

 graph
 
 data_bn
 
 st_gcn_networks
 
 edge_importance
 
 input_key
 

Detailed Description

Spatial temporal graph convolutional networks.

Args:
    in_channels (int): Number of channels in the input data
    num_class (int): Number of classes for the classification task
    graph_args (dict): The arguments for building the graph
    edge_importance_weighting (bool): If ``True``, adds a learnable
        importance weighting to the edges of the graph
    **kwargs (optional): Other parameters for graph convolution units

Shape:
    - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
    - Output: :math:`(N, num_class)` where
        :math:`N` is a batch size,
        :math:`T_{in}` is a length of input sequence,
        :math:`V_{in}` is the number of graph nodes,
        :math:`M_{in}` is the number of instance in a frame.

Definition at line 73 of file stgcn.py.

Constructor & Destructor Documentation

◆ __init__()

stgcn.STGCN.__init__ ( self,
in_channels,
graph_args,
edge_importance_weighting,
input_key,
** kwargs )

Definition at line 93 of file stgcn.py.

94 edge_importance_weighting, input_key, **kwargs):
95 super().__init__()
96
97 # load graph
98 self.graph = Graph(**graph_args)
99 A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
100 self.register_buffer('A', A)
101
102 # build networks
103 spatial_kernel_size = A.size(0)
104 temporal_kernel_size = 9
105 kernel_size = (temporal_kernel_size, spatial_kernel_size)
106 self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
107 kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
108 self.st_gcn_networks = nn.ModuleList((
109 STGCNBlock(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
110 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
111 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
112 STGCNBlock(64, 64, kernel_size, 1, **kwargs),
113 STGCNBlock(64, 128, kernel_size, 2, **kwargs),
114 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
115 STGCNBlock(128, 128, kernel_size, 1, **kwargs),
116 STGCNBlock(128, 256, kernel_size, 2, **kwargs),
117 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
118 STGCNBlock(256, 256, kernel_size, 1, **kwargs),
119 ))
120
121 # initialize parameters for edge importance weighting
122 if edge_importance_weighting:
123 self.edge_importance = nn.ParameterList([
124 nn.Parameter(torch.ones(self.A.size()))
125 for i in self.st_gcn_networks
126 ])
127 else:
128 self.edge_importance = [1] * len(self.st_gcn_networks)
129
130 self.input_key = input_key
131

Member Function Documentation

◆ forward()

stgcn.STGCN.forward ( self,
sample_dict )

Definition at line 132 of file stgcn.py.

132 def forward(self, sample_dict):
133 x = sample_dict[self.input_key]
134 # data normalization
135 # N, C, T, V, M = x.size() # -> N, M, V, C, T
136 # x = x.permute(0, 4, 3, 1, 2).contiguous()
137 # x = x.view(N * M, V * C, T)
138 # x = self.data_bn(x)
139 # x = x.view(N, M, V, C, T) #-> N, M, C, T, V
140 # x = x.permute(0, 1, 3, 4, 2).contiguous()
141 # x = x.view(N * M, C, T, V)
142
143 N, T, V, C = x.size() # -> N, V, C, T
144 x = x.permute(0, 2, 3, 1).contiguous()
145 x = x.view(N, V * C, T)
146
147 x = self.data_bn(x)
148 x = x.view(N, V, C, T)
149 x = x.permute(0, 2, 3, 1).contiguous() # N, C, T, V
150
151 # forwad
152 for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
153 x, _ = gcn(x, self.A * importance)
154
155 # x = x.reshape((N, M) + x.shape[1:])
156 return x

Member Data Documentation

◆ data_bn

stgcn.STGCN.data_bn

Definition at line 106 of file stgcn.py.

◆ edge_importance

stgcn.STGCN.edge_importance

Definition at line 123 of file stgcn.py.

◆ graph

stgcn.STGCN.graph

Definition at line 98 of file stgcn.py.

◆ input_key

stgcn.STGCN.input_key

Definition at line 130 of file stgcn.py.

◆ st_gcn_networks

stgcn.STGCN.st_gcn_networks

Definition at line 108 of file stgcn.py.


The documentation for this class was generated from the following file: