Safemotion Lib
Loading...
Searching...
No Matches
Functions
generate_image_features Namespace Reference

Functions

 inference_one_frame (model, image, bboxes, device)
 
 gen_image_features (model, data_manager, save_path, device)
 
 run_gen_image_features (args)
 

Function Documentation

◆ gen_image_features()

generate_image_features.gen_image_features ( model,
data_manager,
save_path,
device )
하나의 어노테이션 파일을 기반으로 사람영역에 대해 이미지 특징 생성 모델을 사용하여 예측한 이미지 특징을 저장하는 기능
args:
    model: pytorch 모델
    data_manager: AnnotationDataManager 객체, 어노테이션 데이터파일로 초기화까지 진행
    save_path (str): 이미지 특징을 저장할 경로
    device (str): 모델이 구동될 디바이스

Definition at line 56 of file generate_image_features.py.

56def gen_image_features(model, data_manager, save_path, device):
57 """
58 하나의 어노테이션 파일을 기반으로 사람영역에 대해 이미지 특징 생성 모델을 사용하여 예측한 이미지 특징을 저장하는 기능
59 args:
60 model: pytorch 모델
61 data_manager: AnnotationDataManager 객체, 어노테이션 데이터파일로 초기화까지 진행
62 save_path (str): 이미지 특징을 저장할 경로
63 device (str): 모델이 구동될 디바이스
64 """
65
66 #이미지 숫자 체크
67 img_num = data_manager.get_number_of_image()
68 image_feats_data = {}
69 for i in tqdm(range(img_num)):
70 anno_data = data_manager.get_anno_data_in_image() #이미지에 있는 어노테이션 데이터를 가져옴
71 img_path = data_manager.get_image_path() #이미지 경로를 가져옴
72
73 #어노테이션 박스 구조 변경 (x, y, w, h) -> (x, y, x, y)
74 bbox_in_image = []
75 for anno in anno_data:
76 bbox_tmp = copy.deepcopy(anno['bbox'])
77
78 #마진 추가
79 margin_x = bbox_tmp[2]*0.15
80 margin_y = bbox_tmp[3]*0.15
81
82 bbox_tmp[0] -= margin_x
83 bbox_tmp[1] -= margin_y
84 bbox_tmp[2] *= 1.3
85 bbox_tmp[3] *= 1.3
86
87 #박스 구조 변경
88 bbox_tmp[2] += bbox_tmp[0] #TODO: 영상 크기로 최대치 체크
89 bbox_tmp[3] += bbox_tmp[1] #TODO: 영상 크기로 최대치 체크
90 bbox_tmp[0] = max(bbox_tmp[0], 0)
91 bbox_tmp[1] = max(bbox_tmp[1], 0)
92
93 bbox_in_image.append(bbox_tmp)
94
95 #박스가 없으면 다음 이미지로 넘어감
96 if len(bbox_in_image) == 0:
97 data_manager.move_image_right()
98 continue
99
100 bboxes = np.array(bbox_in_image) #박스 타입 변경
101 image = cv2.imread(img_path) #이미지 로드
102 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
104 #이미지 특징 생성
105 image_feats = inference_one_frame(model, image, bboxes, device)
106
107 #이미지 특징을 추적 아이디 별로 저장
108 for anno, feat in zip(anno_data, image_feats):
109 anno_id = anno['id']
110 image_id = anno['image_id']
111 track_id = anno['track_id']
112
113 if track_id not in image_feats_data:
114 image_feats_data[track_id] = {}
115
116 image_feats_data[track_id][image_id] = feat
117
118 #다음 이미지로 이동
119 data_manager.move_image_right()
120
121 #이미지 특징 저장
122 save_pkl_data(image_feats_data, save_path)
123

◆ inference_one_frame()

generate_image_features.inference_one_frame ( model,
image,
bboxes,
device )
한프레임에서 검출된 박스에대해 이미지 특징 생성 모델을 inference하는 기능
args:
    model: pytorch 모델
    image (np.array): 원본 이미지
    bboxes (np.array): 검출 박스(x,y,x,y), shape (num_person, 4 or 5 or 6)
    device (str): 모델이 구동될 디바이스
return (Tensor): 이미지 특징, shape  (num_person, C)

Definition at line 19 of file generate_image_features.py.

19def inference_one_frame(model, image, bboxes, device):
20 """
21 한프레임에서 검출된 박스에대해 이미지 특징 생성 모델을 inference하는 기능
22 args:
23 model: pytorch 모델
24 image (np.array): 원본 이미지
25 bboxes (np.array): 검출 박스(x,y,x,y), shape (num_person, 4 or 5 or 6)
26 device (str): 모델이 구동될 디바이스
27 return (Tensor): 이미지 특징, shape (num_person, C)
28 """
29
30 #이미지 전처리
31 preprocess = transforms.Compose([
32 transforms.ToPILImage(), #구조변환, np.array -> PILImage
33 transforms.Resize((224, 224)), #리사이즈
34 transforms.ToTensor(), #텐서로 변환
35 transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.2758]), #가우시안 정규화
36 ])
37
38 crop_images = crop_image(image, bboxes) #박스영역 잘라내기
39
40 #이미지 전처리
41 transformed_images = []
42 for img in crop_images:
43 transformed_images.append(preprocess(img))
44
45 image_batch = torch.stack(transformed_images) #모델 입력 데이터 생성, 모델 inference를 위한 배치단위 변환
46 image_batch = image_batch.to(device) #입력 데이터를 디바이스로 전송
47
48 #inference
49 with torch.no_grad():
50 results = model(image_batch) #모델 구동
51 results = results.cpu() #결과(이미지 특징) cpu로 전송
52
53 return results
54
55

◆ run_gen_image_features()

generate_image_features.run_gen_image_features ( args)
어노테이션 파일들을 기반으로 사람영역에 대해 이미지 특징 생성 모델을 사용하여 예측한 이미지 특징을 저장하는 기능
이미지 특징 생성 모델은 timm 패키지에서 제공하고 있는 eva02_base_patch16_clip_224 모델(사전학습 파라미터 이용)을 사용함.
args:
    save_root (str): 이미지 특징을 저장할 폴더
    device (str): 모델이 구동될 디바이스
    dataset_folder_list (list[str]): 어노테이션 파일이 저장된 폴더 리스트

Definition at line 124 of file generate_image_features.py.

124def run_gen_image_features(args):
125 """
126 어노테이션 파일들을 기반으로 사람영역에 대해 이미지 특징 생성 모델을 사용하여 예측한 이미지 특징을 저장하는 기능
127 이미지 특징 생성 모델은 timm 패키지에서 제공하고 있는 eva02_base_patch16_clip_224 모델(사전학습 파라미터 이용)을 사용함.
128 args:
129 save_root (str): 이미지 특징을 저장할 폴더
130 device (str): 모델이 구동될 디바이스
131 dataset_folder_list (list[str]): 어노테이션 파일이 저장된 폴더 리스트
132 """
133 print('gen image features start')
134 #json 저장 폴더 생성
135
136 #이미지 특징이 저장될 폴더 생성
137 create_directory(args.save_root)
138
139 #모델이 구동될 디바이스 설정
140 device = args.device
141
142 #모델 생성
143 model = timm.create_model('eva02_base_patch16_clip_224', pretrained=True) #timm 모델
144 model.head = nn.Identity() #헤더를 제거
145 model.to(device) #모델을 디바이스로 전송
146 model.eval() #eval 모드 전환
147
148 #어노테이션 관리자 생성
149 data_manager = AnnotationDataManager()
150
151 #어노테이션 리스트
152 json_name_list = []
153 json_path_list = []
154 for dataset_folder in args.dataset_folder_list:
155 name_list, path_list = search_file(dataset_folder, '.json')
156 json_name_list.extend(name_list)
157 json_path_list.extend(path_list)
158
159 file_num = len(json_path_list) #파일 수량, 진행 정도 표시를 위함
160 for i, (json_name, json_path) in enumerate(zip(json_name_list, json_path_list)):
161
162 #저장 경로 설정
163 save_path = os.path.join(args.save_root, json_name)
164 save_path = save_path.replace('.json', '.pkl')
165
166 #어노테이션 관리자 초기화
167 data_manager.load_annotation(json_path)
168 data_manager.init_annotation()
169
170 #진행도 출력
171 print(f"[ {i+1} / {file_num}] : load path : {json_path}")
172 print(f"[ {i+1} / {file_num}] : save path : {save_path}")
173
174 #어노테이션 파일 하나에 대해서 이미지 특징 생성 진행
175 gen_image_features(model, data_manager, save_path, device)
176 print('')