Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
fastreid.layers.context_block.ContextBlock Class Reference
Inheritance diagram for fastreid.layers.context_block.ContextBlock:

Public Member Functions

 __init__ (self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add',))
 
 reset_parameters (self)
 
 spatial_pool (self, x)
 
 forward (self, x)
 

Public Attributes

 inplanes
 
 ratio
 
 planes
 
 pooling_type
 
 fusion_types
 
 conv_mask
 
 softmax
 
 avg_pool
 
 channel_add_conv
 
 channel_mul_conv
 

Detailed Description

Definition at line 20 of file context_block.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.layers.context_block.ContextBlock.__init__ ( self,
inplanes,
ratio,
pooling_type = 'att',
fusion_types = ('channel_add',) )

Definition at line 22 of file context_block.py.

26 fusion_types=('channel_add',)):
27 super(ContextBlock, self).__init__()
28 assert pooling_type in ['avg', 'att']
29 assert isinstance(fusion_types, (list, tuple))
30 valid_fusion_types = ['channel_add', 'channel_mul']
31 assert all([f in valid_fusion_types for f in fusion_types])
32 assert len(fusion_types) > 0, 'at least one fusion should be used'
33 self.inplanes = inplanes
34 self.ratio = ratio
35 self.planes = int(inplanes * ratio)
36 self.pooling_type = pooling_type
37 self.fusion_types = fusion_types
38 if pooling_type == 'att':
39 self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
40 self.softmax = nn.Softmax(dim=2)
41 else:
42 self.avg_pool = nn.AdaptiveAvgPool2d(1)
43 if 'channel_add' in fusion_types:
44 self.channel_add_conv = nn.Sequential(
45 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
46 nn.LayerNorm([self.planes, 1, 1]),
47 nn.ReLU(inplace=True), # yapf: disable
48 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
49 else:
50 self.channel_add_conv = None
51 if 'channel_mul' in fusion_types:
52 self.channel_mul_conv = nn.Sequential(
53 nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
54 nn.LayerNorm([self.planes, 1, 1]),
55 nn.ReLU(inplace=True), # yapf: disable
56 nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
57 else:
58 self.channel_mul_conv = None
59 self.reset_parameters()
60

Member Function Documentation

◆ forward()

fastreid.layers.context_block.ContextBlock.forward ( self,
x )

Definition at line 99 of file context_block.py.

99 def forward(self, x):
100 # [N, C, 1, 1]
101 context = self.spatial_pool(x)
102
103 out = x
104 if self.channel_mul_conv is not None:
105 # [N, C, 1, 1]
106 channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
107 out = out * channel_mul_term
108 if self.channel_add_conv is not None:
109 # [N, C, 1, 1]
110 channel_add_term = self.channel_add_conv(context)
111 out = out + channel_add_term
112
113 return out

◆ reset_parameters()

fastreid.layers.context_block.ContextBlock.reset_parameters ( self)

Definition at line 61 of file context_block.py.

61 def reset_parameters(self):
62 if self.pooling_type == 'att':
63 nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
64 if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
65 nn.init.constant_(self.conv_mask.bias, 0)
66 self.conv_mask.inited = True
67
68 if self.channel_add_conv is not None:
69 last_zero_init(self.channel_add_conv)
70 if self.channel_mul_conv is not None:
71 last_zero_init(self.channel_mul_conv)
72

◆ spatial_pool()

fastreid.layers.context_block.ContextBlock.spatial_pool ( self,
x )

Definition at line 73 of file context_block.py.

73 def spatial_pool(self, x):
74 batch, channel, height, width = x.size()
75 if self.pooling_type == 'att':
76 input_x = x
77 # [N, C, H * W]
78 input_x = input_x.view(batch, channel, height * width)
79 # [N, 1, C, H * W]
80 input_x = input_x.unsqueeze(1)
81 # [N, 1, H, W]
82 context_mask = self.conv_mask(x)
83 # [N, 1, H * W]
84 context_mask = context_mask.view(batch, 1, height * width)
85 # [N, 1, H * W]
86 context_mask = self.softmax(context_mask)
87 # [N, 1, H * W, 1]
88 context_mask = context_mask.unsqueeze(-1)
89 # [N, 1, C, 1]
90 context = torch.matmul(input_x, context_mask)
91 # [N, C, 1, 1]
92 context = context.view(batch, channel, 1, 1)
93 else:
94 # [N, C, 1, 1]
95 context = self.avg_pool(x)
96
97 return context
98

Member Data Documentation

◆ avg_pool

fastreid.layers.context_block.ContextBlock.avg_pool

Definition at line 42 of file context_block.py.

◆ channel_add_conv

fastreid.layers.context_block.ContextBlock.channel_add_conv

Definition at line 44 of file context_block.py.

◆ channel_mul_conv

fastreid.layers.context_block.ContextBlock.channel_mul_conv

Definition at line 52 of file context_block.py.

◆ conv_mask

fastreid.layers.context_block.ContextBlock.conv_mask

Definition at line 39 of file context_block.py.

◆ fusion_types

fastreid.layers.context_block.ContextBlock.fusion_types

Definition at line 37 of file context_block.py.

◆ inplanes

fastreid.layers.context_block.ContextBlock.inplanes

Definition at line 33 of file context_block.py.

◆ planes

fastreid.layers.context_block.ContextBlock.planes

Definition at line 35 of file context_block.py.

◆ pooling_type

fastreid.layers.context_block.ContextBlock.pooling_type

Definition at line 36 of file context_block.py.

◆ ratio

fastreid.layers.context_block.ContextBlock.ratio

Definition at line 34 of file context_block.py.

◆ softmax

fastreid.layers.context_block.ContextBlock.softmax

Definition at line 40 of file context_block.py.


The documentation for this class was generated from the following file: