53 dropblock_prob=0.0, last_gamma=False):
54 super(Bottleneck, self).__init__()
55 group_width = int(planes * (bottleneck_width / 64.)) * cardinality
56 self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
57 if with_ibn:
58 self.bn1 = IBN(group_width, bn_norm)
59 else:
60 self.bn1 = get_norm(bn_norm, group_width)
61 self.dropblock_prob = dropblock_prob
62 self.radix = radix
63 self.avd = avd and (stride > 1 or is_first)
64 self.avd_first = avd_first
65
66 if self.avd:
67 self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
68 stride = 1
69
70 if radix > 1:
71 self.conv2 = SplAtConv2d(
72 group_width, group_width, kernel_size=3,
73 stride=stride, padding=dilation,
74 dilation=dilation, groups=cardinality, bias=False,
75 radix=radix, rectify=rectified_conv,
76 rectify_avg=rectify_avg,
77 norm_layer=bn_norm,
78 dropblock_prob=dropblock_prob)
79 elif rectified_conv:
80 from rfconv import RFConv2d
81 self.conv2 = RFConv2d(
82 group_width, group_width, kernel_size=3, stride=stride,
83 padding=dilation, dilation=dilation,
84 groups=cardinality, bias=False,
85 average_mode=rectify_avg)
86 self.bn2 = get_norm(bn_norm, group_width)
87 else:
88 self.conv2 = nn.Conv2d(
89 group_width, group_width, kernel_size=3, stride=stride,
90 padding=dilation, dilation=dilation,
91 groups=cardinality, bias=False)
92 self.bn2 = get_norm(bn_norm, group_width)
93
94 self.conv3 = nn.Conv2d(
95 group_width, planes * 4, kernel_size=1, bias=False)
96 self.bn3 = get_norm(bn_norm, planes * 4)
97
98 if last_gamma:
99 from torch.nn.init import zeros_
100 zeros_(self.bn3.weight)
101 self.relu = nn.ReLU(inplace=True)
102 self.downsample = downsample
103 self.dilation = dilation
104 self.stride = stride
105