Safemotion Lib
Loading...
Searching...
No Matches
base_tracker.py
Go to the documentation of this file.
1from abc import abstractmethod
2from addict import Dict
3import torch
4
6 """Base tracker model.
7
8 Args:
9 momentums (dict[str:float], optional): Momentums to update the buffers.
10 The `str` indicates the name of the buffer while the `float`
11 indicates the momentum. Default to None.
12 num_frames_retain (int, optional). If a track is disappeared more than
13 `num_frames_retain` frames, it will be deleted in the memo.
14 init_cfg (dict or list[dict], optional): Initialization config dict.
15 Defaults to None.
16 """
17
18 def __init__(self, momentums=None, num_frames_retain=30):
19 # super().__init__(init_cfg)
20 if momentums is not None:
21 assert isinstance(momentums, dict), 'momentums must be a dict'
22 self.momentums = momentums
23 self.num_frames_retain = num_frames_retain
24 self.fp16_enabled = False
25
26 self.reset()
27
28 def reset(self):
29 """Reset the buffer of the tracker."""
30 self.num_tracks = 0
31 self.tracks = dict()
32
33 @property
34 def empty(self):
35 """Whether the buffer is empty or not."""
36 return False if self.tracks else True
37
38 @property
39 def ids(self):
40 """All ids in the tracker."""
41 return list(self.tracks.keys())
42
43 @property
44 def with_reid(self):
45 """bool: whether the framework has a reid model"""
46 return hasattr(self, 'reid') and self.reid is not None
47
48 def update(self, **kwargs):
49 """Update the tracker.
50
51 Args:
52 kwargs (dict[str: Tensor | int]): The `str` indicates the
53 name of the input variable. `ids` and `frame_ids` are
54 obligatory in the keys.
55 """
56
57 memo_items = [k for k, v in kwargs.items() if v is not None]
58 rm_items = [k for k in kwargs.keys() if k not in memo_items]
59 for item in rm_items:
60 kwargs.pop(item)
61 if not hasattr(self, 'memo_items'):
62 self.memo_items = memo_items
63 else:
64 assert memo_items == self.memo_items
65
66 assert 'ids' in memo_items
67 num_objs = len(kwargs['ids'])
68 id_indice = memo_items.index('ids')
69 assert 'frame_ids' in memo_items
70 frame_id = int(kwargs['frame_ids'])
71 if isinstance(kwargs['frame_ids'], int):
72 kwargs['frame_ids'] = torch.tensor([kwargs['frame_ids']] * num_objs)
73 # cur_frame_id = int(kwargs['frame_ids'][0])
74 for k, v in kwargs.items():
75 if len(v) != num_objs:
76 raise ValueError()
77
78 for obj in zip(*kwargs.values()):
79 id = int(obj[id_indice])
80 if id in self.tracks:
81 self.update_track(id, obj)
82 else:
83 self.init_track(id, obj)
84
85 self.pop_invalid_tracks(frame_id)
86
87 def pop_invalid_tracks(self, frame_id):
88 """Pop out invalid tracks."""
89 invalid_ids = []
90 for k, v in self.tracks.items():
91 if frame_id - v['frame_ids'][-1] >= self.num_frames_retain:
92 invalid_ids.append(k)
93 for invalid_id in invalid_ids:
94 self.tracks.pop(invalid_id)
95
96 def update_track(self, id, obj):
97 """Update a track."""
98 for k, v in zip(self.memo_items, obj):
99 v = v[None]
100 if self.momentums is not None and k in self.momentums:
101 m = self.momentums[k]
102 self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v
103 else:
104 self.tracks[id][k].append(v)
105
106 def init_track(self, id, obj):
107 """Initialize a track."""
108 self.tracks[id] = Dict()
109 for k, v in zip(self.memo_items, obj):
110 v = v[None]
111 if self.momentums is not None and k in self.momentums:
112 self.tracks[id][k] = v
113 else:
114 self.tracks[id][k] = [v]
115
116 @abstractmethod
117 def track(self, *args, **kwargs):
118 """Tracking forward function."""
119 pass
update_track(self, id, obj)
track(self, *args, **kwargs)
update(self, **kwargs)
__init__(self, momentums=None, num_frames_retain=30)
pop_invalid_tracks(self, frame_id)
init_track(self, id, obj)