Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
linear_head.LinearHead Class Reference
Inheritance diagram for linear_head.LinearHead:

Public Member Functions

 __init__ (self, in_channels, num_classes, dropout_ratio, input_key)
 
 forward (self, feat_dict)
 

Public Attributes

 dropout_ratio
 
 input_key
 
 fc_cls
 
 dropout
 

Detailed Description

1차원 텐서를 입력으로 클래스 스코어를 출력하는 기능의 모듈
1차원 입력 텐서에 fc 연산을 적용해서 클래스 스코어를 출력함
args:
    in_channels (int) : 입력 텐서의 채널 수
    num_classes (int) : 클래스 수, 출력 텐서의 채널 수
    dropout_ratio (float) : 드랍아웃 비율을 설정하는 파라미터
    input_key (str) : inference에 사용하는 입력데이터의 키값

Definition at line 4 of file linear_head.py.

Constructor & Destructor Documentation

◆ __init__()

linear_head.LinearHead.__init__ ( self,
in_channels,
num_classes,
dropout_ratio,
input_key )

Definition at line 14 of file linear_head.py.

14 def __init__(self, in_channels, num_classes, dropout_ratio, input_key):
15 super().__init__()
16
17 self.dropout_ratio = dropout_ratio
18 self.input_key = input_key
19
20 self.fc_cls = nn.Linear(in_channels, num_classes)
21
22 if self.dropout_ratio != 0:
23 self.dropout = nn.Dropout(p=self.dropout_ratio)
24 else:
25 self.dropout = None
26
27

Member Function Documentation

◆ forward()

linear_head.LinearHead.forward ( self,
feat_dict )
args:
    sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
        self.input_key의 아이템은 Tensor 타입 -> shape (B, in_channels)
            B : 배치 크기

return (Tensor):
    1차원 특징 벡터 -> shape (B, num_classes)

Definition at line 28 of file linear_head.py.

28 def forward(self, feat_dict):
29 """
30 args:
31 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
32 self.input_key의 아이템은 Tensor 타입 -> shape (B, in_channels)
33 B : 배치 크기
34
35 return (Tensor):
36 1차원 특징 벡터 -> shape (B, num_classes)
37 """
38 x = feat_dict[self.input_key]
39
40 if self.dropout is not None:
41 x = self.dropout(x)
42
43 cls_score = self.fc_cls(x)
44
45 return cls_score
46

Member Data Documentation

◆ dropout

linear_head.LinearHead.dropout

Definition at line 23 of file linear_head.py.

◆ dropout_ratio

linear_head.LinearHead.dropout_ratio

Definition at line 17 of file linear_head.py.

◆ fc_cls

linear_head.LinearHead.fc_cls

Definition at line 20 of file linear_head.py.

◆ input_key

linear_head.LinearHead.input_key

Definition at line 18 of file linear_head.py.


The documentation for this class was generated from the following file: