26 fusion_types=('channel_add',)):
27 super(ContextBlock, self).__init__()
28 assert pooling_type in ['avg', 'att']
29 assert isinstance(fusion_types, (list, tuple))
30 valid_fusion_types = ['channel_add', 'channel_mul']
31 assert all([f in valid_fusion_types for f in fusion_types])
32 assert len(fusion_types) > 0, 'at least one fusion should be used'
33 self.inplanes = inplanes
34 self.ratio = ratio
35 self.planes = int(inplanes * ratio)
36 self.pooling_type = pooling_type
37 self.fusion_types = fusion_types
38 if pooling_type == 'att':
39 self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
40 self.softmax = nn.Softmax(dim=2)
41 else:
42 self.avg_pool = nn.AdaptiveAvgPool2d(1)
43 if 'channel_add' in fusion_types:
44 self.channel_add_conv = nn.Sequential(
45 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
46 nn.LayerNorm([self.planes, 1, 1]),
47 nn.ReLU(inplace=True),
48 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
49 else:
50 self.channel_add_conv = None
51 if 'channel_mul' in fusion_types:
52 self.channel_mul_conv = nn.Sequential(
53 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
54 nn.LayerNorm([self.planes, 1, 1]),
55 nn.ReLU(inplace=True),
56 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
57 else:
58 self.channel_mul_conv = None
59 self.reset_parameters()
60