Safemotion Lib
Loading...
Searching...
No Matches
model_runner_builder.py
Go to the documentation of this file.
2from smtrack.builder.runner_builder import build_track_runner
3from smtrack.builder.runner_builder import __runner_builders__ as track_runner_builders
4from smdetect.builder.runner_builder import build_detect_runner
5from smdetect.builder.runner_builder import __runner_builders__ as det_runner_builders
6from smpose.builder.runner_builder import build_pose_runner
7from smpose.builder.runner_builder import __runner_builders__ as pose_runner_builders
8from smaction.builder.runner_builder import build_action_runner
9from smaction.builder.runner_builder import __runner_builders__ as action_runner_builders
10from smrunner.utils import load_model
11
12from mmengine.config import Config
13
14#TODO: 현재 사용 안함. 사용 안하는게 확실해 지면 제거
15__model_builder__ = {
16 "ByteTrackerRunner" : build_track_runner,
17 "MMDetPyTorchRunner" : build_detect_runner,
18 "MMDetTRTRunner" : build_detect_runner,
19 "YoloV8Runner" : build_detect_runner,
20 "MMPoseTRTRunner" : build_pose_runner,
21 "MMActionPyTorchRunner" : build_action_runner,
22 "STGCNRunner" : build_action_runner,
23 "MCSTGCNRunner" : build_action_runner,
24}
25
26
27def build_model(cfg, device='cpu'):
28 """
29 세이프모션 라이브러리의 모델들을 빌드하는 기능
30 action, detect, pose, track 모델들을 빌드할 수 있음
31 args:
32 cfg (str or Config): 모델의 config 파일 경로 또는 mmengine.config.Config.fromfile()로 생성한 객체
33 device (str): 모델이 구동될 디바이스
34 return : 세이프모션 라이브러리의 모델
35 """
36
37 #문자열로 입력 받을 경우 mmengine.config.Config.fromfile() cfg 생성
38 if isinstance(cfg, str):
39 cfg = Config.fromfile(cfg)
40
41 model = cfg.model #모델 파라미터
42 model_args = model.copy() #복사
43
44 #모델 빌드
45 if model.type in det_runner_builders: #객체 검출 모델 빌드
46 model = build_detect_runner(model_args)
47
48 elif model.type in pose_runner_builders: #포즈 추정 모델 빌드
49 model = build_pose_runner(model_args)
50
51 elif model.type in action_runner_builders: #행동 인식 모델 빌드
52 model = build_action_runner(model_args)
53 if cfg.test.model_path is not None:
54 load_model(model, cfg.test.model_path, device=device)
55
56 elif model.type in track_runner_builders: #추적 모델 빌드
57 model = build_track_runner(model_args)
58 else:
59 return None
60
61 return model
62
63 # assert model.type in __model_builder__, \
64 # f'not found model type : {model.type}'
65
66 # model_builder = __model_builder__[model.type]
67 # model_args = model.copy()
68 # return model_builder(model_args)
69
70