240 ):
241 super(OSBlock, self).__init__()
242 mid_channels = out_channels // bottleneck_reduction
243 self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm)
244 self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm)
245 self.conv2b = nn.Sequential(
246 LightConv3x3(mid_channels, mid_channels, bn_norm),
247 LightConv3x3(mid_channels, mid_channels, bn_norm),
248 )
249 self.conv2c = nn.Sequential(
250 LightConv3x3(mid_channels, mid_channels, bn_norm),
251 LightConv3x3(mid_channels, mid_channels, bn_norm),
252 LightConv3x3(mid_channels, mid_channels, bn_norm),
253 )
254 self.conv2d = nn.Sequential(
255 LightConv3x3(mid_channels, mid_channels, bn_norm),
256 LightConv3x3(mid_channels, mid_channels, bn_norm),
257 LightConv3x3(mid_channels, mid_channels, bn_norm),
258 LightConv3x3(mid_channels, mid_channels, bn_norm),
259 )
260 self.gate = ChannelGate(mid_channels)
261 self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm)
262 self.downsample = None
263 if in_channels != out_channels:
264 self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm)
265 self.IN = None
266 if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
267 self.relu = nn.ReLU(True)
268