162 frame_id):
163 """Tracking forward function.
164
165 Args:
166 img (Tensor): of shape (N, C, H, W) encoding input images.
167 Typically these should be mean centered and std scaled.
168 img_metas (list[dict]): list of image info dict where each dict
169 has: 'img_shape', 'scale_factor', 'flip', and may also contain
170 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
171 model (nn.Module): MOT model.
172 bboxes (Tensor): of shape (N, 5).
173 labels (Tensor): of shape (N, ).
174 frame_id (int): The id of current frame, 0-index.
175 rescale (bool, optional): If True, the bounding boxes should be
176 rescaled to fit the original scale of the image. Defaults to
177 False.
178 Returns:
179 tuple: Tracking results.
180 """
181 if not hasattr(self, 'of'):
182 self.of = motion
183
184 self.of.preprocessing(img, mask)
185 for id in self.tracks.keys():
186 self.tracks[id].mean, self.tracks[id].valid = self.of.predict(self.tracks[id].mean)
187
188 if self.empty or bboxes.size(0) == 0:
189 valid_inds = bboxes[:, -1] > self.init_track_thr
190 bboxes = bboxes[valid_inds]
191 labels = labels[valid_inds]
192 num_new_tracks = bboxes.size(0)
193 ids = torch.arange(self.num_tracks,
194 self.num_tracks + num_new_tracks).to(labels)
195 self.num_tracks += num_new_tracks
196
197 else:
198
199 ids = torch.full((bboxes.size(0), ),
200 -1,
201 dtype=labels.dtype,
202 device=labels.device)
203
204
205 first_det_inds = bboxes[:, -1] > self.obj_score_thrs['high']
206 first_det_bboxes = bboxes[first_det_inds]
207 first_det_labels = labels[first_det_inds]
208 first_det_ids = ids[first_det_inds]
209
210
211 second_det_inds = (~first_det_inds) & (
212 bboxes[:, -1] > self.obj_score_thrs['low'])
213 second_det_bboxes = bboxes[second_det_inds]
214 second_det_labels = labels[second_det_inds]
215 second_det_ids = ids[second_det_inds]
216
217
218
219
220
221
222
223 first_match_track_inds, first_match_det_inds = self.assign_ids(
224 self.confirmed_ids, first_det_bboxes, first_det_labels,
225 self.weight_iou_with_det_scores, self.match_iou_thrs['high'])
226
227
228 valid = first_match_det_inds > -1
229 first_det_ids[valid] = torch.tensor(
230 self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
231
232 first_match_det_bboxes = first_det_bboxes[valid]
233 first_match_det_labels = first_det_labels[valid]
234 first_match_det_ids = first_det_ids[valid]
235 assert (first_match_det_ids > -1).all()
236
237 first_unmatch_det_bboxes = first_det_bboxes[~valid]
238 first_unmatch_det_labels = first_det_labels[~valid]
239 first_unmatch_det_ids = first_det_ids[~valid]
240 assert (first_unmatch_det_ids == -1).all()
241
242
243
244 (tentative_match_track_inds,
245 tentative_match_det_inds) = self.assign_ids(
246 self.unconfirmed_ids, first_unmatch_det_bboxes,
247 first_unmatch_det_labels, self.weight_iou_with_det_scores,
248 self.match_iou_thrs['tentative'])
249 valid = tentative_match_det_inds > -1
250 first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
251 tentative_match_det_inds[valid]].to(labels)
252
253
254 first_unmatch_track_ids = []
255 for i, id in enumerate(self.confirmed_ids):
256
257 case_1 = first_match_track_inds[i] == -1
258
259 case_2 = True
260 if case_1 and case_2:
261 first_unmatch_track_ids.append(id)
262
263 second_match_track_inds, second_match_det_inds = self.assign_ids(
264 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
265 False, self.match_iou_thrs['low'])
266 valid = second_match_det_inds > -1
267 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
268 second_match_det_inds[valid]].to(ids)
269
270
271
272
273 valid = second_det_ids > -1
274 bboxes = torch.cat(
275 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
276 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
277
278 labels = torch.cat(
279 (first_match_det_labels, first_unmatch_det_labels), dim=0)
280 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
281
282 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
283 dim=0)
284 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
285
286
287 new_track_inds = ids == -1
288 ids[new_track_inds] = torch.arange(
289 self.num_tracks,
290 self.num_tracks + new_track_inds.sum()).to(labels)
291 self.num_tracks += new_track_inds.sum()
292
293 self.update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
294 self.of.postprocessing()
295
296 return bboxes, labels, ids