Safemotion Lib
Loading...
Searching...
No Matches
reid_api.py
Go to the documentation of this file.
1import os
2import cv2
3import operator
4
5import numpy as np
6from tqdm import tqdm
7
8import torch
9from fastreid.engine import DefaultPredictor
10from fastreid.config import get_cfg
11
12# reid cfg 만들기
13def make_reid_cfg(cfg_path, checkpoint_path, device='cuda:0'):
14 """
15 fast-reid 모델 빌드를 위한 cfg 생성 기능
16 args:
17 cfg_path (str): 모델이 정의된 config 파일 경로
18 checkpoint_path (str): 모델의 파라미터가 저장된 경로
19 device (str): 모델이 구동될 디바이스
20 return: cfg
21 """
22 cfg = get_cfg()
23 cfg.merge_from_file(cfg_path)
24 cfg.merge_from_list(['MODEL.DEVICE', device])
25 cfg.merge_from_list(['MODEL.WEIGHTS', checkpoint_path])
26
27 return cfg
28
29# reid model build
31 """
32 fast-reid 모델 빌드 기능
33 args:
34 cfg_path (str): 모델이 정의된 config 파일 경로
35 checkpoint_path (str): 모델의 파라미터가 저장된 경로
36 device (str): 모델이 구동될 디바이스
37 return: reid 모델
38 """
39 reid_cfg = make_reid_cfg(args.cfg_path, args.checkpoint_path, args.device)
40 model = DefaultPredictor(reid_cfg)
41
42 return model
43
44# image : BGR type numpy array, using opecv read
45def inference_reid_model(model, image):
46 """
47 이미지 한장에 대해 reid 모델을 inference 하는 기능
48 args:
49 model: reid 모델
50 image (np.array): BGR 이미지
51 return:
52 reid features
53 """
54 image = image[:, :, ::-1] #RGB 변환
55 image = cv2.resize(image, tuple(model.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC) #리사이즈
56 image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None] #텐서 변환
57
58 #구동
59 return model(image)
60
61# 이미지 path들에서 이미지를 불러와 reid 모델에 넣고, 해당 값들을 feats에 저장
62def get_reid_feats(model, crop_images):
63 """
64 여러장의 이미지에 대해 reid 특징을 추출하는 기능
65 args:
66 model: reid 모델
67 crop_images (list[str or np.array]): 사람 영역만 자른 이미지 저장 경로 또는 이미지
68 return (list[Tensor]):
69 각 이미지에 대한 reid 특징
70 """
71 feats = []
72
73 for img in crop_images:
74 #입력이 이미지 경로일 경우 로드
75 if isinstance(img, str):
76 img = img = cv2.imread(img)
77
78 # model inference
79 inference_result = inference_reid_model(model, img)
80 feats.append(inference_result)
81
82 return feats
83
84# 쿼리와 갤러기 간의 유사도 행렬 계산
85# def get_reid_matrix(feats):
86# if not isinstance(feats, type(torch.float32)):
87# feats = torch.cat(feats, dim=0)
88
89# # compute cosine distance
90# simmat = torch.mm(feats, feats.t())
91# simmat = simmat.numpy()
92# return simmat
93
94def get_reid_matrix(q_feats, g_feats=None):
95 """
96 쿼리와 갤러리 간의 유사도 행렬을 계산하는 기능
97 args:
98 q_feats (list[tensor] or Tensor): 쿼리, reid 특징 벡터
99 g_feats (Tensor): 갤러리, reid 특징 벡터
100 return (np.array):
101 유사도 행렬
102 """
103 # q_feats, g_feats ( list[tensor] : tensor shape [1, 2048 or any ])
104 # tensor : tensor shape [N, 2048 or any]
105 if isinstance(q_feats, list):
106 q_feats = torch.cat(q_feats, dim=0)
107
108 #갤러리의 입력이 없으면 쿼리원소들 간의 유사도 측정
109 if g_feats == None:
110 g_feats = q_feats
111
112 if isinstance(g_feats, list):
113 g_feats = torch.cat(g_feats, dim=0)
114
115 # compute cosine distance
116 simmat = torch.mm(q_feats, g_feats.t())
117 simmat = simmat.numpy()
118 return simmat
119
120# 찾고자 하는 id와 넘어야 하는 유사도로 원하는 feature 찾기
121def get_reid_image_index(simmat, select_id = 1, threshold=0.8):
122 """
123 유사도 매트릭스를 기반으로 임계치 이상의 유사도를 가지는 feature를 찾는 기능
124 args:
125 simmat (np.array): 유사도 매트릭스
126 select_id (int): 유사도를 비교하려는 인덱스
127 threshold (float): 유사도 임계치
128 return:
129 s_sim (np.array):select_id와 다른 특징과의 유사도
130 keep_index (np.array): 임계치 이상의 유사도를 가지는 인덱스
131 """
132 # select_id에 대한 유사도
133 s_sim = simmat[select_id, :]
134
135 # 임계값을 넘는 feat들만 받기
136 keep = s_sim > threshold
137 keep_index = np.where(keep)[0]
138
139 return s_sim, keep_index
140
141
142# 유사도 score에 따라 정렬
143def index_sort(keep_index, s_sim):
144 """
145 유사도를 기준으로 정렬하는 기능
146 args:
147 keep_index (np.array): 인덱스
148 s_sim (np.array): 유사도
149 return (list[tuple]): (인덱스, 유사도) 튜플의 내림차순 정렬된 리스트
150
151 """
152 sort_index = []
153
154 #(인덱스, 유사도)
155 for i in range(len(keep_index)):
156 index = keep_index[i]
157 sim = s_sim[index]
158 sort_index.append((index,sim))
159
160 #내림차순 정렬
161 sorted_keep_index = sorted(sort_index, key=operator.itemgetter(1), reverse=True)
162
163 return sorted_keep_index
164
165
make_reid_cfg(cfg_path, checkpoint_path, device='cuda:0')
Definition reid_api.py:13
get_reid_matrix(q_feats, g_feats=None)
Definition reid_api.py:94
get_reid_feats(model, crop_images)
Definition reid_api.py:62
inference_reid_model(model, image)
Definition reid_api.py:45
build_reid_model(args)
Definition reid_api.py:30
get_reid_image_index(simmat, select_id=1, threshold=0.8)
Definition reid_api.py:121
index_sort(keep_index, s_sim)
Definition reid_api.py:143