args:
sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W) or (B, C, H(or T), W)
B : 배치 크기
C : 입력 채널
T : 시간
return (Tensor):
1차원 특징 벡터 -> shape (B, C_o)
B : 배치 크기
C_o : 채널 크기, 입력 채널의 합
Definition at line 25 of file cat_layer.py.
25 def forward(self, feat_dict):
26 """
27 args:
28 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
29 self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, H, W) or (B, C, H(or T), W)
30 B : 배치 크기
31 C : 입력 채널
32 T : 시간
33
34 return (Tensor):
35 1차원 특징 벡터 -> shape (B, C_o)
36 B : 배치 크기
37 C_o : 채널 크기, 입력 채널의 합
38 """
39 feats = []
40 for key in self.input_key:
41 x = feat_dict[key]
42 x = self.avg_pool_dict[key](x)
43 x = x.view(x.shape[0], -1)
44 feats.append(x)
45
46 cat_feat = torch.cat(feats, dim=1)
47
48 return cat_feat