Safemotion Lib
Loading...
Searching...
No Matches
byte_tracker_mod.py
Go to the documentation of this file.
1import lap
2import numpy as np
3import torch
4
5from smutils.bbox.iou_calculator import bbox_overlaps
6from smtrack.models.trackers.base_tracker import BaseTracker
7
8class ByteTrackerMod(BaseTracker):
9 """Tracker for ByteTrack.
10 Args:
11 obj_score_thrs (dict): Detection score threshold for matching objects.
12 - high (float): Threshold of the first matching. Defaults to 0.6.
13 - low (float): Threshold of the second matching. Defaults to 0.1.
14 init_track_thr (float): Detection score threshold for initializing a
15 new tracklet. Defaults to 0.7.
16 weight_iou_with_det_scores (bool): Whether using detection scores to
17 weight IOU which is used for matching. Defaults to True.
18 match_iou_thrs (dict): IOU distance threshold for matching between two
19 frames.
20 - high (float): Threshold of the first matching. Defaults to 0.1.
21 - low (float): Threshold of the second matching. Defaults to 0.5.
22 - tentative (float): Threshold of the matching for tentative
23 tracklets. Defaults to 0.3.
24 num_tentatives (int, optional): Number of continuous frames to confirm
25 a track. Defaults to 3.
26 init_cfg (dict or list[dict], optional): Initialization config dict.
27 Defaults to None.
28 """
29
30 def __init__(self,
31 obj_score_thrs=dict(high=0.6, low=0.1),
32 init_track_thr=0.7,
33 weight_iou_with_det_scores=True,
34 match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
35 num_tentatives=3,
36 **kwargs):
37 super().__init__(**kwargs)
38 self.obj_score_thrs = obj_score_thrs
39 self.init_track_thr = init_track_thr
40
41 self.weight_iou_with_det_scores = weight_iou_with_det_scores
42 self.match_iou_thrs = match_iou_thrs
43
44 self.num_tentatives = num_tentatives
45
46 # self.kf = KalmanFilter()
47
48 @property
49 def confirmed_ids(self):
50 """Confirmed ids in the tracker."""
51 ids = [id for id, track in self.tracks.items() if not track.tentative]
52 return ids
53
54 @property
55 def unconfirmed_ids(self):
56 """Unconfirmed ids in the tracker."""
57 ids = [id for id, track in self.tracks.items() if track.tentative]
58 return ids
59
60 def init_track(self, id, obj):
61 """Initialize a track."""
62 super().init_track(id, obj)
63 if self.tracks[id].frame_ids[-1] == 0:
64 self.tracks[id].tentative = False
65 else:
66 self.tracks[id].tentative = True
67 bbox = self.tracks[id].bboxes[-1][:, :4]
68 assert bbox.ndim == 2 and bbox.shape[0] == 1
69 bbox = bbox.squeeze(0).cpu().numpy()
70 self.tracks[id].mean, self.tracks[id].valid = self.of.initiate(bbox)
71
72 def update_track(self, id, obj):
73 """Update a track."""
74 super().update_track(id, obj)
75 if self.tracks[id].tentative:
76 if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
77 self.tracks[id].tentative = False
78 bbox = self.tracks[id].bboxes[-1]#[:, :4] # size = (1, 4)
79 assert bbox.ndim == 2 and bbox.shape[0] == 1
80 bbox = bbox.squeeze(0).cpu().numpy()
81 track_label = self.tracks[id]['labels'][-1]
82 label_idx = self.memo_items.index('labels')
83 obj_label = obj[label_idx]
84 assert obj_label == track_label
85 self.tracks[id].mean = self.of.update(self.tracks[id].mean, bbox)
86
87 def pop_invalid_tracks(self, frame_id):
88 """Pop out invalid tracks."""
89 invalid_ids = []
90 for k, v in self.tracks.items():
91 # case1: disappeared frames >= self.num_frames_retrain
92 case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
93 # case2: tentative tracks but not matched in this frame
94 case2 = v.tentative and frame_id - v['frame_ids'][-1] >= self.num_frames_retain//3
95
96 case3 = v.valid and frame_id - v['frame_ids'][-1] >= self.num_frames_retain//3
97
98 if case1 or case2 or case3:
99 invalid_ids.append(k)
100
101 for invalid_id in invalid_ids:
102 self.tracks.pop(invalid_id)
103
104 def assign_ids(self,
105 ids,
106 det_bboxes,
107 det_labels,
108 weight_iou_with_det_scores=False,
109 match_iou_thr=0.5):
110 """Assign ids.
111
112 Args:
113 ids (list[int]): Tracking ids.
114 det_bboxes (Tensor): of shape (N, 5)
115 weight_iou_with_det_scores (bool, optional): Whether using
116 detection scores to weight IOU which is used for matching.
117 Defaults to False.
118 match_iou_thr (float, optional): Matching threshold.
119 Defaults to 0.5.
120 Returns:
121 tuple(int): The assigning ids.
122 """
123 # get track_bboxes
124 track_bboxes = np.zeros((0, 4))
125 for id in ids:
126 track_bboxes = np.concatenate(
127 (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
128 track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
129 # track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
130
131 # compute distance
132 ious = bbox_overlaps(track_bboxes, det_bboxes[:, :4])
133 if weight_iou_with_det_scores:
134 ious *= det_bboxes[:, 4][None]
135
136 # support multi-class association
137 track_labels = torch.tensor([
138 self.tracks[id]['labels'][-1] for id in ids
139 ]).to(det_bboxes.device)
140
141 cate_match = det_labels[None, :] == track_labels[:, None]
142 # to avoid det and track of different categories are matched
143 cate_cost = (1 - cate_match.int()) * 1e6
144
145 dists = (1 - ious + cate_cost).cpu().numpy()
146
147 # bipartite match
148 if dists.size > 0:
149 cost, row, col = lap.lapjv(
150 dists, extend_cost=True, cost_limit=1 - match_iou_thr)
151 else:
152 row = np.zeros(len(ids)).astype(np.int32) - 1
153 col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
154 return row, col
155
156 def track(self,
157 img,
158 mask,
159 motion,
160 bboxes,
161 labels,
162 frame_id):
163 """Tracking forward function.
164
165 Args:
166 img (Tensor): of shape (N, C, H, W) encoding input images.
167 Typically these should be mean centered and std scaled.
168 img_metas (list[dict]): list of image info dict where each dict
169 has: 'img_shape', 'scale_factor', 'flip', and may also contain
170 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
171 model (nn.Module): MOT model.
172 bboxes (Tensor): of shape (N, 5).
173 labels (Tensor): of shape (N, ).
174 frame_id (int): The id of current frame, 0-index.
175 rescale (bool, optional): If True, the bounding boxes should be
176 rescaled to fit the original scale of the image. Defaults to
177 False.
178 Returns:
179 tuple: Tracking results.
180 """
181 if not hasattr(self, 'of'):
182 self.of = motion
183
184 self.of.preprocessing(img, mask)
185 for id in self.tracks.keys():
186 self.tracks[id].mean, self.tracks[id].valid = self.of.predict(self.tracks[id].mean)
187
188 if self.empty or bboxes.size(0) == 0:
189 valid_inds = bboxes[:, -1] > self.init_track_thr
190 bboxes = bboxes[valid_inds]
191 labels = labels[valid_inds]
192 num_new_tracks = bboxes.size(0)
193 ids = torch.arange(self.num_tracks,
194 self.num_tracks + num_new_tracks).to(labels)
195 self.num_tracks += num_new_tracks
196
197 else:
198 # 0. init
199 ids = torch.full((bboxes.size(0), ),
200 -1,
201 dtype=labels.dtype,
202 device=labels.device)
203
204 # get the detection bboxes for the first association
205 first_det_inds = bboxes[:, -1] > self.obj_score_thrs['high']
206 first_det_bboxes = bboxes[first_det_inds]
207 first_det_labels = labels[first_det_inds]
208 first_det_ids = ids[first_det_inds]
209
210 # get the detection bboxes for the second association
211 second_det_inds = (~first_det_inds) & (
212 bboxes[:, -1] > self.obj_score_thrs['low'])
213 second_det_bboxes = bboxes[second_det_inds]
214 second_det_labels = labels[second_det_inds]
215 second_det_ids = ids[second_det_inds]
216
217 # 1. use optical flow to predict current location
218 # for id in self.tracks.keys():
219 # self.tracks[id].mean = self.of.predict(self.tracks[id].mean)
220
221
222 # 2. first match
223 first_match_track_inds, first_match_det_inds = self.assign_ids(
224 self.confirmed_idsconfirmed_ids, first_det_bboxes, first_det_labels,
226 # '-1' mean a detection box is not matched with tracklets in
227 # previous frame
228 valid = first_match_det_inds > -1
229 first_det_ids[valid] = torch.tensor(
230 self.confirmed_idsconfirmed_ids)[first_match_det_inds[valid]].to(labels)
231
232 first_match_det_bboxes = first_det_bboxes[valid]
233 first_match_det_labels = first_det_labels[valid]
234 first_match_det_ids = first_det_ids[valid]
235 assert (first_match_det_ids > -1).all()
236
237 first_unmatch_det_bboxes = first_det_bboxes[~valid]
238 first_unmatch_det_labels = first_det_labels[~valid]
239 first_unmatch_det_ids = first_det_ids[~valid]
240 assert (first_unmatch_det_ids == -1).all()
241
242 # 3. use unmatched detection bboxes from the first match to match
243 # the unconfirmed tracks
244 (tentative_match_track_inds,
245 tentative_match_det_inds) = self.assign_ids(
246 self.unconfirmed_idsunconfirmed_ids, first_unmatch_det_bboxes,
247 first_unmatch_det_labels, self.weight_iou_with_det_scores,
248 self.match_iou_thrs['tentative'])
249 valid = tentative_match_det_inds > -1
250 first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_idsunconfirmed_ids)[
251 tentative_match_det_inds[valid]].to(labels)
252
253 # 4. second match for unmatched tracks from the first match
254 first_unmatch_track_ids = []
255 for i, id in enumerate(self.confirmed_idsconfirmed_ids):
256 # tracklet is not matched in the first match
257 case_1 = first_match_track_inds[i] == -1
258 # tracklet is not lost in the previous frame
259 case_2 = True#self.tracks[id].frame_ids[-1] == frame_id - 1
260 if case_1 and case_2:
261 first_unmatch_track_ids.append(id)
262
263 second_match_track_inds, second_match_det_inds = self.assign_ids(
264 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
265 False, self.match_iou_thrs['low'])
266 valid = second_match_det_inds > -1
267 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
268 second_match_det_inds[valid]].to(ids)
269
270 # 5. gather all matched detection bboxes from step 2-4
271 # we only keep matched detection bboxes in second match, which
272 # means the id != -1
273 valid = second_det_ids > -1
274 bboxes = torch.cat(
275 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
276 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
277
278 labels = torch.cat(
279 (first_match_det_labels, first_unmatch_det_labels), dim=0)
280 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
281
282 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
283 dim=0)
284 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
285
286 # 6. assign new ids
287 new_track_inds = ids == -1
288 ids[new_track_inds] = torch.arange(
290 self.num_tracks + new_track_inds.sum()).to(labels)
291 self.num_tracks += new_track_inds.sum()
292
293 self.update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
294 self.of.postprocessing()
295
296 return bboxes, labels, ids
assign_ids(self, ids, det_bboxes, det_labels, weight_iou_with_det_scores=False, match_iou_thr=0.5)
__init__(self, obj_score_thrs=dict(high=0.6, low=0.1), init_track_thr=0.7, weight_iou_with_det_scores=True, match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3), num_tentatives=3, **kwargs)
track(self, img, mask, motion, bboxes, labels, frame_id)