Safemotion Lib
Loading...
Searching...
No Matches
mmpose_pytorch_runner.py
Go to the documentation of this file.
2# import torch
3# import numpy as np
4
5# from mmpose.apis import init_model
6# from smutils.utils_image import crop_image
7
8# class MMPosePyTorchRunner(object):
9# def __init__(self, model_cfg, checkpoint, device='cuda', batch_size=1):
10# """
11# input:
12# deploy_cfg :
13# """
14# self.model = init_model(model_cfg, checkpoint, device=device)
15# self.batch_size = batch_size
16# self.test_pipeline = self.build_test_pipeline(self.model.cfg)
17
18# def build_test_pipeline(self, cfg):
19# from mmcv.transforms import Compose
20
21# test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
22# return test_pipeline
23
24# def make_src(self, image, image_id, bboxes=None):
25# """
26# return:
27# crop_images
28# bboxes
29# track_ids
30# image_ids
31# """
32# if bboxes is None:
33# h, w = image.shape[:2]
34# return [image], np.array([[0, 0, w, h]]), np.array([0]), [image_id]
35
36# num_person, dim = bboxes.shape
37# box_index = 1 if dim == 6 else 0
38# track_ids = bboxes[:, 0] if dim == 6 else np.arange(num_person)
39# image_ids = [image_id]*num_person
40# crop_images = crop_image(image, bboxes[:, box_index:box_index+4])
41
42# return crop_images, bboxes[:, box_index:box_index+4], track_ids, image_ids
43
44# def divide_into_batches(self, data, batch_size):
45# if batch_size == 1:
46# return data
47
48# batches = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
49# return batches
50
51# def run_detector(self, images, bboxes=None):
52# is_batch = isinstance(images, (list, tuple))
53
54# images = [images] if not is_batch else images
55# if not is_batch:
56# bboxes = [bboxes]
57# elif bboxes is None:
58# bboxes = [None]*len(images)
59
60# crop_images, bbox_list, track_ids, image_ids = [], [], [], []
61# for image_id, (image, bbox) in enumerate(zip(images, bboxes)):
62# crops, bbs, tids, imids = self.make_src(image, image_id, bbox)
63# crop_images.extend(crops)
64# bbox_list.append(bbs)
65# track_ids.append(tids)
66# image_ids.extend(imids)
67
68# bbox_list = np.vstack(bbox_list)
69# track_ids = np.hstack(track_ids)
70
71# crop_images = self.divide_into_batches(crop_images, self.batch_size)
72# with torch.no_grad():
73# results = [item for image in crop_images for item in self.model.test_step(self.task_processor.create_input(image, self.input_shape, test_pipeline=self.test_pipeline)[0])]
74
75# pose_results = []
76# pre_img_id = -1
77# for res, bbox, track_id, img_id in zip(results, bbox_list, track_ids, image_ids):
78# keypoints = res.pred_instances.keypoints
79# keypoint_scores = np.expand_dims(res.pred_instances.keypoint_scores, axis=-1)
80# keypoints = np.dstack((keypoints, keypoint_scores))
81
82# if pre_img_id != img_id:
83# pre_img_id = img_id
84# pose_results.append([])
85
86# data = {'track_id': track_id, 'bbox': bbox, 'keypoints': keypoints}
87# pose_results[-1].append(data)
88
89# if is_batch:
90# return pose_results
91# else:
92# return pose_results[0]
93