108 weight_iou_with_det_scores=False,
113 ids (list[int]): Tracking ids.
114 det_bboxes (Tensor): of shape (N, 5)
115 weight_iou_with_det_scores (bool, optional): Whether using
116 detection scores to weight IOU which is used for matching.
118 match_iou_thr (float, optional): Matching threshold.
121 tuple(int): The assigning ids.
124 track_bboxes = np.zeros((0, 4))
126 track_bboxes = np.concatenate(
127 (track_bboxes, self.tracks[id].mean[:4][
None]), axis=0)
128 track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
132 ious = bbox_overlaps(track_bboxes, det_bboxes[:, :4])
133 if weight_iou_with_det_scores:
134 ious *= det_bboxes[:, 4][
None]
137 track_labels = torch.tensor([
138 self.tracks[id][
'labels'][-1]
for id
in ids
139 ]).to(det_bboxes.device)
141 cate_match = det_labels[
None, :] == track_labels[:,
None]
143 cate_cost = (1 - cate_match.int()) * 1e6
145 dists = (1 - ious + cate_cost).cpu().numpy()
149 cost, row, col = lap.lapjv(
150 dists, extend_cost=
True, cost_limit=1 - match_iou_thr)
152 row = np.zeros(len(ids)).astype(np.int32) - 1
153 col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
163 """Tracking forward function.
166 img (Tensor): of shape (N, C, H, W) encoding input images.
167 Typically these should be mean centered and std scaled.
168 img_metas (list[dict]): list of image info dict where each dict
169 has: 'img_shape', 'scale_factor', 'flip', and may also contain
170 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
171 model (nn.Module): MOT model.
172 bboxes (Tensor): of shape (N, 5).
173 labels (Tensor): of shape (N, ).
174 frame_id (int): The id of current frame, 0-index.
175 rescale (bool, optional): If True, the bounding boxes should be
176 rescaled to fit the original scale of the image. Defaults to
179 tuple: Tracking results.
181 if not hasattr(self,
'of'):
184 self.
of.preprocessing(img, mask)
185 for id
in self.tracks.keys():
186 self.tracks[id].mean, self.tracks[id].valid = self.
of.predict(self.tracks[id].mean)
188 if self.empty
or bboxes.size(0) == 0:
190 bboxes = bboxes[valid_inds]
191 labels = labels[valid_inds]
192 num_new_tracks = bboxes.size(0)
199 ids = torch.full((bboxes.size(0), ),
202 device=labels.device)
206 first_det_bboxes = bboxes[first_det_inds]
207 first_det_labels = labels[first_det_inds]
208 first_det_ids = ids[first_det_inds]
211 second_det_inds = (~first_det_inds) & (
213 second_det_bboxes = bboxes[second_det_inds]
214 second_det_labels = labels[second_det_inds]
215 second_det_ids = ids[second_det_inds]
223 first_match_track_inds, first_match_det_inds = self.
assign_ids(
228 valid = first_match_det_inds > -1
229 first_det_ids[valid] = torch.tensor(
232 first_match_det_bboxes = first_det_bboxes[valid]
233 first_match_det_labels = first_det_labels[valid]
234 first_match_det_ids = first_det_ids[valid]
235 assert (first_match_det_ids > -1).all()
237 first_unmatch_det_bboxes = first_det_bboxes[~valid]
238 first_unmatch_det_labels = first_det_labels[~valid]
239 first_unmatch_det_ids = first_det_ids[~valid]
240 assert (first_unmatch_det_ids == -1).all()
244 (tentative_match_track_inds,
249 valid = tentative_match_det_inds > -1
251 tentative_match_det_inds[valid]].to(labels)
254 first_unmatch_track_ids = []
257 case_1 = first_match_track_inds[i] == -1
260 if case_1
and case_2:
261 first_unmatch_track_ids.append(id)
263 second_match_track_inds, second_match_det_inds = self.
assign_ids(
264 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
266 valid = second_match_det_inds > -1
267 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
268 second_match_det_inds[valid]].to(ids)
273 valid = second_det_ids > -1
275 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
276 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
279 (first_match_det_labels, first_unmatch_det_labels), dim=0)
280 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
282 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
284 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
287 new_track_inds = ids == -1
288 ids[new_track_inds] = torch.arange(
290 self.
num_tracks + new_track_inds.sum()).to(labels)
293 self.update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
294 self.
of.postprocessing()
296 return bboxes, labels, ids