19 inflate_style: str = '3x1x1') -> None:
20 super().__init__()
21
22 self.inplanes = inplanes
23 self.planes = planes
24 self.spatial_stride = spatial_stride
25 self.temporal_stride = temporal_stride
26 self.dilation = dilation
27
28 self.inflate = inflate
29 self.inflate_style = inflate_style
30
31 self.conv1_stride_s = 1
32 self.conv2_stride_s = spatial_stride
33 self.conv1_stride_t = 1
34 self.conv2_stride_t = temporal_stride
35
36 conv1_stride = (self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s)
37 conv2_stride = (self.conv2_stride_t, self.conv2_stride_s, self.conv2_stride_s)
38
39 if self.inflate:
40 if inflate_style == '3x1x1':
41 conv1_kernel_size = (3, 1, 1)
42 conv1_padding = (1, 0, 0)
43 conv2_kernel_size = (1, 3, 3)
44 conv2_padding = (0, dilation, dilation)
45 else:
46 conv1_kernel_size = (1, 1, 1)
47 conv1_padding = (0, 0, 0)
48 conv2_kernel_size = (3, 3, 3)
49 conv2_padding = (1, dilation, dilation)
50 else:
51 conv1_kernel_size = (1, 1, 1)
52 conv1_padding = (0, 0, 0)
53 conv2_kernel_size = (1, 3, 3)
54 conv2_padding = (0, dilation, dilation)
55
56
57 self.conv1 = nn.Sequential(nn.Conv3d(inplanes, planes, kernel_size=conv1_kernel_size, stride=conv1_stride, padding=conv1_padding, bias=False),
58 nn.BatchNorm3d(planes),
59 nn.ReLU(inplace=True))
60
61 self.conv2 = nn.Sequential(nn.Conv3d(planes, planes, kernel_size=conv2_kernel_size, stride=conv2_stride, padding=conv2_padding, bias=False),
62 nn.BatchNorm3d(planes),
63 nn.ReLU(inplace=True))
64
65 self.conv3 = nn.Sequential(nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
66 nn.BatchNorm3d(planes * self.expansion))
67
68 self.downsample = downsample
69 self.relu = nn.ReLU(inplace=True)
70
71