Safemotion Lib
Loading...
Searching...
No Matches
smaction
runner
stgcn_runner.py
Go to the documentation of this file.
1
import
torch
2
import
torch.nn
as
nn
3
from
smaction.builder.model_builder
import
build_action_model
4
5
from
smaction.utils.transforms
import
*
6
import
smaction.utils.action_utils
as
utils
7
8
#skeleton, image, fusion 버전으로 분리 필요
9
class
STGCNRunner
(nn.Module):
10
def
__init__
(self, backbone, head, device='cuda:0
'):
11
"""
12
input:
13
model_cfg :
14
model_checkpoint :
15
device :
16
"""
17
super().
__init__
()
18
self.
backbone
= build_action_model(backbone)
19
self.
head
= build_action_model(head)
20
self.
device
= device
21
22
23
def
make_action_input
(self, pose_list, img_shape):
24
pass
25
26
def
run_recognizer
(self, pose_results, img_shape=None):
27
if
self.use_valid_check
and
not
utils.check_valid(pose_results, self.
k
):
28
return
None
29
30
if
self.use_dummy_pose:
31
pose_results = utils.insert_dummy_pose(pose_results)
32
33
data = self.
make_action_input
(pose_results, img_shape)
34
35
return
self.
inference
(data)
36
37
def
inference
(self, x):
38
feat = self.
backbone
(x)
39
scores = self.
head
(feat)
40
action_labels = scores.argmax()
41
return
{
'action_labels'
:action_labels,
'scores'
:scores}
42
43
def
forward
(self, x):
44
return
self.
inference
(x[
'keypoints'
])
stgcn_runner.STGCNRunner
Definition
stgcn_runner.py:9
stgcn_runner.STGCNRunner.__init__
__init__(self, backbone, head, device='cuda:0')
Definition
stgcn_runner.py:10
stgcn_runner.STGCNRunner.forward
forward(self, x)
Definition
stgcn_runner.py:43
stgcn_runner.STGCNRunner.run_recognizer
run_recognizer(self, pose_results, img_shape=None)
Definition
stgcn_runner.py:26
stgcn_runner.STGCNRunner.k
k
Definition
stgcn_runner.py:27
stgcn_runner.STGCNRunner.inference
inference(self, x)
Definition
stgcn_runner.py:37
stgcn_runner.STGCNRunner.head
head
Definition
stgcn_runner.py:19
stgcn_runner.STGCNRunner.device
device
Definition
stgcn_runner.py:20
stgcn_runner.STGCNRunner.backbone
backbone
Definition
stgcn_runner.py:18
stgcn_runner.STGCNRunner.make_action_input
make_action_input(self, pose_list, img_shape)
Definition
stgcn_runner.py:23
transforms
torch.nn
Generated by
1.10.0