Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
byte_tracker.ByteTracker Class Reference
Inheritance diagram for byte_tracker.ByteTracker:
base_tracker.BaseTracker

Public Member Functions

 __init__ (self, obj_score_thrs=dict(high=0.6, low=0.1), init_track_thr=0.7, weight_iou_with_det_scores=True, match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3), num_tentatives=3, **kwargs)
 
 confirmed_ids (self)
 
 unconfirmed_ids (self)
 
 init_track (self, id, obj)
 
 update_track (self, id, obj)
 
 pop_invalid_tracks (self, frame_id)
 
 assign_ids (self, ids, det_bboxes, det_labels, weight_iou_with_det_scores=False, match_iou_thr=0.5)
 
 track (self, motion, bboxes, labels, frame_id)
 
- Public Member Functions inherited from base_tracker.BaseTracker
 reset (self)
 
 empty (self)
 
 ids (self)
 
 with_reid (self)
 
 update (self, **kwargs)
 

Public Attributes

 obj_score_thrs
 
 init_track_thr
 
 weight_iou_with_det_scores
 
 match_iou_thrs
 
 num_tentatives
 
 kf
 
 confirmed_ids
 
 unconfirmed_ids
 
 num_tracks
 
- Public Attributes inherited from base_tracker.BaseTracker
 momentums
 
 num_frames_retain
 
 fp16_enabled
 
 num_tracks
 
 tracks
 
 memo_items
 

Detailed Description

Tracker for ByteTrack.
Args:
    obj_score_thrs (dict): Detection score threshold for matching objects.
        - high (float): Threshold of the first matching. Defaults to 0.6.
        - low (float): Threshold of the second matching. Defaults to 0.1.
    init_track_thr (float): Detection score threshold for initializing a
        new tracklet. Defaults to 0.7.
    weight_iou_with_det_scores (bool): Whether using detection scores to
        weight IOU which is used for matching. Defaults to True.
    match_iou_thrs (dict): IOU distance threshold for matching between two
        frames.
        - high (float): Threshold of the first matching. Defaults to 0.1.
        - low (float): Threshold of the second matching. Defaults to 0.5.
        - tentative (float): Threshold of the matching for tentative
            tracklets. Defaults to 0.3.
    num_tentatives (int, optional): Number of continuous frames to confirm
        a track. Defaults to 3.
    init_cfg (dict or list[dict], optional): Initialization config dict.
        Defaults to None.

Definition at line 9 of file byte_tracker.py.

Constructor & Destructor Documentation

◆ __init__()

byte_tracker.ByteTracker.__init__ ( self,
obj_score_thrs = dict(high=0.6, low=0.1),
init_track_thr = 0.7,
weight_iou_with_det_scores = True,
match_iou_thrs = dict(high=0.1, low=0.5, tentative=0.3),
num_tentatives = 3,
** kwargs )

Reimplemented from base_tracker.BaseTracker.

Definition at line 31 of file byte_tracker.py.

37 **kwargs):
38 super().__init__(**kwargs)
39 self.obj_score_thrs = obj_score_thrs
40 self.init_track_thr = init_track_thr
41
42 self.weight_iou_with_det_scores = weight_iou_with_det_scores
43 self.match_iou_thrs = match_iou_thrs
44
45 self.num_tentatives = num_tentatives
46
47 # self.kf = KalmanFilter()
48

Member Function Documentation

◆ assign_ids()

byte_tracker.ByteTracker.assign_ids ( self,
ids,
det_bboxes,
det_labels,
weight_iou_with_det_scores = False,
match_iou_thr = 0.5 )
Assign ids.

Args:
    ids (list[int]): Tracking ids.
    det_bboxes (Tensor): of shape (N, 5)
    weight_iou_with_det_scores (bool, optional): Whether using
        detection scores to weight IOU which is used for matching.
        Defaults to False.
    match_iou_thr (float, optional): Matching threshold.
        Defaults to 0.5.
Returns:
    tuple(int): The assigning ids.

Definition at line 102 of file byte_tracker.py.

107 match_iou_thr=0.5):
108 """Assign ids.
109
110 Args:
111 ids (list[int]): Tracking ids.
112 det_bboxes (Tensor): of shape (N, 5)
113 weight_iou_with_det_scores (bool, optional): Whether using
114 detection scores to weight IOU which is used for matching.
115 Defaults to False.
116 match_iou_thr (float, optional): Matching threshold.
117 Defaults to 0.5.
118 Returns:
119 tuple(int): The assigning ids.
120 """
121 # get track_bboxes
122 track_bboxes = np.zeros((0, 4))
123 for id in ids:
124 track_bboxes = np.concatenate(
125 (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
126 track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
127 track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
128
129 # compute distance
130 ious = bbox_overlaps(track_bboxes, det_bboxes[:, :4])
131 if weight_iou_with_det_scores:
132 ious *= det_bboxes[:, 4][None]
133
134 # support multi-class association
135 track_labels = torch.tensor([
136 self.tracks[id]['labels'][-1] for id in ids
137 ]).to(det_bboxes.device)
138
139 cate_match = det_labels[None, :] == track_labels[:, None]
140 # to avoid det and track of different categories are matched
141 cate_cost = (1 - cate_match.int()) * 1e6
142
143 dists = (1 - ious + cate_cost).cpu().numpy()
144
145 # bipartite match
146 if dists.size > 0:
147 cost, row, col = lap.lapjv(
148 dists, extend_cost=True, cost_limit=1 - match_iou_thr)
149 else:
150 row = np.zeros(len(ids)).astype(np.int32) - 1
151 col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
152 return row, col
153

◆ confirmed_ids()

byte_tracker.ByteTracker.confirmed_ids ( self)
Confirmed ids in the tracker.

Definition at line 50 of file byte_tracker.py.

50 def confirmed_ids(self):
51 """Confirmed ids in the tracker."""
52 ids = [id for id, track in self.tracks.items() if not track.tentative]
53 return ids
54

◆ init_track()

byte_tracker.ByteTracker.init_track ( self,
id,
obj )
Initialize a track.

Reimplemented from base_tracker.BaseTracker.

Definition at line 61 of file byte_tracker.py.

61 def init_track(self, id, obj):
62 """Initialize a track."""
63 super().init_track(id, obj)
64 if self.tracks[id].frame_ids[-1] == 0:
65 self.tracks[id].tentative = False
66 else:
67 self.tracks[id].tentative = True
68 bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
69 assert bbox.ndim == 2 and bbox.shape[0] == 1
70 bbox = bbox.squeeze(0).cpu().numpy()
71 self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(bbox)
72

◆ pop_invalid_tracks()

byte_tracker.ByteTracker.pop_invalid_tracks ( self,
frame_id )
Pop out invalid tracks.

Reimplemented from base_tracker.BaseTracker.

Definition at line 89 of file byte_tracker.py.

89 def pop_invalid_tracks(self, frame_id):
90 """Pop out invalid tracks."""
91 invalid_ids = []
92 for k, v in self.tracks.items():
93 # case1: disappeared frames >= self.num_frames_retrain
94 case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
95 # case2: tentative tracks but not matched in this frame
96 case2 = v.tentative and v['frame_ids'][-1] != frame_id
97 if case1 or case2:
98 invalid_ids.append(k)
99 for invalid_id in invalid_ids:
100 self.tracks.pop(invalid_id)
101

◆ track()

byte_tracker.ByteTracker.track ( self,
motion,
bboxes,
labels,
frame_id )
Tracking forward function.

Args:
    img (Tensor): of shape (N, C, H, W) encoding input images.
        Typically these should be mean centered and std scaled.
    img_metas (list[dict]): list of image info dict where each dict
        has: 'img_shape', 'scale_factor', 'flip', and may also contain
        'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
    model (nn.Module): MOT model.
    bboxes (Tensor): of shape (N, 5).
    labels (Tensor): of shape (N, ).
    frame_id (int): The id of current frame, 0-index.
    rescale (bool, optional): If True, the bounding boxes should be
        rescaled to fit the original scale of the image. Defaults to
        False.
Returns:
    tuple: Tracking results.

Reimplemented from base_tracker.BaseTracker.

Definition at line 154 of file byte_tracker.py.

158 frame_id):
159 """Tracking forward function.
160
161 Args:
162 img (Tensor): of shape (N, C, H, W) encoding input images.
163 Typically these should be mean centered and std scaled.
164 img_metas (list[dict]): list of image info dict where each dict
165 has: 'img_shape', 'scale_factor', 'flip', and may also contain
166 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
167 model (nn.Module): MOT model.
168 bboxes (Tensor): of shape (N, 5).
169 labels (Tensor): of shape (N, ).
170 frame_id (int): The id of current frame, 0-index.
171 rescale (bool, optional): If True, the bounding boxes should be
172 rescaled to fit the original scale of the image. Defaults to
173 False.
174 Returns:
175 tuple: Tracking results.
176 """
177 if not hasattr(self, 'kf'):
178 self.kf = motion
179
180 if self.empty or bboxes.size(0) == 0:
181 valid_inds = bboxes[:, -1] > self.init_track_thr
182 bboxes = bboxes[valid_inds]
183 labels = labels[valid_inds]
184 num_new_tracks = bboxes.size(0)
185 ids = torch.arange(self.num_tracks,
186 self.num_tracks + num_new_tracks).to(labels)
187 self.num_tracks += num_new_tracks
188
189 else:
190 # 0. init
191 ids = torch.full((bboxes.size(0), ),
192 -1,
193 dtype=labels.dtype,
194 device=labels.device)
195
196 # get the detection bboxes for the first association
197 first_det_inds = bboxes[:, -1] > self.obj_score_thrs['high']
198 first_det_bboxes = bboxes[first_det_inds]
199 first_det_labels = labels[first_det_inds]
200 first_det_ids = ids[first_det_inds]
201
202 # get the detection bboxes for the second association
203 second_det_inds = (~first_det_inds) & (
204 bboxes[:, -1] > self.obj_score_thrs['low'])
205 second_det_bboxes = bboxes[second_det_inds]
206 second_det_labels = labels[second_det_inds]
207 second_det_ids = ids[second_det_inds]
208
209 # 1. use Kalman Filter to predict current location
210 for id in self.confirmed_ids:
211 # track is lost in previous frame
212 if self.tracks[id].frame_ids[-1] != frame_id - 1:
213 self.tracks[id].mean[7] = 0
214 (self.tracks[id].mean,
215 self.tracks[id].covariance) = self.kf.predict(
216 self.tracks[id].mean, self.tracks[id].covariance)
217
218 # 2. first match
219 first_match_track_inds, first_match_det_inds = self.assign_ids(
220 self.confirmed_ids, first_det_bboxes, first_det_labels,
221 self.weight_iou_with_det_scores, self.match_iou_thrs['high'])
222 # '-1' mean a detection box is not matched with tracklets in
223 # previous frame
224 valid = first_match_det_inds > -1
225 first_det_ids[valid] = torch.tensor(
226 self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
227
228 first_match_det_bboxes = first_det_bboxes[valid]
229 first_match_det_labels = first_det_labels[valid]
230 first_match_det_ids = first_det_ids[valid]
231 assert (first_match_det_ids > -1).all()
232
233 first_unmatch_det_bboxes = first_det_bboxes[~valid]
234 first_unmatch_det_labels = first_det_labels[~valid]
235 first_unmatch_det_ids = first_det_ids[~valid]
236 assert (first_unmatch_det_ids == -1).all()
237
238 # 3. use unmatched detection bboxes from the first match to match
239 # the unconfirmed tracks
240 (tentative_match_track_inds,
241 tentative_match_det_inds) = self.assign_ids(
242 self.unconfirmed_ids, first_unmatch_det_bboxes,
243 first_unmatch_det_labels, self.weight_iou_with_det_scores,
244 self.match_iou_thrs['tentative'])
245 valid = tentative_match_det_inds > -1
246 first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
247 tentative_match_det_inds[valid]].to(labels)
248
249 # 4. second match for unmatched tracks from the first match
250 first_unmatch_track_ids = []
251 for i, id in enumerate(self.confirmed_ids):
252 # tracklet is not matched in the first match
253 case_1 = first_match_track_inds[i] == -1
254 # tracklet is not lost in the previous frame
255 case_2 = self.tracks[id].frame_ids[-1] == frame_id - 1
256 if case_1 and case_2:
257 first_unmatch_track_ids.append(id)
258
259 second_match_track_inds, second_match_det_inds = self.assign_ids(
260 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
261 False, self.match_iou_thrs['low'])
262 valid = second_match_det_inds > -1
263 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
264 second_match_det_inds[valid]].to(ids)
265
266 # 5. gather all matched detection bboxes from step 2-4
267 # we only keep matched detection bboxes in second match, which
268 # means the id != -1
269 valid = second_det_ids > -1
270 bboxes = torch.cat(
271 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
272 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
273
274 labels = torch.cat(
275 (first_match_det_labels, first_unmatch_det_labels), dim=0)
276 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
277
278 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
279 dim=0)
280 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
281
282 # 6. assign new ids
283 new_track_inds = ids == -1
284 ids[new_track_inds] = torch.arange(
285 self.num_tracks,
286 self.num_tracks + new_track_inds.sum()).to(labels)
287 self.num_tracks += new_track_inds.sum()
288
289 self.update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
290 return bboxes, labels, ids

◆ unconfirmed_ids()

byte_tracker.ByteTracker.unconfirmed_ids ( self)
Unconfirmed ids in the tracker.

Definition at line 56 of file byte_tracker.py.

56 def unconfirmed_ids(self):
57 """Unconfirmed ids in the tracker."""
58 ids = [id for id, track in self.tracks.items() if track.tentative]
59 return ids
60

◆ update_track()

byte_tracker.ByteTracker.update_track ( self,
id,
obj )
Update a track.

Reimplemented from base_tracker.BaseTracker.

Definition at line 73 of file byte_tracker.py.

73 def update_track(self, id, obj):
74 """Update a track."""
75 super().update_track(id, obj)
76 if self.tracks[id].tentative:
77 if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
78 self.tracks[id].tentative = False
79 bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
80 assert bbox.ndim == 2 and bbox.shape[0] == 1
81 bbox = bbox.squeeze(0).cpu().numpy()
82 track_label = self.tracks[id]['labels'][-1]
83 label_idx = self.memo_items.index('labels')
84 obj_label = obj[label_idx]
85 assert obj_label == track_label
86 self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
87 self.tracks[id].mean, self.tracks[id].covariance, bbox)
88

Member Data Documentation

◆ confirmed_ids

byte_tracker.ByteTracker.confirmed_ids

Definition at line 220 of file byte_tracker.py.

◆ init_track_thr

byte_tracker.ByteTracker.init_track_thr

Definition at line 40 of file byte_tracker.py.

◆ kf

byte_tracker.ByteTracker.kf

Definition at line 178 of file byte_tracker.py.

◆ match_iou_thrs

byte_tracker.ByteTracker.match_iou_thrs

Definition at line 43 of file byte_tracker.py.

◆ num_tentatives

byte_tracker.ByteTracker.num_tentatives

Definition at line 45 of file byte_tracker.py.

◆ num_tracks

byte_tracker.ByteTracker.num_tracks

Definition at line 285 of file byte_tracker.py.

◆ obj_score_thrs

byte_tracker.ByteTracker.obj_score_thrs

Definition at line 39 of file byte_tracker.py.

◆ unconfirmed_ids

byte_tracker.ByteTracker.unconfirmed_ids

Definition at line 242 of file byte_tracker.py.

◆ weight_iou_with_det_scores

byte_tracker.ByteTracker.weight_iou_with_det_scores

Definition at line 42 of file byte_tracker.py.


The documentation for this class was generated from the following file: