182 ):
183 super(ChannelGate, self).__init__()
184 if num_gates is None: num_gates = in_channels
185 self.return_gates = return_gates
186
187 self.global_avgpool = nn.AdaptiveAvgPool2d(1)
188
189 self.fc1 = nn.Conv2d(
190 in_channels,
191 in_channels // reduction,
192 kernel_size=1,
193 bias=True,
194 padding=0
195 )
196 self.norm1 = None
197 if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
198 self.relu = nn.ReLU(inplace=True)
199 self.fc2 = nn.Conv2d(
200 in_channels // reduction,
201 num_gates,
202 kernel_size=1,
203 bias=True,
204 padding=0
205 )
206 if gate_activation == 'sigmoid':
207 self.gate_activation = nn.Sigmoid()
208 elif gate_activation == 'relu':
209 self.gate_activation = nn.ReLU(inplace=True)
210 elif gate_activation == 'linear':
211 self.gate_activation = nn.Identity()
212 else:
213 raise RuntimeError(
214 "Unknown gate activation: {}".format(gate_activation)
215 )
216