Safemotion Lib
Loading...
Searching...
No Matches
bases.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: sherlock
4@contact: sherlockliao01@gmail.com
5"""
6
7import copy
8import logging
9import os
10from tabulate import tabulate
11from termcolor import colored
12
13logger = logging.getLogger(__name__)
14
15
16class Dataset(object):
17 """An abstract class representing a Dataset.
18 This is the base class for ``ImageDataset`` and ``VideoDataset``.
19 Args:
20 train (list): contains tuples of (img_path(s), pid, camid).
21 query (list): contains tuples of (img_path(s), pid, camid).
22 gallery (list): contains tuples of (img_path(s), pid, camid).
23 transform: transform function.
24 mode (str): 'train', 'query' or 'gallery'.
25 combineall (bool): combines train, query and gallery in a
26 dataset for training.
27 verbose (bool): show information.
28 """
29 _junk_pids = [] # contains useless person IDs, e.g. background, false detections
30
31 def __init__(self, train, query, gallery, transform=None, mode='train',
32 combineall=False, verbose=True, **kwargs):
33 self.train = train
34 self.query = query
35 self.gallery = gallery
36 self.transform = transform
37 self.mode = mode
38 self.combineall = combineall
39 self.verbose = verbose
40
43
44 if self.combineall:
45 self.combine_all()
46
47 if self.mode == 'train':
48 self.data = self.train
49 elif self.mode == 'query':
50 self.data = self.query
51 elif self.mode == 'gallery':
52 self.data = self.gallery
53 else:
54 raise ValueError('Invalid mode. Got {}, but expected to be '
55 'one of [train | query | gallery]'.format(self.mode))
56
57 def __getitem__(self, index):
58 raise NotImplementedError
59
60 def __len__(self):
61 return len(self.data)
62
63 def __radd__(self, other):
64 """Supports sum([dataset1, dataset2, dataset3])."""
65 if other == 0:
66 return self
67 else:
68 return self.__add__(other)
69
70 def parse_data(self, data):
71 """Parses data list and returns the number of person IDs
72 and the number of camera views.
73 Args:
74 data (list): contains tuples of (img_path(s), pid, camid)
75 """
76 pids = set()
77 cams = set()
78 for _, pid, camid in data:
79 pids.add(pid)
80 cams.add(camid)
81 return len(pids), len(cams)
82
83 def get_num_pids(self, data):
84 """Returns the number of training person identities."""
85 return self.parse_data(data)[0]
86
87 def get_num_cams(self, data):
88 """Returns the number of training cameras."""
89 return self.parse_data(data)[1]
90
91 def show_summary(self):
92 """Shows dataset statistics."""
93 pass
94
95 def combine_all(self):
96 """Combines train, query and gallery in a dataset for training."""
97 combined = copy.deepcopy(self.train)
98
99 def _combine_data(data):
100 for img_path, pid, camid in data:
101 if pid in self._junk_pids:
102 continue
103 pid = self.dataset_name + "_" + str(pid)
104 camid = self.dataset_name + "_" + str(camid)
105 combined.append((img_path, pid, camid))
106
107 _combine_data(self.query)
108 _combine_data(self.gallery)
109
110 self.train = combined
111 self.num_train_pids = self.get_num_pids(self.train)
112
113 def check_before_run(self, required_files):
114 """Checks if required files exist before going deeper.
115 Args:
116 required_files (str or list): string file name(s).
117 """
118 if isinstance(required_files, str):
119 required_files = [required_files]
120
121 for fpath in required_files:
122 if not os.path.exists(fpath):
123 raise RuntimeError('"{}" is not found'.format(fpath))
124
125
127 """A base class representing ImageDataset.
128 All other image datasets should subclass it.
129 ``__getitem__`` returns an image given index.
130 It will return ``img``, ``pid``, ``camid`` and ``img_path``
131 where ``img`` has shape (channel, height, width). As a result,
132 data in each batch has shape (batch_size, channel, height, width).
133 """
134
135 def __init__(self, train, query, gallery, **kwargs):
136 super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
137
138 def show_train(self):
139 num_train_pids, num_train_cams = self.parse_data(self.train)
140
141 headers = ['subset', '# ids', '# images', '# cameras']
142 csv_results = [['train', num_train_pids, len(self.train), num_train_cams]]
143
144 # tabulate it
145 table = tabulate(
146 csv_results,
147 tablefmt="pipe",
148 headers=headers,
149 numalign="left",
150 )
151 logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
152
153 def show_test(self):
154 num_query_pids, num_query_cams = self.parse_data(self.queryquery)
155 num_gallery_pids, num_gallery_cams = self.parse_data(self.gallerygallery)
156
157 headers = ['subset', '# ids', '# images', '# cameras']
158 csv_results = [
159 ['query', num_query_pids, len(self.queryquery), num_query_cams],
160 ['gallery', num_gallery_pids, len(self.gallerygallery), num_gallery_cams],
161 ]
162
163 # tabulate it
164 table = tabulate(
165 csv_results,
166 tablefmt="pipe",
167 headers=headers,
168 numalign="left",
169 )
170 logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
check_before_run(self, required_files)
Definition bases.py:113
__init__(self, train, query, gallery, transform=None, mode='train', combineall=False, verbose=True, **kwargs)
Definition bases.py:32
__init__(self, train, query, gallery, **kwargs)
Definition bases.py:135