111 **kwargs) -> None:
112 super().__init__()
113 """
114 args:
115 in_channels (int) : 입력 채널
116 base_channels (int) : 초기 conv의 출력 채널, base_channels의 배수로 conv 블럭들의 출력 채널이 결정됨
117 stage_blocks (tuple) : 각 스테이지에서 반복하는 블럭수
118 out_indices (Sequence[int]) : 출력 하려는 특징 스테이지 인덱스
119 spatial_strides (Sequence[int]) : 각 스테이지별 공간축 stride
120 temporal_strides (Sequence[int]) : 각 스테이지별 시간축 stride
121 dilations (Sequence[int]) : 각 스테이지별 dilation
122 conv1_kernel (Sequence[int]) : 초기 conv의 커널 shape
123 conv1_stride_s (int) : 초기 conv의 공간축 stride
124 conv1_stride_t (int) : 초기 conv의 시간축 stride
125 pool1_stride_s (int) : 초기 conv이후의 풀링 레이어의 공간축 stride
126 pool1_stride_t (int) : 초기 conv이후의 풀링 레이어의 시간축 stride
127 inflate (Sequence[int]) : 블럭의 conv 커널 타입 설정
128 inflate_style (str) : 블럭의 conv 커널 타입 설정
129 TODO : inflate_style을 추가해서 커널 종류를 다양하게 사용해 볼 수 있음
130 input_key (str) : 모듈의 inference에서 사용하는 입력데이터의 키값
131 """
132 self.input_key = input_key
133
134 self.in_channels = in_channels
135 self.base_channels = base_channels
136 self.num_stages = len(stage_blocks)
137 self.stage_blocks = stage_blocks
138 self.out_indices = out_indices
139 assert max(out_indices) < self.num_stages
140 self.spatial_strides = spatial_strides
141 self.temporal_strides = temporal_strides
142 self.dilations = dilations
143 assert len(spatial_strides) == len(temporal_strides) == len(
144 dilations) == self.num_stages
145
146 self.conv1_kernel = conv1_kernel
147 self.conv1_stride_s = conv1_stride_s
148 self.conv1_stride_t = conv1_stride_t
149 self.pool1_stride_s = pool1_stride_s
150 self.pool1_stride_t = pool1_stride_t
151 self.stage_inflations = inflate
152 self.inflate_style = inflate_style
153
154
155 self.block = Bottleneck3d
156
157 self.inplanes = self.base_channels
158
159 self.conv1_stride = (self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s)
160 self.conv1_padding = tuple([(k - 1) // 2 for k in _triple(self.conv1_kernel)])
161 self.pool1_stride = (self.pool1_stride_t, self.pool1_stride_s, self.pool1_stride_s)
162
163 self._make_stem_layer()
164
165 self.res_layers = []
166 for i, num_blocks in enumerate(self.stage_blocks):
167 spatial_stride = spatial_strides[i]
168 temporal_stride = temporal_strides[i]
169 dilation = dilations[i]
170 planes = self.base_channels * 2**i
171 res_layer = self.make_res_layer(
172 self.block,
173 self.inplanes,
174 planes,
175 num_blocks,
176 spatial_stride=spatial_stride,
177 temporal_stride=temporal_stride,
178 dilation=dilation,
179 inflate=self.stage_inflations[i],
180 inflate_style=self.inflate_style,
181 **kwargs)
182 self.inplanes = planes * self.block.expansion
183 layer_name = f'layer{i + 1}'
184 self.add_module(layer_name, res_layer)
185 self.res_layers.append(layer_name)
186
187
188
189