10 def __init__(self, in_channels, bn_norm, reduc_ratio=2):
17 kernel_size=1, stride=1, padding=0)
19 self.
W = nn.Sequential(
21 kernel_size=1, stride=1, padding=0),
24 nn.init.constant_(self.
W[1].weight, 0.0)
25 nn.init.constant_(self.
W[1].bias, 0.0)
28 kernel_size=1, stride=1, padding=0)
31 kernel_size=1, stride=1, padding=0)
35 :param x: (b, t, h, w)
36 :return x: (b, t, h, w)
38 batch_size = x.size(0)
40 g_x = g_x.permute(0, 2, 1)
43 theta_x = theta_x.permute(0, 2, 1)
45 f = torch.matmul(theta_x, phi_x)
49 y = torch.matmul(f_div_C, g_x)
50 y = y.permute(0, 2, 1).contiguous()