10 def __init__(self, in_channels, bn_norm, reduc_ratio=2):
11 super(Non_local, self).__init__()
12
13 self.in_channels = in_channels
14 self.inter_channels = in_channels // reduc_ratio
15
16 self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
17 kernel_size=1, stride=1, padding=0)
18
19 self.W = nn.Sequential(
20 nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
21 kernel_size=1, stride=1, padding=0),
22 get_norm(bn_norm, self.in_channels),
23 )
24 nn.init.constant_(self.W[1].weight, 0.0)
25 nn.init.constant_(self.W[1].bias, 0.0)
26
27 self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
28 kernel_size=1, stride=1, padding=0)
29
30 self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
31 kernel_size=1, stride=1, padding=0)
32