Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | List of all members
byte_tracker_mod.ByteTrackerMod Class Reference
Inheritance diagram for byte_tracker_mod.ByteTrackerMod:

Public Member Functions

 __init__ (self, obj_score_thrs=dict(high=0.6, low=0.1), init_track_thr=0.7, weight_iou_with_det_scores=True, match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3), num_tentatives=3, **kwargs)
 
 confirmed_ids (self)
 
 unconfirmed_ids (self)
 
 init_track (self, id, obj)
 
 update_track (self, id, obj)
 
 pop_invalid_tracks (self, frame_id)
 
 assign_ids (self, ids, det_bboxes, det_labels, weight_iou_with_det_scores=False, match_iou_thr=0.5)
 
 track (self, img, mask, motion, bboxes, labels, frame_id)
 

Public Attributes

 obj_score_thrs
 
 init_track_thr
 
 weight_iou_with_det_scores
 
 match_iou_thrs
 
 num_tentatives
 
 of
 
 confirmed_ids
 
 unconfirmed_ids
 
 num_tracks
 

Detailed Description

Tracker for ByteTrack.
Args:
    obj_score_thrs (dict): Detection score threshold for matching objects.
        - high (float): Threshold of the first matching. Defaults to 0.6.
        - low (float): Threshold of the second matching. Defaults to 0.1.
    init_track_thr (float): Detection score threshold for initializing a
        new tracklet. Defaults to 0.7.
    weight_iou_with_det_scores (bool): Whether using detection scores to
        weight IOU which is used for matching. Defaults to True.
    match_iou_thrs (dict): IOU distance threshold for matching between two
        frames.
        - high (float): Threshold of the first matching. Defaults to 0.1.
        - low (float): Threshold of the second matching. Defaults to 0.5.
        - tentative (float): Threshold of the matching for tentative
            tracklets. Defaults to 0.3.
    num_tentatives (int, optional): Number of continuous frames to confirm
        a track. Defaults to 3.
    init_cfg (dict or list[dict], optional): Initialization config dict.
        Defaults to None.

Definition at line 8 of file byte_tracker_mod.py.

Constructor & Destructor Documentation

◆ __init__()

byte_tracker_mod.ByteTrackerMod.__init__ ( self,
obj_score_thrs = dict(high=0.6, low=0.1),
init_track_thr = 0.7,
weight_iou_with_det_scores = True,
match_iou_thrs = dict(high=0.1, low=0.5, tentative=0.3),
num_tentatives = 3,
** kwargs )

Definition at line 30 of file byte_tracker_mod.py.

36 **kwargs):
37 super().__init__(**kwargs)
38 self.obj_score_thrs = obj_score_thrs
39 self.init_track_thr = init_track_thr
40
41 self.weight_iou_with_det_scores = weight_iou_with_det_scores
42 self.match_iou_thrs = match_iou_thrs
43
44 self.num_tentatives = num_tentatives
45
46 # self.kf = KalmanFilter()
47

Member Function Documentation

◆ assign_ids()

byte_tracker_mod.ByteTrackerMod.assign_ids ( self,
ids,
det_bboxes,
det_labels,
weight_iou_with_det_scores = False,
match_iou_thr = 0.5 )
Assign ids.

Args:
    ids (list[int]): Tracking ids.
    det_bboxes (Tensor): of shape (N, 5)
    weight_iou_with_det_scores (bool, optional): Whether using
        detection scores to weight IOU which is used for matching.
        Defaults to False.
    match_iou_thr (float, optional): Matching threshold.
        Defaults to 0.5.
Returns:
    tuple(int): The assigning ids.

Definition at line 104 of file byte_tracker_mod.py.

109 match_iou_thr=0.5):
110 """Assign ids.
111
112 Args:
113 ids (list[int]): Tracking ids.
114 det_bboxes (Tensor): of shape (N, 5)
115 weight_iou_with_det_scores (bool, optional): Whether using
116 detection scores to weight IOU which is used for matching.
117 Defaults to False.
118 match_iou_thr (float, optional): Matching threshold.
119 Defaults to 0.5.
120 Returns:
121 tuple(int): The assigning ids.
122 """
123 # get track_bboxes
124 track_bboxes = np.zeros((0, 4))
125 for id in ids:
126 track_bboxes = np.concatenate(
127 (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
128 track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
129 # track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)
130
131 # compute distance
132 ious = bbox_overlaps(track_bboxes, det_bboxes[:, :4])
133 if weight_iou_with_det_scores:
134 ious *= det_bboxes[:, 4][None]
135
136 # support multi-class association
137 track_labels = torch.tensor([
138 self.tracks[id]['labels'][-1] for id in ids
139 ]).to(det_bboxes.device)
140
141 cate_match = det_labels[None, :] == track_labels[:, None]
142 # to avoid det and track of different categories are matched
143 cate_cost = (1 - cate_match.int()) * 1e6
144
145 dists = (1 - ious + cate_cost).cpu().numpy()
146
147 # bipartite match
148 if dists.size > 0:
149 cost, row, col = lap.lapjv(
150 dists, extend_cost=True, cost_limit=1 - match_iou_thr)
151 else:
152 row = np.zeros(len(ids)).astype(np.int32) - 1
153 col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
154 return row, col
155

◆ confirmed_ids()

byte_tracker_mod.ByteTrackerMod.confirmed_ids ( self)
Confirmed ids in the tracker.

Definition at line 49 of file byte_tracker_mod.py.

49 def confirmed_ids(self):
50 """Confirmed ids in the tracker."""
51 ids = [id for id, track in self.tracks.items() if not track.tentative]
52 return ids
53

◆ init_track()

byte_tracker_mod.ByteTrackerMod.init_track ( self,
id,
obj )
Initialize a track.

Definition at line 60 of file byte_tracker_mod.py.

60 def init_track(self, id, obj):
61 """Initialize a track."""
62 super().init_track(id, obj)
63 if self.tracks[id].frame_ids[-1] == 0:
64 self.tracks[id].tentative = False
65 else:
66 self.tracks[id].tentative = True
67 bbox = self.tracks[id].bboxes[-1][:, :4]
68 assert bbox.ndim == 2 and bbox.shape[0] == 1
69 bbox = bbox.squeeze(0).cpu().numpy()
70 self.tracks[id].mean, self.tracks[id].valid = self.of.initiate(bbox)
71

◆ pop_invalid_tracks()

byte_tracker_mod.ByteTrackerMod.pop_invalid_tracks ( self,
frame_id )
Pop out invalid tracks.

Definition at line 87 of file byte_tracker_mod.py.

87 def pop_invalid_tracks(self, frame_id):
88 """Pop out invalid tracks."""
89 invalid_ids = []
90 for k, v in self.tracks.items():
91 # case1: disappeared frames >= self.num_frames_retrain
92 case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
93 # case2: tentative tracks but not matched in this frame
94 case2 = v.tentative and frame_id - v['frame_ids'][-1] >= self.num_frames_retain//3
95
96 case3 = v.valid and frame_id - v['frame_ids'][-1] >= self.num_frames_retain//3
97
98 if case1 or case2 or case3:
99 invalid_ids.append(k)
100
101 for invalid_id in invalid_ids:
102 self.tracks.pop(invalid_id)
103

◆ track()

byte_tracker_mod.ByteTrackerMod.track ( self,
img,
mask,
motion,
bboxes,
labels,
frame_id )
Tracking forward function.

Args:
    img (Tensor): of shape (N, C, H, W) encoding input images.
        Typically these should be mean centered and std scaled.
    img_metas (list[dict]): list of image info dict where each dict
        has: 'img_shape', 'scale_factor', 'flip', and may also contain
        'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
    model (nn.Module): MOT model.
    bboxes (Tensor): of shape (N, 5).
    labels (Tensor): of shape (N, ).
    frame_id (int): The id of current frame, 0-index.
    rescale (bool, optional): If True, the bounding boxes should be
        rescaled to fit the original scale of the image. Defaults to
        False.
Returns:
    tuple: Tracking results.

Definition at line 156 of file byte_tracker_mod.py.

162 frame_id):
163 """Tracking forward function.
164
165 Args:
166 img (Tensor): of shape (N, C, H, W) encoding input images.
167 Typically these should be mean centered and std scaled.
168 img_metas (list[dict]): list of image info dict where each dict
169 has: 'img_shape', 'scale_factor', 'flip', and may also contain
170 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
171 model (nn.Module): MOT model.
172 bboxes (Tensor): of shape (N, 5).
173 labels (Tensor): of shape (N, ).
174 frame_id (int): The id of current frame, 0-index.
175 rescale (bool, optional): If True, the bounding boxes should be
176 rescaled to fit the original scale of the image. Defaults to
177 False.
178 Returns:
179 tuple: Tracking results.
180 """
181 if not hasattr(self, 'of'):
182 self.of = motion
183
184 self.of.preprocessing(img, mask)
185 for id in self.tracks.keys():
186 self.tracks[id].mean, self.tracks[id].valid = self.of.predict(self.tracks[id].mean)
187
188 if self.empty or bboxes.size(0) == 0:
189 valid_inds = bboxes[:, -1] > self.init_track_thr
190 bboxes = bboxes[valid_inds]
191 labels = labels[valid_inds]
192 num_new_tracks = bboxes.size(0)
193 ids = torch.arange(self.num_tracks,
194 self.num_tracks + num_new_tracks).to(labels)
195 self.num_tracks += num_new_tracks
196
197 else:
198 # 0. init
199 ids = torch.full((bboxes.size(0), ),
200 -1,
201 dtype=labels.dtype,
202 device=labels.device)
203
204 # get the detection bboxes for the first association
205 first_det_inds = bboxes[:, -1] > self.obj_score_thrs['high']
206 first_det_bboxes = bboxes[first_det_inds]
207 first_det_labels = labels[first_det_inds]
208 first_det_ids = ids[first_det_inds]
209
210 # get the detection bboxes for the second association
211 second_det_inds = (~first_det_inds) & (
212 bboxes[:, -1] > self.obj_score_thrs['low'])
213 second_det_bboxes = bboxes[second_det_inds]
214 second_det_labels = labels[second_det_inds]
215 second_det_ids = ids[second_det_inds]
216
217 # 1. use optical flow to predict current location
218 # for id in self.tracks.keys():
219 # self.tracks[id].mean = self.of.predict(self.tracks[id].mean)
220
221
222 # 2. first match
223 first_match_track_inds, first_match_det_inds = self.assign_ids(
224 self.confirmed_ids, first_det_bboxes, first_det_labels,
225 self.weight_iou_with_det_scores, self.match_iou_thrs['high'])
226 # '-1' mean a detection box is not matched with tracklets in
227 # previous frame
228 valid = first_match_det_inds > -1
229 first_det_ids[valid] = torch.tensor(
230 self.confirmed_ids)[first_match_det_inds[valid]].to(labels)
231
232 first_match_det_bboxes = first_det_bboxes[valid]
233 first_match_det_labels = first_det_labels[valid]
234 first_match_det_ids = first_det_ids[valid]
235 assert (first_match_det_ids > -1).all()
236
237 first_unmatch_det_bboxes = first_det_bboxes[~valid]
238 first_unmatch_det_labels = first_det_labels[~valid]
239 first_unmatch_det_ids = first_det_ids[~valid]
240 assert (first_unmatch_det_ids == -1).all()
241
242 # 3. use unmatched detection bboxes from the first match to match
243 # the unconfirmed tracks
244 (tentative_match_track_inds,
245 tentative_match_det_inds) = self.assign_ids(
246 self.unconfirmed_ids, first_unmatch_det_bboxes,
247 first_unmatch_det_labels, self.weight_iou_with_det_scores,
248 self.match_iou_thrs['tentative'])
249 valid = tentative_match_det_inds > -1
250 first_unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
251 tentative_match_det_inds[valid]].to(labels)
252
253 # 4. second match for unmatched tracks from the first match
254 first_unmatch_track_ids = []
255 for i, id in enumerate(self.confirmed_ids):
256 # tracklet is not matched in the first match
257 case_1 = first_match_track_inds[i] == -1
258 # tracklet is not lost in the previous frame
259 case_2 = True#self.tracks[id].frame_ids[-1] == frame_id - 1
260 if case_1 and case_2:
261 first_unmatch_track_ids.append(id)
262
263 second_match_track_inds, second_match_det_inds = self.assign_ids(
264 first_unmatch_track_ids, second_det_bboxes, second_det_labels,
265 False, self.match_iou_thrs['low'])
266 valid = second_match_det_inds > -1
267 second_det_ids[valid] = torch.tensor(first_unmatch_track_ids)[
268 second_match_det_inds[valid]].to(ids)
269
270 # 5. gather all matched detection bboxes from step 2-4
271 # we only keep matched detection bboxes in second match, which
272 # means the id != -1
273 valid = second_det_ids > -1
274 bboxes = torch.cat(
275 (first_match_det_bboxes, first_unmatch_det_bboxes), dim=0)
276 bboxes = torch.cat((bboxes, second_det_bboxes[valid]), dim=0)
277
278 labels = torch.cat(
279 (first_match_det_labels, first_unmatch_det_labels), dim=0)
280 labels = torch.cat((labels, second_det_labels[valid]), dim=0)
281
282 ids = torch.cat((first_match_det_ids, first_unmatch_det_ids),
283 dim=0)
284 ids = torch.cat((ids, second_det_ids[valid]), dim=0)
285
286 # 6. assign new ids
287 new_track_inds = ids == -1
288 ids[new_track_inds] = torch.arange(
289 self.num_tracks,
290 self.num_tracks + new_track_inds.sum()).to(labels)
291 self.num_tracks += new_track_inds.sum()
292
293 self.update(ids=ids, bboxes=bboxes, labels=labels, frame_ids=frame_id)
294 self.of.postprocessing()
295
296 return bboxes, labels, ids

◆ unconfirmed_ids()

byte_tracker_mod.ByteTrackerMod.unconfirmed_ids ( self)
Unconfirmed ids in the tracker.

Definition at line 55 of file byte_tracker_mod.py.

55 def unconfirmed_ids(self):
56 """Unconfirmed ids in the tracker."""
57 ids = [id for id, track in self.tracks.items() if track.tentative]
58 return ids
59

◆ update_track()

byte_tracker_mod.ByteTrackerMod.update_track ( self,
id,
obj )
Update a track.

Definition at line 72 of file byte_tracker_mod.py.

72 def update_track(self, id, obj):
73 """Update a track."""
74 super().update_track(id, obj)
75 if self.tracks[id].tentative:
76 if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
77 self.tracks[id].tentative = False
78 bbox = self.tracks[id].bboxes[-1]#[:, :4] # size = (1, 4)
79 assert bbox.ndim == 2 and bbox.shape[0] == 1
80 bbox = bbox.squeeze(0).cpu().numpy()
81 track_label = self.tracks[id]['labels'][-1]
82 label_idx = self.memo_items.index('labels')
83 obj_label = obj[label_idx]
84 assert obj_label == track_label
85 self.tracks[id].mean = self.of.update(self.tracks[id].mean, bbox)
86

Member Data Documentation

◆ confirmed_ids

byte_tracker_mod.ByteTrackerMod.confirmed_ids

Definition at line 224 of file byte_tracker_mod.py.

◆ init_track_thr

byte_tracker_mod.ByteTrackerMod.init_track_thr

Definition at line 39 of file byte_tracker_mod.py.

◆ match_iou_thrs

byte_tracker_mod.ByteTrackerMod.match_iou_thrs

Definition at line 42 of file byte_tracker_mod.py.

◆ num_tentatives

byte_tracker_mod.ByteTrackerMod.num_tentatives

Definition at line 44 of file byte_tracker_mod.py.

◆ num_tracks

byte_tracker_mod.ByteTrackerMod.num_tracks

Definition at line 289 of file byte_tracker_mod.py.

◆ obj_score_thrs

byte_tracker_mod.ByteTrackerMod.obj_score_thrs

Definition at line 38 of file byte_tracker_mod.py.

◆ of

byte_tracker_mod.ByteTrackerMod.of

Definition at line 182 of file byte_tracker_mod.py.

◆ unconfirmed_ids

byte_tracker_mod.ByteTrackerMod.unconfirmed_ids

Definition at line 246 of file byte_tracker_mod.py.

◆ weight_iou_with_det_scores

byte_tracker_mod.ByteTrackerMod.weight_iou_with_det_scores

Definition at line 41 of file byte_tracker_mod.py.


The documentation for this class was generated from the following file: