Safemotion Lib
Loading...
Searching...
No Matches
resnet3d.py
Go to the documentation of this file.
1from typing import Dict, List, Optional, Sequence, Tuple, Union
2
3import torch
4import torch.nn as nn
5from torch.nn.modules.utils import _ntuple, _triple
6
7class Bottleneck3d(nn.Module):
8
9 expansion = 4
10
11 def __init__(self,
12 inplanes: int,
13 planes: int,
14 spatial_stride: int = 1,
15 temporal_stride: int = 1,
16 dilation: int = 1,
17 downsample: Optional[nn.Module] = None,
18 inflate: bool = True,
19 inflate_style: str = '3x1x1') -> None:
20 super().__init__()
21
22 self.inplanes = inplanes
23 self.planes = planes
24 self.spatial_stride = spatial_stride
25 self.temporal_stride = temporal_stride
26 self.dilation = dilation
27
28 self.inflate = inflate
29 self.inflate_style = inflate_style
30
32 self.conv2_stride_s = spatial_stride
34 self.conv2_stride_t = temporal_stride
35
36 conv1_stride = (self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s)
37 conv2_stride = (self.conv2_stride_t, self.conv2_stride_s, self.conv2_stride_s)
38
39 if self.inflate:
40 if inflate_style == '3x1x1':
41 conv1_kernel_size = (3, 1, 1)
42 conv1_padding = (1, 0, 0)
43 conv2_kernel_size = (1, 3, 3)
44 conv2_padding = (0, dilation, dilation)
45 else:
46 conv1_kernel_size = (1, 1, 1)
47 conv1_padding = (0, 0, 0)
48 conv2_kernel_size = (3, 3, 3)
49 conv2_padding = (1, dilation, dilation)
50 else:
51 conv1_kernel_size = (1, 1, 1)
52 conv1_padding = (0, 0, 0)
53 conv2_kernel_size = (1, 3, 3)
54 conv2_padding = (0, dilation, dilation)
55
56
57 self.conv1 = nn.Sequential(nn.Conv3d(inplanes, planes, kernel_size=conv1_kernel_size, stride=conv1_stride, padding=conv1_padding, bias=False),
58 nn.BatchNorm3d(planes),
59 nn.ReLU(inplace=True))
60
61 self.conv2 = nn.Sequential(nn.Conv3d(planes, planes, kernel_size=conv2_kernel_size, stride=conv2_stride, padding=conv2_padding, bias=False),
62 nn.BatchNorm3d(planes),
63 nn.ReLU(inplace=True))
64
65 self.conv3 = nn.Sequential(nn.Conv3d(planes, planes * self.expansionexpansion, kernel_size=1, bias=False),
66 nn.BatchNorm3d(planes * self.expansionexpansion))
67
68 self.downsample = downsample
69 self.relu = nn.ReLU(inplace=True)
70
71
72 def forward(self, x):
73
74 identity = x
75
76 out = self.conv1(x)
77 out = self.conv2(out)
78 out = self.conv3(out)
79
80 if self.downsample is not None:
81 identity = self.downsample(x)
82
83 out = out + identity
84
85 out = self.relu(out)
86
87 return out
88
89
90class ResNet3d(nn.Module):
91 """
92 3D conv기반의 Resnet
93 mmaction2의 slowonly를 기반으로함
94 """
95 def __init__(self,
96 in_channels: int = 17,
97 base_channels: int = 64,
98 stage_blocks: Optional[Tuple] = (4, 6, 3),
99 out_indices: Sequence[int] = (2, ),
100 spatial_strides: Sequence[int] = (2, 2, 2),
101 temporal_strides: Sequence[int] = (1, 1, 2),
102 dilations: Sequence[int] = (1, 1, 1),
103 conv1_kernel: Sequence[int] = (1, 7, 7),
104 conv1_stride_s: int = 1,
105 conv1_stride_t: int = 1,
106 pool1_stride_s: int = 1,
107 pool1_stride_t: int = 1,
108 inflate: Sequence[int] = (0, 1, 1),
109 inflate_style: str = '3x1x1',
110 input_key = 'keypoint_heatmap',
111 **kwargs) -> None:
112 super().__init__()
113 """
114 args:
115 in_channels (int) : 입력 채널
116 base_channels (int) : 초기 conv의 출력 채널, base_channels의 배수로 conv 블럭들의 출력 채널이 결정됨
117 stage_blocks (tuple) : 각 스테이지에서 반복하는 블럭수
118 out_indices (Sequence[int]) : 출력 하려는 특징 스테이지 인덱스
119 spatial_strides (Sequence[int]) : 각 스테이지별 공간축 stride
120 temporal_strides (Sequence[int]) : 각 스테이지별 시간축 stride
121 dilations (Sequence[int]) : 각 스테이지별 dilation
122 conv1_kernel (Sequence[int]) : 초기 conv의 커널 shape
123 conv1_stride_s (int) : 초기 conv의 공간축 stride
124 conv1_stride_t (int) : 초기 conv의 시간축 stride
125 pool1_stride_s (int) : 초기 conv이후의 풀링 레이어의 공간축 stride
126 pool1_stride_t (int) : 초기 conv이후의 풀링 레이어의 시간축 stride
127 inflate (Sequence[int]) : 블럭의 conv 커널 타입 설정
128 inflate_style (str) : 블럭의 conv 커널 타입 설정
129 TODO : inflate_style을 추가해서 커널 종류를 다양하게 사용해 볼 수 있음
130 input_key (str) : 모듈의 inference에서 사용하는 입력데이터의 키값
131 """
132 self.input_key = input_key
133
134 self.in_channels = in_channels
135 self.base_channels = base_channels
136 self.num_stages = len(stage_blocks)
137 self.stage_blocks = stage_blocks
138 self.out_indices = out_indices
139 assert max(out_indices) < self.num_stages
140 self.spatial_strides = spatial_strides
141 self.temporal_strides = temporal_strides
142 self.dilations = dilations
143 assert len(spatial_strides) == len(temporal_strides) == len(
144 dilations) == self.num_stages
145
146 self.conv1_kernel = conv1_kernel
147 self.conv1_stride_s = conv1_stride_s
148 self.conv1_stride_t = conv1_stride_t
149 self.pool1_stride_s = pool1_stride_s
150 self.pool1_stride_t = pool1_stride_t
151 self.stage_inflations = inflate
152 self.inflate_style = inflate_style
153
154
155 self.block = Bottleneck3d
156
158
160 self.conv1_padding = tuple([(k - 1) // 2 for k in _triple(self.conv1_kernel)])
162
163 self._make_stem_layer()
164
165 self.res_layers = []
166 for i, num_blocks in enumerate(self.stage_blocks):
167 spatial_stride = spatial_strides[i]
168 temporal_stride = temporal_strides[i]
169 dilation = dilations[i]
170 planes = self.base_channels * 2**i
171 res_layer = self.make_res_layer(
172 self.block,
173 self.inplanes,
174 planes,
175 num_blocks,
176 spatial_stride=spatial_stride,
177 temporal_stride=temporal_stride,
178 dilation=dilation,
179 inflate=self.stage_inflations[i],
180 inflate_style=self.inflate_style,
181 **kwargs)
182 self.inplanes = planes * self.block.expansion
183 layer_name = f'layer{i + 1}'
184 self.add_module(layer_name, res_layer)
185 self.res_layers.append(layer_name)
186
187 # self.feat_dim = self.block.expansion * \
188 # self.base_channels * 2 ** (len(self.stage_blocks) - 1)
189
190 @staticmethod
191 def make_res_layer(block: nn.Module,
192 inplanes: int,
193 planes: int,
194 blocks: int,
195 spatial_stride: Union[int, Sequence[int]] = 1,
196 temporal_stride: Union[int, Sequence[int]] = 1,
197 dilation: int = 1,
198 inflate: Union[int, Sequence[int]] = 1,
199 inflate_style: str = '3x1x1',
200 **kwargs) -> nn.Module:
201
202 inflate = inflate if not isinstance(inflate, int) \
203 else (inflate,) * blocks
204
205 downsample = None
206 if spatial_stride != 1 or inplanes != planes * block.expansion:
207 stride = (temporal_stride, spatial_stride, spatial_stride)
208 downsample = nn.Sequential(nn.Conv3d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
209 nn.BatchNorm3d(planes * block.expansion))
210
211 layers = []
212 layers.append(
213 block(
214 inplanes,
215 planes,
216 spatial_stride=spatial_stride,
217 temporal_stride=temporal_stride,
218 dilation=dilation,
219 downsample=downsample,
220 inflate=(inflate[0] == 1),
221 inflate_style=inflate_style,
222 **kwargs))
223 inplanes = planes * block.expansion
224 for i in range(1, blocks):
225 layers.append(
226 block(
227 inplanes,
228 planes,
229 spatial_stride=1,
230 temporal_stride=1,
231 dilation=dilation,
232 inflate=(inflate[i] == 1),
233 inflate_style=inflate_style,
234 **kwargs))
235
236 return nn.Sequential(*layers)
237
238
239 def _make_stem_layer(self) -> None:
240
241 self.conv1 = nn.Sequential(nn.Conv3d(self.in_channels, self.base_channels, kernel_size=self.conv1_kernel, stride=self.conv1_stride, padding=self.conv1_padding, bias=False),
242 nn.BatchNorm3d(self.base_channels),
243 nn.ReLU(inplace=True))
244
245 self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=self.pool1_stride, padding=(0, 1, 1))
246
247 def forward(self, sample_dict):
248 """
249 args:
250 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
251 self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W)
252 B : 배치 크기
253 C : 입력 채널
254 T : 시간
255
256 return (Tensor):
257 특정 해상도의 특징 벡터 -> shape (B, C_o, T_o, H_o, W_o)
258 B : 배치 크기
259 C_o : 채널
260 T_o : 시간
261 """
262 x = sample_dict[self.input_key]
263
264 x = self.conv1(x)
265 x = self.maxpool(x)
266
267 outs = []
268 for i, layer_name in enumerate(self.res_layers):
269 res_layer = getattr(self, layer_name)
270 x = res_layer(x)
271
272 if i in self.out_indices:
273 outs.append(x)
274
275 if len(outs) == 1:
276 return outs[0]
277
278 return tuple(outs)
None __init__(self, int inplanes, int planes, int spatial_stride=1, int temporal_stride=1, int dilation=1, Optional[nn.Module] downsample=None, bool inflate=True, str inflate_style='3x1x1')
Definition resnet3d.py:19
nn.Module make_res_layer(nn.Module block, int inplanes, int planes, int blocks, Union[int, Sequence[int]] spatial_stride=1, Union[int, Sequence[int]] temporal_stride=1, int dilation=1, Union[int, Sequence[int]] inflate=1, str inflate_style='3x1x1', **kwargs)
Definition resnet3d.py:200
None __init__(self, int in_channels=17, int base_channels=64, Optional[Tuple] stage_blocks=(4, 6, 3), Sequence[int] out_indices=(2,), Sequence[int] spatial_strides=(2, 2, 2), Sequence[int] temporal_strides=(1, 1, 2), Sequence[int] dilations=(1, 1, 1), Sequence[int] conv1_kernel=(1, 7, 7), int conv1_stride_s=1, int conv1_stride_t=1, int pool1_stride_s=1, int pool1_stride_t=1, Sequence[int] inflate=(0, 1, 1), str inflate_style='3x1x1', input_key='keypoint_heatmap', **kwargs)
Definition resnet3d.py:111
None _make_stem_layer(self)
Definition resnet3d.py:239
forward(self, sample_dict)
Definition resnet3d.py:247