Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
fastreid.layers.splat.SplAtConv2d Class Reference
Inheritance diagram for fastreid.layers.splat.SplAtConv2d:

Public Member Functions

 __init__ (self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, bias=True, radix=2, reduction_factor=4, rectify=False, rectify_avg=False, norm_layer=None, num_splits=1, dropblock_prob=0.0, **kwargs)
 
 forward (self, x)
 

Public Attributes

 rectify
 
 rectify_avg
 
 radix
 
 cardinality
 
 channels
 
 dropblock_prob
 
 conv
 
 use_bn
 
 bn0
 
 relu
 
 fc1
 
 bn1
 
 fc2
 
 rsoftmax
 

Detailed Description

Split-Attention Conv2d

Definition at line 15 of file splat.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.layers.splat.SplAtConv2d.__init__ ( self,
in_channels,
channels,
kernel_size,
stride = (1, 1),
padding = (0, 0),
dilation = (1, 1),
groups = 1,
bias = True,
radix = 2,
reduction_factor = 4,
rectify = False,
rectify_avg = False,
norm_layer = None,
num_splits = 1,
dropblock_prob = 0.0,
** kwargs )

Definition at line 19 of file splat.py.

23 dropblock_prob=0.0, **kwargs):
24 super(SplAtConv2d, self).__init__()
25 padding = _pair(padding)
26 self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
27 self.rectify_avg = rectify_avg
28 inter_channels = max(in_channels * radix // reduction_factor, 32)
29 self.radix = radix
30 self.cardinality = groups
31 self.channels = channels
32 self.dropblock_prob = dropblock_prob
33 if self.rectify:
34 from rfconv import RFConv2d
35 self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
36 groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
37 else:
38 self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
39 groups=groups * radix, bias=bias, **kwargs)
40 self.use_bn = norm_layer is not None
41 if self.use_bn:
42 self.bn0 = get_norm(norm_layer, channels * radix)
43 self.relu = ReLU(inplace=True)
44 self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
45 if self.use_bn:
46 self.bn1 = get_norm(norm_layer, inter_channels)
47 self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
48
49 self.rsoftmax = rSoftMax(radix, groups)
50

Member Function Documentation

◆ forward()

fastreid.layers.splat.SplAtConv2d.forward ( self,
x )

Definition at line 51 of file splat.py.

51 def forward(self, x):
52 x = self.conv(x)
53 if self.use_bn:
54 x = self.bn0(x)
55 if self.dropblock_prob > 0.0:
56 x = self.dropblock(x)
57 x = self.relu(x)
58
59 batch, rchannel = x.shape[:2]
60 if self.radix > 1:
61 splited = torch.split(x, rchannel // self.radix, dim=1)
62 gap = sum(splited)
63 else:
64 gap = x
65 gap = F.adaptive_avg_pool2d(gap, 1)
66 gap = self.fc1(gap)
67
68 if self.use_bn:
69 gap = self.bn1(gap)
70 gap = self.relu(gap)
71
72 atten = self.fc2(gap)
73 atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
74
75 if self.radix > 1:
76 attens = torch.split(atten, rchannel // self.radix, dim=1)
77 out = sum([att * split for (att, split) in zip(attens, splited)])
78 else:
79 out = atten * x
80 return out.contiguous()
81
82

Member Data Documentation

◆ bn0

fastreid.layers.splat.SplAtConv2d.bn0

Definition at line 42 of file splat.py.

◆ bn1

fastreid.layers.splat.SplAtConv2d.bn1

Definition at line 46 of file splat.py.

◆ cardinality

fastreid.layers.splat.SplAtConv2d.cardinality

Definition at line 30 of file splat.py.

◆ channels

fastreid.layers.splat.SplAtConv2d.channels

Definition at line 31 of file splat.py.

◆ conv

fastreid.layers.splat.SplAtConv2d.conv

Definition at line 35 of file splat.py.

◆ dropblock_prob

fastreid.layers.splat.SplAtConv2d.dropblock_prob

Definition at line 32 of file splat.py.

◆ fc1

fastreid.layers.splat.SplAtConv2d.fc1

Definition at line 44 of file splat.py.

◆ fc2

fastreid.layers.splat.SplAtConv2d.fc2

Definition at line 47 of file splat.py.

◆ radix

fastreid.layers.splat.SplAtConv2d.radix

Definition at line 29 of file splat.py.

◆ rectify

fastreid.layers.splat.SplAtConv2d.rectify

Definition at line 26 of file splat.py.

◆ rectify_avg

fastreid.layers.splat.SplAtConv2d.rectify_avg

Definition at line 27 of file splat.py.

◆ relu

fastreid.layers.splat.SplAtConv2d.relu

Definition at line 43 of file splat.py.

◆ rsoftmax

fastreid.layers.splat.SplAtConv2d.rsoftmax

Definition at line 49 of file splat.py.

◆ use_bn

fastreid.layers.splat.SplAtConv2d.use_bn

Definition at line 40 of file splat.py.


The documentation for this class was generated from the following file: