106 weight_iou_with_det_scores=False,
111 ids (list[int]): Tracking ids.
112 det_bboxes (Tensor): of shape (N, 5)
113 weight_iou_with_det_scores (bool, optional): Whether using
114 detection scores to weight IOU which is used for matching.
116 match_iou_thr (float, optional): Matching threshold.
119 tuple(int): The assigning ids.
122 track_bboxes = np.zeros((0, 4))
124 track_bboxes = np.concatenate(
125 (track_bboxes, self.
tracks[id].mean[:4][
None]), axis=0)
126 track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
127 track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
130 ious = bbox_overlaps(track_bboxes, det_bboxes[:, :4])
131 if weight_iou_with_det_scores:
132 ious *= det_bboxes[:, 4][
None]
135 track_labels = torch.tensor([
136 self.
tracks[id][
'labels'][-1]
for id
in ids
137 ]).to(det_bboxes.device)
139 cate_match = det_labels[
None, :] == track_labels[:,
None]
141 cate_cost = (1 - cate_match.int()) * 1e6
143 dists = (1 - ious + cate_cost).cpu().numpy()
147 cost, row, col = lap.lapjv(
148 dists, extend_cost=
True, cost_limit=1 - match_iou_thr)
150 row = np.zeros(len(ids)).astype(np.int32) - 1
151 col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
159 """Tracking forward function.
162 img (Tensor): of shape (N, C, H, W) encoding input images.
163 Typically these should be mean centered and std scaled.
164 img_metas (list[dict]): list of image info dict where each dict
165 has: 'img_shape', 'scale_factor', 'flip', and may also contain
166 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
167 model (nn.Module): MOT model.
168 bboxes (Tensor): of shape (N, 5).
169 labels (Tensor): of shape (N, ).
170 frame_id (int): The id of current frame, 0-index.
171 rescale (bool, optional): If True, the bounding boxes should be
172 rescaled to fit the original scale of the image. Defaults to
175 tuple: Tracking results.
177 if not hasattr(self,
'kf'):
180 if self.
empty or bboxes.size(0) == 0:
182 bboxes = bboxes[valid_inds]
183 labels = labels[valid_inds]
184 num_new_tracks = bboxes.size(0)
191 ids = torch.full((bboxes.size(0), ),
194 device=labels.device)
198 first_det_bboxes = bboxes[first_det_inds]
199 first_det_labels = labels[first_det_inds]
200 first_det_ids = ids[first_det_inds]
203 second_det_inds = (~first_det_inds) & (
205 second_det_bboxes = bboxes[second_det_inds]
206 second_det_labels = labels[second_det_inds]
207 second_det_ids = ids[second_det_inds]
212 if self.
tracks[id].frame_ids[-1] != frame_id - 1:
213 self.
tracks[id].mean[7] = 0
215 self.
tracks[id].covariance) = self.
kf.predict(
219 first_match_track_inds, first_match_det_inds = self.
assign_ids(
224 valid = first_match_det_inds > -1
225 first_det_ids[valid] = torch.tensor(
228 first_match_det_bboxes = first_det_bboxes[valid]
229 first_match_det_labels = first_det_labels[valid]
230 first_match_det_ids = first_det_ids[valid]
231 assert (first_match_det_ids > -1).all()
233 first_unmatch_det_bboxes = first_det_bboxes[~valid]
234 first_unmatch_det_labels = first_det_labels[~valid]
235 first_unmatch_det_ids = first_det_ids[~valid]
236 assert (first_unmatch_det_ids == -1).all()
240 (tentative_match_track_inds,
245 valid = tentative_match_det_inds > -1
247 tentative_match_det_inds[valid]].to(labels)
250 first_unmatch_track_ids = []
253 case_1 = first_match_track_inds[i] == -1
255 case_2 = self.
tracks[id].frame_ids[-1] == frame_id - 1
256 if case_1
and case_2:
257 first_unmatch_track_ids.append(id)
259 second_match_track_inds, second_match_det_inds = self.
assign_ids(
260 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
262 valid = second_match_det_inds > -1
263 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
264 second_match_det_inds[valid]].to(ids)
269 valid = second_det_ids > -1
271 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
272 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
275 (first_match_det_labels, first_unmatch_det_labels), dim=0)
276 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
278 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
280 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
283 new_track_inds = ids == -1
284 ids[new_track_inds] = torch.arange(
289 self.
update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
290 return bboxes, labels, ids