Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
mmdet_trt_runner.MMDetTRTRunner Class Reference
Inheritance diagram for mmdet_trt_runner.MMDetTRTRunner:

Public Member Functions

 __init__ (self, deploy_cfg, model_cfg, backend_model, device='cuda')
 
 run_detector (self, image)
 

Public Attributes

 deploy_cfg
 
 model_cfg
 
 task_processor
 
 model
 
 input_shape
 

Detailed Description

mmdetection의 Deploy(TensorRT) 모델을 구동시키기 위한 클래스 
args:
    deploy_cfg (str): mmtection의 deploy 내용이 정의된 config 파일 경로
    model_cfg (str): mmtection의 모델이 정의되 config 파일 경로
    backend_model (str): TenosrRT 모델이 저장된 파일 경로
    device (str): 모델이 구동될 디바이스

Definition at line 5 of file mmdet_trt_runner.py.

Constructor & Destructor Documentation

◆ __init__()

mmdet_trt_runner.MMDetTRTRunner.__init__ ( self,
deploy_cfg,
model_cfg,
backend_model,
device = 'cuda' )

Definition at line 14 of file mmdet_trt_runner.py.

14 def __init__(self, deploy_cfg, model_cfg, backend_model, device='cuda'):
15
16 self.deploy_cfg, self.model_cfg = load_config(deploy_cfg, model_cfg)
17 self.task_processor = build_task_processor(self.model_cfg, self.deploy_cfg, device)
18 self.model = self.task_processor.build_backend_model(backend_model)
19 self.input_shape = get_input_shape(deploy_cfg)
20

Member Function Documentation

◆ run_detector()

mmdet_trt_runner.MMDetTRTRunner.run_detector ( self,
image )
객체 검출 모델을 구동시키는 기능
args:
    image (np.array): RGB 이미지
return (dict): 객체 검출 결과

Definition at line 21 of file mmdet_trt_runner.py.

21 def run_detector(self, image):
22 """
23 객체 검출 모델을 구동시키는 기능
24 args:
25 image (np.array): RGB 이미지
26 return (dict): 객체 검출 결과
27 """
28
29 #배치 단위인지 채크
30 is_batch = False
31 if isinstance(image, (list, tuple)):
32 is_batch = True
33
34 #입력 생성
35 input_data, _ = self.task_processor.create_input(image, self.input_shape)
36
37 #모델 구동
38 with torch.no_grad():
39 result = self.model.test_step(input_data)
40
41 #출력 포멧 설정, mmaction의 이전 버전과 포맷 매칭
42 #TODO:
43 ret_data = []
44 for d in result:
45 masks = d.pred_instances.masks #TODO : Instance 모델일 경우에만 작동되도록 바꿔야함
46 bboxes = d.pred_instances.bboxes.cpu() #박스
47 scores = d.pred_instances.scores.cpu() #스코어
48 labels = d.pred_instances.labels.cpu() #라벨
49
50 #기존 출력 포맷으로 변환
51 scores = scores.unsqueeze(1)
52 bboxes = torch.cat((bboxes, scores), dim=1)
53 data = {}
54 data['det_bboxes'] = bboxes
55 data['det_labels'] = labels
56 data['masks'] = masks
57 ret_data.append(data)
58
59 if is_batch: #배치 단위
60 return ret_data
61 else: #이미지 한장
62 return ret_data[0]
63

Member Data Documentation

◆ deploy_cfg

mmdet_trt_runner.MMDetTRTRunner.deploy_cfg

Definition at line 16 of file mmdet_trt_runner.py.

◆ input_shape

mmdet_trt_runner.MMDetTRTRunner.input_shape

Definition at line 19 of file mmdet_trt_runner.py.

◆ model

mmdet_trt_runner.MMDetTRTRunner.model

Definition at line 18 of file mmdet_trt_runner.py.

◆ model_cfg

mmdet_trt_runner.MMDetTRTRunner.model_cfg

Definition at line 16 of file mmdet_trt_runner.py.

◆ task_processor

mmdet_trt_runner.MMDetTRTRunner.task_processor

Definition at line 17 of file mmdet_trt_runner.py.


The documentation for this class was generated from the following file: