Safemotion Lib
Loading...
Searching...
No Matches
smaction
models
heads
linear_head.py
Go to the documentation of this file.
1
import
torch
2
import
torch.nn
as
nn
3
4
class
LinearHead
(nn.Module):
5
"""
6
1차원 텐서를 입력으로 클래스 스코어를 출력하는 기능의 모듈
7
1차원 입력 텐서에 fc 연산을 적용해서 클래스 스코어를 출력함
8
args:
9
in_channels (int) : 입력 텐서의 채널 수
10
num_classes (int) : 클래스 수, 출력 텐서의 채널 수
11
dropout_ratio (float) : 드랍아웃 비율을 설정하는 파라미터
12
input_key (str) : inference에 사용하는 입력데이터의 키값
13
"""
14
def
__init__
(self, in_channels, num_classes, dropout_ratio, input_key):
15
super().
__init__
()
16
17
self.
dropout_ratio
= dropout_ratio
18
self.
input_key
= input_key
19
20
self.
fc_cls
= nn.Linear(in_channels, num_classes)
21
22
if
self.
dropout_ratio
!= 0:
23
self.
dropout
= nn.Dropout(p=self.
dropout_ratio
)
24
else
:
25
self.
dropout
=
None
26
27
28
def
forward
(self, feat_dict):
29
"""
30
args:
31
sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
32
self.input_key의 아이템은 Tensor 타입 -> shape (B, in_channels)
33
B : 배치 크기
34
35
return (Tensor):
36
1차원 특징 벡터 -> shape (B, num_classes)
37
"""
38
x = feat_dict[self.
input_key
]
39
40
if
self.
dropout
is
not
None
:
41
x = self.
dropout
(x)
42
43
cls_score = self.
fc_cls
(x)
44
45
return
cls_score
46
linear_head.LinearHead
Definition
linear_head.py:4
linear_head.LinearHead.forward
forward(self, feat_dict)
Definition
linear_head.py:28
linear_head.LinearHead.dropout_ratio
dropout_ratio
Definition
linear_head.py:17
linear_head.LinearHead.dropout
dropout
Definition
linear_head.py:23
linear_head.LinearHead.__init__
__init__(self, in_channels, num_classes, dropout_ratio, input_key)
Definition
linear_head.py:14
linear_head.LinearHead.input_key
input_key
Definition
linear_head.py:18
linear_head.LinearHead.fc_cls
fc_cls
Definition
linear_head.py:20
torch.nn
Generated by
1.10.0