Safemotion Lib
Loading...
Searching...
No Matches
context_block.py
Go to the documentation of this file.
1# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
2
3import torch
4from torch import nn
5
6__all__ = ['ContextBlock']
7
8
10 if isinstance(m, nn.Sequential):
11 nn.init.constant_(m[-1].weight, val=0)
12 if hasattr(m[-1], 'bias') and m[-1].bias is not None:
13 nn.init.constant_(m[-1].bias, 0)
14 else:
15 nn.init.constant_(m.weight, val=0)
16 if hasattr(m, 'bias') and m.bias is not None:
17 nn.init.constant_(m.bias, 0)
18
19
20class ContextBlock(nn.Module):
21
22 def __init__(self,
23 inplanes,
24 ratio,
25 pooling_type='att',
26 fusion_types=('channel_add',)):
27 super(ContextBlock, self).__init__()
28 assert pooling_type in ['avg', 'att']
29 assert isinstance(fusion_types, (list, tuple))
30 valid_fusion_types = ['channel_add', 'channel_mul']
31 assert all([f in valid_fusion_types for f in fusion_types])
32 assert len(fusion_types) > 0, 'at least one fusion should be used'
33 self.inplanes = inplanes
34 self.ratio = ratio
35 self.planes = int(inplanes * ratio)
36 self.pooling_type = pooling_type
37 self.fusion_types = fusion_types
38 if pooling_type == 'att':
39 self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
40 self.softmax = nn.Softmax(dim=2)
41 else:
42 self.avg_pool = nn.AdaptiveAvgPool2d(1)
43 if 'channel_add' in fusion_types:
44 self.channel_add_conv = nn.Sequential(
45 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
46 nn.LayerNorm([self.planes, 1, 1]),
47 nn.ReLU(inplace=True), # yapf: disable
48 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
49 else:
50 self.channel_add_conv = None
51 if 'channel_mul' in fusion_types:
52 self.channel_mul_conv = nn.Sequential(
53 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
54 nn.LayerNorm([self.planes, 1, 1]),
55 nn.ReLU(inplace=True), # yapf: disable
56 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
57 else:
58 self.channel_mul_conv = None
59 self.reset_parameters()
60
62 if self.pooling_type == 'att':
63 nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
64 if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
65 nn.init.constant_(self.conv_mask.bias, 0)
66 self.conv_mask.inited = True
67
68 if self.channel_add_conv is not None:
70 if self.channel_mul_conv is not None:
72
73 def spatial_pool(self, x):
74 batch, channel, height, width = x.size()
75 if self.pooling_type == 'att':
76 input_x = x
77 # [N, C, H * W]
78 input_x = input_x.view(batch, channel, height * width)
79 # [N, 1, C, H * W]
80 input_x = input_x.unsqueeze(1)
81 # [N, 1, H, W]
82 context_mask = self.conv_mask(x)
83 # [N, 1, H * W]
84 context_mask = context_mask.view(batch, 1, height * width)
85 # [N, 1, H * W]
86 context_mask = self.softmax(context_mask)
87 # [N, 1, H * W, 1]
88 context_mask = context_mask.unsqueeze(-1)
89 # [N, 1, C, 1]
90 context = torch.matmul(input_x, context_mask)
91 # [N, C, 1, 1]
92 context = context.view(batch, channel, 1, 1)
93 else:
94 # [N, C, 1, 1]
95 context = self.avg_pool(x)
96
97 return context
98
99 def forward(self, x):
100 # [N, C, 1, 1]
101 context = self.spatial_pool(x)
102
103 out = x
104 if self.channel_mul_conv is not None:
105 # [N, C, 1, 1]
106 channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
107 out = out * channel_mul_term
108 if self.channel_add_conv is not None:
109 # [N, C, 1, 1]
110 channel_add_term = self.channel_add_conv(context)
111 out = out + channel_add_term
112
113 return out
__init__(self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add',))