Safemotion Lib
Loading...
Searching...
No Matches
stgcn_runner.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3from smaction.builder.model_builder import build_action_model
4
6import smaction.utils.action_utils as utils
7
8#skeleton, image, fusion 버전으로 분리 필요
9class STGCNRunner(nn.Module):
10 def __init__(self, backbone, head, device='cuda:0'):
11 """
12 input:
13 model_cfg :
14 model_checkpoint :
15 device :
16 """
17 super().__init__()
18 self.backbone = build_action_model(backbone)
19 self.head = build_action_model(head)
20 self.device = device
21
22
23 def make_action_input(self, pose_list, img_shape):
24 pass
25
26 def run_recognizer(self, pose_results, img_shape=None):
27 if self.use_valid_check and not utils.check_valid(pose_results, self.k):
28 return None
29
30 if self.use_dummy_pose:
31 pose_results = utils.insert_dummy_pose(pose_results)
32
33 data = self.make_action_input(pose_results, img_shape)
34
35 return self.inference(data)
36
37 def inference(self, x):
38 feat = self.backbone(x)
39 scores = self.head(feat)
40 action_labels = scores.argmax()
41 return {'action_labels':action_labels, 'scores':scores}
42
43 def forward(self, x):
44 return self.inference(x['keypoints'])
__init__(self, backbone, head, device='cuda:0')
run_recognizer(self, pose_results, img_shape=None)
make_action_input(self, pose_list, img_shape)