21 def run_detector(self, image):
22 """
23 객체 검출 모델을 구동시키는 기능
24 args:
25 image (np.array): RGB 이미지
26 return (dict): 객체 검출 결과
27 """
28
29
30 is_batch = False
31 if isinstance(image, (list, tuple)):
32 is_batch = True
33
34
35 input_data, _ = self.task_processor.create_input(image, self.input_shape)
36
37
38 with torch.no_grad():
39 result = self.model.test_step(input_data)
40
41
42
43 ret_data = []
44 for d in result:
45 masks = d.pred_instances.masks
46 bboxes = d.pred_instances.bboxes.cpu()
47 scores = d.pred_instances.scores.cpu()
48 labels = d.pred_instances.labels.cpu()
49
50
51 scores = scores.unsqueeze(1)
52 bboxes = torch.cat((bboxes, scores), dim=1)
53 data = {}
54 data['det_bboxes'] = bboxes
55 data['det_labels'] = labels
56 data['masks'] = masks
57 ret_data.append(data)
58
59 if is_batch:
60 return ret_data
61 else:
62 return ret_data[0]
63