Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Static Public Member Functions | Public Attributes | Protected Member Functions | List of all members
resnet3d.ResNet3d Class Reference
Inheritance diagram for resnet3d.ResNet3d:

Public Member Functions

None __init__ (self, int in_channels=17, int base_channels=64, Optional[Tuple] stage_blocks=(4, 6, 3), Sequence[int] out_indices=(2,), Sequence[int] spatial_strides=(2, 2, 2), Sequence[int] temporal_strides=(1, 1, 2), Sequence[int] dilations=(1, 1, 1), Sequence[int] conv1_kernel=(1, 7, 7), int conv1_stride_s=1, int conv1_stride_t=1, int pool1_stride_s=1, int pool1_stride_t=1, Sequence[int] inflate=(0, 1, 1), str inflate_style='3x1x1', input_key='keypoint_heatmap', **kwargs)
 
 forward (self, sample_dict)
 

Static Public Member Functions

nn.Module make_res_layer (nn.Module block, int inplanes, int planes, int blocks, Union[int, Sequence[int]] spatial_stride=1, Union[int, Sequence[int]] temporal_stride=1, int dilation=1, Union[int, Sequence[int]] inflate=1, str inflate_style='3x1x1', **kwargs)
 

Public Attributes

 input_key
 
 in_channels
 
 base_channels
 
 num_stages
 
 stage_blocks
 
 out_indices
 
 spatial_strides
 
 temporal_strides
 
 dilations
 
 conv1_kernel
 
 conv1_stride_s
 
 conv1_stride_t
 
 pool1_stride_s
 
 pool1_stride_t
 
 stage_inflations
 
 inflate_style
 
 block
 
 inplanes
 
 conv1_stride
 
 conv1_padding
 
 pool1_stride
 
 res_layers
 
 conv1
 
 maxpool
 

Protected Member Functions

None _make_stem_layer (self)
 

Detailed Description

3D conv기반의 Resnet
mmaction2의 slowonly를 기반으로함

Definition at line 90 of file resnet3d.py.

Constructor & Destructor Documentation

◆ __init__()

None resnet3d.ResNet3d.__init__ ( self,
int in_channels = 17,
int base_channels = 64,
Optional[Tuple] stage_blocks = (4, 6, 3),
Sequence[int] out_indices = (2, ),
Sequence[int] spatial_strides = (2, 2, 2),
Sequence[int] temporal_strides = (1, 1, 2),
Sequence[int] dilations = (1, 1, 1),
Sequence[int] conv1_kernel = (1, 7, 7),
int conv1_stride_s = 1,
int conv1_stride_t = 1,
int pool1_stride_s = 1,
int pool1_stride_t = 1,
Sequence[int] inflate = (0, 1, 1),
str inflate_style = '3x1x1',
input_key = 'keypoint_heatmap',
** kwargs )

Definition at line 95 of file resnet3d.py.

111 **kwargs) -> None:
112 super().__init__()
113 """
114 args:
115 in_channels (int) : 입력 채널
116 base_channels (int) : 초기 conv의 출력 채널, base_channels의 배수로 conv 블럭들의 출력 채널이 결정됨
117 stage_blocks (tuple) : 각 스테이지에서 반복하는 블럭수
118 out_indices (Sequence[int]) : 출력 하려는 특징 스테이지 인덱스
119 spatial_strides (Sequence[int]) : 각 스테이지별 공간축 stride
120 temporal_strides (Sequence[int]) : 각 스테이지별 시간축 stride
121 dilations (Sequence[int]) : 각 스테이지별 dilation
122 conv1_kernel (Sequence[int]) : 초기 conv의 커널 shape
123 conv1_stride_s (int) : 초기 conv의 공간축 stride
124 conv1_stride_t (int) : 초기 conv의 시간축 stride
125 pool1_stride_s (int) : 초기 conv이후의 풀링 레이어의 공간축 stride
126 pool1_stride_t (int) : 초기 conv이후의 풀링 레이어의 시간축 stride
127 inflate (Sequence[int]) : 블럭의 conv 커널 타입 설정
128 inflate_style (str) : 블럭의 conv 커널 타입 설정
129 TODO : inflate_style을 추가해서 커널 종류를 다양하게 사용해 볼 수 있음
130 input_key (str) : 모듈의 inference에서 사용하는 입력데이터의 키값
131 """
132 self.input_key = input_key
133
134 self.in_channels = in_channels
135 self.base_channels = base_channels
136 self.num_stages = len(stage_blocks)
137 self.stage_blocks = stage_blocks
138 self.out_indices = out_indices
139 assert max(out_indices) < self.num_stages
140 self.spatial_strides = spatial_strides
141 self.temporal_strides = temporal_strides
142 self.dilations = dilations
143 assert len(spatial_strides) == len(temporal_strides) == len(
144 dilations) == self.num_stages
145
146 self.conv1_kernel = conv1_kernel
147 self.conv1_stride_s = conv1_stride_s
148 self.conv1_stride_t = conv1_stride_t
149 self.pool1_stride_s = pool1_stride_s
150 self.pool1_stride_t = pool1_stride_t
151 self.stage_inflations = inflate
152 self.inflate_style = inflate_style
153
154
155 self.block = Bottleneck3d
156
157 self.inplanes = self.base_channels
158
159 self.conv1_stride = (self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s)
160 self.conv1_padding = tuple([(k - 1) // 2 for k in _triple(self.conv1_kernel)])
161 self.pool1_stride = (self.pool1_stride_t, self.pool1_stride_s, self.pool1_stride_s)
162
163 self._make_stem_layer()
164
165 self.res_layers = []
166 for i, num_blocks in enumerate(self.stage_blocks):
167 spatial_stride = spatial_strides[i]
168 temporal_stride = temporal_strides[i]
169 dilation = dilations[i]
170 planes = self.base_channels * 2**i
171 res_layer = self.make_res_layer(
172 self.block,
173 self.inplanes,
174 planes,
175 num_blocks,
176 spatial_stride=spatial_stride,
177 temporal_stride=temporal_stride,
178 dilation=dilation,
179 inflate=self.stage_inflations[i],
180 inflate_style=self.inflate_style,
181 **kwargs)
182 self.inplanes = planes * self.block.expansion
183 layer_name = f'layer{i + 1}'
184 self.add_module(layer_name, res_layer)
185 self.res_layers.append(layer_name)
186
187 # self.feat_dim = self.block.expansion * \
188 # self.base_channels * 2 ** (len(self.stage_blocks) - 1)
189

Member Function Documentation

◆ _make_stem_layer()

None resnet3d.ResNet3d._make_stem_layer ( self)
protected

Definition at line 239 of file resnet3d.py.

239 def _make_stem_layer(self) -> None:
240
241 self.conv1 = nn.Sequential(nn.Conv3d(self.in_channels, self.base_channels, kernel_size=self.conv1_kernel, stride=self.conv1_stride, padding=self.conv1_padding, bias=False),
242 nn.BatchNorm3d(self.base_channels),
243 nn.ReLU(inplace=True))
244
245 self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=self.pool1_stride, padding=(0, 1, 1))
246

◆ forward()

resnet3d.ResNet3d.forward ( self,
sample_dict )
args:
    sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
        self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W)
            B : 배치 크기
            C : 입력 채널
            T : 시간

return (Tensor):
    특정 해상도의 특징 벡터 -> shape (B, C_o, T_o, H_o, W_o)
        B : 배치 크기
        C_o : 채널
        T_o : 시간

Definition at line 247 of file resnet3d.py.

247 def forward(self, sample_dict):
248 """
249 args:
250 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
251 self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W)
252 B : 배치 크기
253 C : 입력 채널
254 T : 시간
255
256 return (Tensor):
257 특정 해상도의 특징 벡터 -> shape (B, C_o, T_o, H_o, W_o)
258 B : 배치 크기
259 C_o : 채널
260 T_o : 시간
261 """
262 x = sample_dict[self.input_key]
263
264 x = self.conv1(x)
265 x = self.maxpool(x)
266
267 outs = []
268 for i, layer_name in enumerate(self.res_layers):
269 res_layer = getattr(self, layer_name)
270 x = res_layer(x)
271
272 if i in self.out_indices:
273 outs.append(x)
274
275 if len(outs) == 1:
276 return outs[0]
277
278 return tuple(outs)

◆ make_res_layer()

nn.Module resnet3d.ResNet3d.make_res_layer ( nn.Module block,
int inplanes,
int planes,
int blocks,
Union[int, Sequence[int]] spatial_stride = 1,
Union[int, Sequence[int]] temporal_stride = 1,
int dilation = 1,
Union[int, Sequence[int]] inflate = 1,
str inflate_style = '3x1x1',
** kwargs )
static

Definition at line 191 of file resnet3d.py.

200 **kwargs) -> nn.Module:
201
202 inflate = inflate if not isinstance(inflate, int) \
203 else (inflate,) * blocks
204
205 downsample = None
206 if spatial_stride != 1 or inplanes != planes * block.expansion:
207 stride = (temporal_stride, spatial_stride, spatial_stride)
208 downsample = nn.Sequential(nn.Conv3d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
209 nn.BatchNorm3d(planes * block.expansion))
210
211 layers = []
212 layers.append(
213 block(
214 inplanes,
215 planes,
216 spatial_stride=spatial_stride,
217 temporal_stride=temporal_stride,
218 dilation=dilation,
219 downsample=downsample,
220 inflate=(inflate[0] == 1),
221 inflate_style=inflate_style,
222 **kwargs))
223 inplanes = planes * block.expansion
224 for i in range(1, blocks):
225 layers.append(
226 block(
227 inplanes,
228 planes,
229 spatial_stride=1,
230 temporal_stride=1,
231 dilation=dilation,
232 inflate=(inflate[i] == 1),
233 inflate_style=inflate_style,
234 **kwargs))
235
236 return nn.Sequential(*layers)
237
238

Member Data Documentation

◆ base_channels

resnet3d.ResNet3d.base_channels

Definition at line 135 of file resnet3d.py.

◆ block

resnet3d.ResNet3d.block

Definition at line 155 of file resnet3d.py.

◆ conv1

resnet3d.ResNet3d.conv1

Definition at line 241 of file resnet3d.py.

◆ conv1_kernel

resnet3d.ResNet3d.conv1_kernel

Definition at line 146 of file resnet3d.py.

◆ conv1_padding

resnet3d.ResNet3d.conv1_padding

Definition at line 160 of file resnet3d.py.

◆ conv1_stride

resnet3d.ResNet3d.conv1_stride

Definition at line 159 of file resnet3d.py.

◆ conv1_stride_s

resnet3d.ResNet3d.conv1_stride_s

Definition at line 147 of file resnet3d.py.

◆ conv1_stride_t

resnet3d.ResNet3d.conv1_stride_t

Definition at line 148 of file resnet3d.py.

◆ dilations

resnet3d.ResNet3d.dilations

Definition at line 142 of file resnet3d.py.

◆ in_channels

resnet3d.ResNet3d.in_channels

Definition at line 134 of file resnet3d.py.

◆ inflate_style

resnet3d.ResNet3d.inflate_style

Definition at line 152 of file resnet3d.py.

◆ inplanes

resnet3d.ResNet3d.inplanes

Definition at line 157 of file resnet3d.py.

◆ input_key

resnet3d.ResNet3d.input_key

Definition at line 132 of file resnet3d.py.

◆ maxpool

resnet3d.ResNet3d.maxpool

Definition at line 245 of file resnet3d.py.

◆ num_stages

resnet3d.ResNet3d.num_stages

Definition at line 136 of file resnet3d.py.

◆ out_indices

resnet3d.ResNet3d.out_indices

Definition at line 138 of file resnet3d.py.

◆ pool1_stride

resnet3d.ResNet3d.pool1_stride

Definition at line 161 of file resnet3d.py.

◆ pool1_stride_s

resnet3d.ResNet3d.pool1_stride_s

Definition at line 149 of file resnet3d.py.

◆ pool1_stride_t

resnet3d.ResNet3d.pool1_stride_t

Definition at line 150 of file resnet3d.py.

◆ res_layers

resnet3d.ResNet3d.res_layers

Definition at line 165 of file resnet3d.py.

◆ spatial_strides

resnet3d.ResNet3d.spatial_strides

Definition at line 140 of file resnet3d.py.

◆ stage_blocks

resnet3d.ResNet3d.stage_blocks

Definition at line 137 of file resnet3d.py.

◆ stage_inflations

resnet3d.ResNet3d.stage_inflations

Definition at line 151 of file resnet3d.py.

◆ temporal_strides

resnet3d.ResNet3d.temporal_strides

Definition at line 141 of file resnet3d.py.


The documentation for this class was generated from the following file: