Safemotion Lib
Loading...
Searching...
No Matches
action_demo_with_video.py
Go to the documentation of this file.
1#system path 설정
2import sys
3sys_path = ['/workspace/smlab', '/workspace']
4for path in sys_path:
5 if path not in sys.path:
6 sys.path.append(path)
7
8#import package
9import os
10import cv2
11import smrunner
12from smutils.utils_vis import draw_single_bbox_and_label
13from smutils.utils_vis import vis_pose_coco_skeleton
14from smutils.utils_os import search_file, create_directory
15from smutils.utils_data import load_labelmap
16from smutils.utils_video import make_video
17from smutils.utils_data import save_pkl_data
18import shutil
19
20#cfg 경로 설정
21det_cfg_path = '/workspace/smlab/smdetect/configs/yolo/yolov8.py' # yolov8
22track_cfg_path = '/workspace/smlab/smtrack/configs/bytetrack/bytetrack_base.py'
23pose_cfg_path = '/workspace/smlab/smpose/configs/mmpose/hrnet_trt.py' # hrnet
24action_cfg_path = '/workspace/InnoTest/models/posec3d_action.py'
25posture_cfg_path = '/workspace/InnoTest/models/posec3d_pose.py'
26
27# 모델 빌드
28device = 'cuda:0'
29det_model = smrunner.build_model(det_cfg_path)
30track_model = smrunner.build_model(track_cfg_path)
31pose_model = smrunner.build_model(pose_cfg_path)
32action_model = smrunner.build_model(action_cfg_path).to(device)
33posture_model = smrunner.build_model(posture_cfg_path).to(device)
34action_model.eval()
35posture_model.eval()
36
37#샘플링 파라미터
38#TODO: cfg로 옮겨서 runner의 파라미터로 설정할 필요가 있음 또는 pipline에 설정
39action_sample = 20
40pose_sample = 6
41
42#레이블 맵 로드
43labelmap = dict()
44labelmap['action_upper'] = load_labelmap('/workspace/smlab/smaction/datasets/safemotion_v22_upper_action.txt')
45labelmap['action_lower'] = load_labelmap('/workspace/smlab/smaction/datasets/safemotion_v22_lower_action.txt')
46labelmap['pose'] = load_labelmap('/workspace/smlab/smaction/datasets/safemotion_v22_pose.txt')
47labelmap['hand'] = load_labelmap('/workspace/smlab/smaction/datasets/safemotion_v22_hand.txt')
48
49#레이블맵 변환 테이블
50cvt_lower_map = [0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 0, 11]
51cvt_upper_map = [0, 0, 0, 3, 0, 0]
52cvt_pose_map = [0, 0, 0, 0, 4, 4, 6, 0]
53cvt_hand_map = [0, 0, 0, 3]
54
55#시각화 파라미터
56action_vis_param = dict(
57 box_color = (0, 255, 0),
58 box_thk = 3,
59 txt_color = (255, 255, 255),
60 txt_thk=3,
61 txt_scale=1.5,
62 box_type='xyxy'
63)
64
65vis_param = dict(
66 box_color = (0, 255, 0),
67 box_thk = 3,
68 txt_color = (255, 255, 255),
69 txt_thk=3,
70 txt_scale=1.5,
71 box_type='xyxy'
72)
73
74#비디오 경로 설정
75video_folder = '/media/safemotion/HDD5/pjm_test/ai_park_test_video'
76_, video_path_list = search_file(video_folder, '.mp4')
77
78#결과 저장 위치
79save_folder = '/media/safemotion/HDD5/pjm_test/ai_park_test_fps20'
80create_directory(save_folder)
81
82#임시 폴더
83tmp_save_folder = '/media/safemotion/HDD5/pjm_test/tmp_clip_images_fps20'
84create_directory(tmp_save_folder)
85shutil.rmtree(tmp_save_folder)
86
87#데모
88for video_idx, video_path in enumerate(video_path_list):
89
90 #임시 폴더 생성
91 create_directory(tmp_save_folder)
92
93 #저장 경로 설정
94 video_name = video_path.split('/')[-1]
95 save_video_path = os.path.join(save_folder, video_name)
96
97 #비디오 로드
98 cap = cv2.VideoCapture(video_path)
99
100 if not cap.isOpened():
101 print(f"Error: Could not open video({video_path}).")
102 continue
103
104 #비디오 정보 출력
105 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
106 fps = cap.get(cv2.CAP_PROP_FPS)
107 print(f'{video_path} : {total_frames} frames, {fps} FPS')
108
109 #행동인식 범위 설정, 행동은 2초 구간 포즈는 0.5초 구간
110 action_k = int(fps+0.5)*2
111 pose_k = int(fps+0.5)//2
112
113 #초기화
114 track_data = {} #추적 객체 저장 변수, 추적아이디별 박스, 스켈레톤 정보 저장
115 frame_id = 0 #프레임 아이디, 0부터 1씩 증가
116 track_model.tracker.reset() #트래커 리셋, 여러 동영상을 하나의 트래커로 사용하기 때문에 동영상 시작전에 리셋함
117 while True:
118 #비디오에서 이미지 로드
119 ret, image = cap.read()
120
121 #마지막 프레임
122 if not ret:
123 print("Reached the end of the video or encountered an error.")
124 break
125
126 #진행 정도 출력
127 print(f'{video_name} : {frame_id+1} / {total_frames}', end='\r')
128 img_shape = image.shape[:2]
129
130 #가시화 이미지 생성
131 vis_img = image.copy()
132
133 # 모델 inference(검출, 추적, 포즈)
134 det_result = det_model.run_detector(image)
135 track_result = track_model.run_tracker(det_result['det_bboxes'], det_result['det_labels'], frame_id)
136 pose_result = pose_model.run_detector(image, track_result['track_bboxes'][0])
137
138 #추적 아이디별로 포즈 데이터 저장
139 for pose in pose_result:
140 track_id = int(pose['track_id'])
141 if track_id not in track_data:
142 track_data[track_id] = []
143 pose['frame_id'] = frame_id
144 track_data[track_id].append(pose)
145
146 #스켈레톤 시각화
147 vis_img = vis_pose_coco_skeleton(vis_img, pose_result)
148
149 delete_list = []
150 txt_pos_y = 50
151 for track_id, pose_q in track_data.items():
152 #박스 시각화
153 bbox = pose_q[-1]['bbox']
154
155 last_frame = pose_q[-1]['frame_id']
156 if last_frame == frame_id:
157 label = f'{track_id:3d}'
158 vis_img = draw_single_bbox_and_label(vis_img, bbox, label, **vis_param)
159
160
161 #최근 1초동안 검출 안되면 제거
162 if frame_id - last_frame > action_k*0.5:
163 delete_list.append(track_id)
164 continue
165
166 #최근 15프레임동안 검출 안되면 넘김
167 #포스처 때문에 넣은 조건, 하나의 모델로 변경되면 조건 변경 필요함
168 if frame_id - last_frame > 15:
169 continue
170
171 #액션 모델 inference
172 action_result = action_model.run_recognizer(pose_q, action_k, action_sample, device=device)
173
174 #결과가 None면 유효성 체크에서 탈락한 것임
175 if action_result is None:
176 continue
177
178 #포스처 모델 inference
179 posture_result = posture_model.run_recognizer(pose_q[-pose_k:], pose_k, pose_sample, False, device=device)
180
181 #첫 프레임과 마지막 프레임 간격이 설정한 구간을 넘기면 맨 앞쪽 포즈 제거
182 #상황에 따라 루프로 제거해야 할 수도 있음
183 if last_frame - pose_q[0]['frame_id'] >= action_k:
184 del pose_q[0]
185
186 #추론된 결과 변환
187 pose_label = cvt_pose_map[posture_result['pred_pose']]
188 upper_label = cvt_upper_map[action_result['pred_action_upper']]
189 lower_label = cvt_lower_map[action_result['pred_action_lower']]
190 hand_label = cvt_hand_map[posture_result['pred_hand']]
191
192 #행동 라벨 및 박스 시각화
193 label = f"{track_id:3d}: {labelmap['pose'][pose_label]}"
194
195 if hand_label != 0:
196 label += f"/{labelmap['hand'][hand_label]}"
197 if upper_label != 0:
198 label += f"/{labelmap['action_upper'][upper_label]}"
199 if lower_label != 0:
200 label += f"/{labelmap['action_lower'][lower_label]}"
201
202 cv2.putText(vis_img, label, (10, txt_pos_y), cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 0), 10, 1)
203 cv2.putText(vis_img, label, (10, txt_pos_y), cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 255, 255), 2, 1)
204 txt_pos_y += 45
205
206 vis_img = draw_single_bbox_and_label(vis_img, bbox, f'{track_id:3d}', **action_vis_param)
207
208 #긴시간 미검출된 추적 아이디 제거
209 for track_id in delete_list:
210 del track_data[track_id]
211
212 #프레임 아이디 증가
213 frame_id+=1
214
215 #시각화 이미지 저장
216 name = f'{frame_id:09d}.jpg'
217 save_path = os.path.join(tmp_save_folder, name)
218 cv2.imwrite(save_path, vis_img)
219
220 # 시각화 이미지 비디오로 저장
221 make_video(tmp_save_folder, save_video_path, fps=fps, half=False) #동영상 생성
222 shutil.rmtree(tmp_save_folder) #임시 폴더 삭제