Safemotion Lib
Loading...
Searching...
No Matches
triplet_sampler.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: liaoxingyu2@jd.com
5"""
6
7import copy
8import itertools
9from collections import defaultdict
10from typing import Optional
11
12import numpy as np
13from torch.utils.data.sampler import Sampler
14
15from fastreid.utils import comm
16
17
18def no_index(a, b):
19 assert isinstance(a, list)
20 return [i for i, j in enumerate(a) if j != b]
21
22
24 def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
25 self.data_source = data_source
26 self.batch_size = batch_size
27 self.num_instances = num_instances
28 self.num_pids_per_batch = batch_size // self.num_instances
29
30 self.index_pid = defaultdict(list)
31 self.pid_cam = defaultdict(list)
32 self.pid_index = defaultdict(list)
33
34 for index, info in enumerate(data_source):
35 pid = info[1]
36 camid = info[2]
37 self.index_pid[index] = pid
38 self.pid_cam[pid].append(camid)
39 self.pid_index[pid].append(index)
40
41 self.pids = sorted(list(self.pid_index.keys()))
42 self.num_identities = len(self.pids)
43
44 if seed is None:
45 seed = comm.shared_random_seed()
46 self._seed = int(seed)
47
48 self._rank = comm.get_rank()
49 self._world_size = comm.get_world_size()
50
51 def __iter__(self):
52 start = self._rank
53 yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
54
56 np.random.seed(self._seed)
57 while True:
58 # Shuffle identity list
59 identities = np.random.permutation(self.num_identities)
60
61 # If remaining identities cannot be enough for a batch,
62 # just drop the remaining parts
63 drop_indices = self.num_identities % self.num_pids_per_batch
64 if drop_indices: identities = identities[:-drop_indices]
65
66 ret = []
67 for kid in identities:
68 i = np.random.choice(self.pid_index[self.pids[kid]])
69 _, i_pid, i_cam = self.data_source[i]
70 ret.append(i)
71 pid_i = self.index_pid[i]
72 cams = self.pid_cam[pid_i]
73 index = self.pid_index[pid_i]
74 select_cams = no_index(cams, i_cam)
75
76 if select_cams:
77 if len(select_cams) >= self.num_instances:
78 cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
79 else:
80 cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
81 for kk in cam_indexes:
82 ret.append(index[kk])
83 else:
84 select_indexes = no_index(index, i)
85 if not select_indexes:
86 # Only one image for this identity
87 ind_indexes = [0] * (self.num_instances - 1)
88 elif len(select_indexes) >= self.num_instances:
89 ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
90 else:
91 ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
92
93 for kk in ind_indexes:
94 ret.append(index[kk])
95
96 if len(ret) == self.batch_size:
97 yield from ret
98 ret = []
99
100
101class NaiveIdentitySampler(Sampler):
102 """
103 Randomly sample N identities, then for each identity,
104 randomly sample K instances, therefore batch size is N*K.
105 Args:
106 - data_source (list): list of (img_path, pid, camid).
107 - num_instances (int): number of instances per identity in a batch.
108 - batch_size (int): number of examples in a batch.
109 """
110
111 def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
112 self.data_source = data_source
113 self.batch_size = batch_size
114 self.num_instances = num_instances
115 self.num_pids_per_batch = batch_size // self.num_instances
116
117 self.index_pid = defaultdict(list)
118 self.pid_cam = defaultdict(list)
119 self.pid_index = defaultdict(list)
120
121 for index, info in enumerate(data_source):
122 pid = info[1]
123 camid = info[2]
124 self.index_pid[index] = pid
125 self.pid_cam[pid].append(camid)
126 self.pid_index[pid].append(index)
127
128 self.pids = sorted(list(self.pid_index.keys()))
129 self.num_identities = len(self.pids)
130
131 if seed is None:
132 seed = comm.shared_random_seed()
133 self._seed = int(seed)
134
135 self._rank = comm.get_rank()
136 self._world_size = comm.get_world_size()
137
138 def __iter__(self):
139 start = self._rank
140 yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
141
143 np.random.seed(self._seed)
144 while True:
145 avai_pids = copy.deepcopy(self.pids)
146 batch_idxs_dict = {}
147
148 batch_indices = []
149 while len(avai_pids) >= self.num_pids_per_batch:
150 selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
151 for pid in selected_pids:
152 # Register pid in batch_idxs_dict if not
153 if pid not in batch_idxs_dict:
154 idxs = copy.deepcopy(self.pid_index[pid])
155 if len(idxs) < self.num_instances:
156 idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
157 np.random.shuffle(idxs)
158 batch_idxs_dict[pid] = idxs
159
160 avai_idxs = batch_idxs_dict[pid]
161 for _ in range(self.num_instances):
162 batch_indices.append(avai_idxs.pop(0))
163
164 if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
165
166 assert len(batch_indices) == self.batch_size, f"batch indices have wrong " \
167 f"length with {len(batch_indices)}!"
168 yield from batch_indices
169 batch_indices = []
__init__(self, str data_source, int batch_size, int num_instances, Optional[int] seed=None)
__init__(self, str data_source, int batch_size, int num_instances, Optional[int] seed=None)