Safemotion Lib
Loading...
Searching...
No Matches
resnet2d.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3
4#ResNet2d의 베이스블럭
5class BasicBlock(nn.Module):
7 self,
8 inplanes: int,
9 planes: int,
10 stride: int = 1,
11 downsample = None
12 ) -> None:
13 super(BasicBlock, self).__init__()
14
15 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=(3, 1), stride=stride, padding=(1, 0), bias=False)
16 self.bn1 = nn.BatchNorm2d(planes)
17 self.relu = nn.ReLU(inplace=True)
18 self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False)
19 self.bn2 = nn.BatchNorm2d(planes)
20 self.downsample = downsample
21 self.stride = stride
22
23 def forward(self, x: torch.Tensor) -> torch.Tensor:
24
25 identity = x
26
27 if self.downsample is not None:
28 identity = self.downsample(x)
29
30 out = self.conv1(x)
31 out = self.bn1(out)
32 out = self.relu(out)
33
34 out = self.conv2(out)
35 out = self.bn2(out)
36
37 out += identity
38 out = self.relu(out)
39
40 return out
41
42class ResNet2d(nn.Module):
43 """
44 이미지 특징을 사용하기 위한 모델
45 입력을 T x 1 로 가정하고 만듦
46 """
48 self,
49 in_channels,
50 stage_blocks,
51 input_key,
52 ) -> None:
53 """
54 args:
55 in_channels (int) : 입력 채널수
56 stage_blocks (List) : 스테이지 마다 쌓는 블럭의 수를 설정
57 input_key (str) : 모듈의 inference에서 사용하는 입력데이터의 키값
58
59 """
60 super(ResNet2d, self).__init__()
61
62 block = BasicBlock
63
64 self.inplanes = in_channels
65 self.input_key = input_key
66
67 # input block
68 self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=1, stride=1, padding=0, bias=True)
69 self.bn1 = nn.BatchNorm2d(self.inplanes)
70 self.relu = nn.ReLU(inplace=True)
71
72 # residual blocks
73 self.layer1 = self._make_layer(block, in_channels, stage_blocks[0], stride=1)
74 self.layer2 = self._make_layer(block, in_channels, stage_blocks[1], stride=2)
75 self.layer3 = self._make_layer(block, in_channels, stage_blocks[2], stride=2)
76 self.layer4 = self._make_layer(block, in_channels, stage_blocks[3], stride=1)
77
78
79 def _make_layer(self, block, planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
80
81 downsample = None
82
83 # downsampling 필요할경우 downsample layer 생성
84 if stride != 1 or self.inplanes != planes:
85 downsample = nn.Sequential(
86 nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
87 nn.BatchNorm2d(planes)
88 )
89
90 layers = []
91 layers.append(block(self.inplanes, planes, stride, downsample))
92 self.inplanes = planes
93 for _ in range(1, blocks):
94 layers.append(block(self.inplanes, planes))
95
96 return nn.Sequential(*layers)
97
98 def forward(self, sample):
99 """
100 args:
101 sample (dict)) : 입력 데이터, self.input_key에 해당하는 키가 있어야함
102 self.input_key의 아이템은 Tensor 타입 -> shape (B, C, T, 1)
103 B : 배치 크기
104 C : 입력 채널
105 T : 시간
106
107 return (Tensor):
108 특정 해상도의 특징 벡터 -> shape (B, C_o, T_o, 1)
109 B : 배치 크기
110 C_o : 채널
111 T_o : 시간
112 """
113 x = sample[self.input_key]
114 x = self.conv1(x)
115 x = self.bn1(x)
116 x = self.relu(x)
117
118 x = self.layer1(x)
119 x = self.layer2(x)
120 x = self.layer3(x)
121 x = self.layer4(x)
122
123 return x
torch.Tensor forward(self, torch.Tensor x)
Definition resnet2d.py:23
None __init__(self, int inplanes, int planes, int stride=1, downsample=None)
Definition resnet2d.py:12
nn.Sequential _make_layer(self, block, int planes, int blocks, int stride=1)
Definition resnet2d.py:79
forward(self, sample)
Definition resnet2d.py:98
None __init__(self, in_channels, stage_blocks, input_key)
Definition resnet2d.py:52