19 def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
20 dilation=(1, 1), groups=1, bias=
True,
21 radix=2, reduction_factor=4,
22 rectify=
False, rectify_avg=
False, norm_layer=
None, num_splits=1,
23 dropblock_prob=0.0, **kwargs):
25 padding = _pair(padding)
26 self.
rectify = rectify
and (padding[0] > 0
or padding[1] > 0)
28 inter_channels = max(in_channels * radix // reduction_factor, 32)
34 from rfconv
import RFConv2d
35 self.
conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
36 groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
38 self.
conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
39 groups=groups * radix, bias=bias, **kwargs)
42 self.
bn0 = get_norm(norm_layer, channels * radix)
43 self.
relu = ReLU(inplace=
True)
46 self.
bn1 = get_norm(norm_layer, inter_channels)
47 self.
fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.
cardinality)
59 batch, rchannel = x.shape[:2]
61 splited = torch.split(x, rchannel // self.
radix, dim=1)
65 gap = F.adaptive_avg_pool2d(gap, 1)
73 atten = self.
rsoftmax(atten).view(batch, -1, 1, 1)
76 attens = torch.split(atten, rchannel // self.
radix, dim=1)
77 out = sum([att * split
for (att, split)
in zip(attens, splited)])
80 return out.contiguous()