Safemotion Lib
Loading...
Searching...
No Matches
smreid
fastreid
data
samplers
data_sampler.py
Go to the documentation of this file.
1
# encoding: utf-8
2
"""
3
@author: l1aoxingyu
4
@contact: sherlockliao01@gmail.com
5
"""
6
import
itertools
7
from
typing
import
Optional
8
9
import
numpy
as
np
10
from
torch.utils.data
import
Sampler
11
12
from
fastreid.utils
import
comm
13
14
15
class
TrainingSampler
(Sampler):
16
"""
17
In training, we only care about the "infinite stream" of training data.
18
So this sampler produces an infinite stream of indices and
19
all workers cooperate to correctly shuffle the indices and sample different indices.
20
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
21
where `indices` is an infinite stream of indices consisting of
22
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
23
or `range(size) + range(size) + ...` (if shuffle is False)
24
"""
25
26
def
__init__
(self, size: int, shuffle: bool =
True
, seed: Optional[int] =
None
):
27
"""
28
Args:
29
size (int): the total number of data of the underlying dataset to sample from
30
shuffle (bool): whether to shuffle the indices or not
31
seed (int): the initial seed of the shuffle. Must be the same
32
across all workers. If None, will use a random seed shared
33
among workers (require synchronization among all workers).
34
"""
35
self.
_size
= size
36
assert
size > 0
37
self.
_shuffle
= shuffle
38
if
seed
is
None
:
39
seed = comm.shared_random_seed()
40
self.
_seed
= int(seed)
41
42
self.
_rank
= comm.get_rank()
43
self.
_world_size
= comm.get_world_size()
44
45
def
__iter__
(self):
46
start = self.
_rank
47
yield
from
itertools.islice(self.
_infinite_indices
(), start,
None
, self.
_world_size
)
48
49
def
_infinite_indices
(self):
50
np.random.seed(self.
_seed
)
51
while
True
:
52
if
self.
_shuffle
:
53
yield
from
np.random.permutation(self.
_size
)
54
else
:
55
yield
from
np.arange(self.
_size
)
56
57
58
class
InferenceSampler
(Sampler):
59
"""
60
Produce indices for inference.
61
Inference needs to run on the __exact__ set of samples,
62
therefore when the total number of samples is not divisible by the number of workers,
63
this sampler produces different number of samples on different workers.
64
"""
65
66
def
__init__
(self, size: int):
67
"""
68
Args:
69
size (int): the total number of data of the underlying dataset to sample from
70
"""
71
self.
_size
= size
72
assert
size > 0
73
self.
_rank
= comm.get_rank()
74
self.
_world_size
= comm.get_world_size()
75
76
shard_size = (self.
_size
- 1) // self.
_world_size
+ 1
77
begin = shard_size * self.
_rank
78
end = min(shard_size * (self.
_rank
+ 1), self.
_size
)
79
self.
_local_indices
= range(begin, end)
80
81
def
__iter__
(self):
82
yield
from
self.
_local_indices
83
84
def
__len__
(self):
85
return
len(self.
_local_indices
)
fastreid.data.samplers.data_sampler.InferenceSampler
Definition
data_sampler.py:58
fastreid.data.samplers.data_sampler.InferenceSampler.__iter__
__iter__(self)
Definition
data_sampler.py:81
fastreid.data.samplers.data_sampler.InferenceSampler._local_indices
_local_indices
Definition
data_sampler.py:79
fastreid.data.samplers.data_sampler.InferenceSampler._world_size
_world_size
Definition
data_sampler.py:74
fastreid.data.samplers.data_sampler.InferenceSampler.__len__
__len__(self)
Definition
data_sampler.py:84
fastreid.data.samplers.data_sampler.InferenceSampler._size
_size
Definition
data_sampler.py:71
fastreid.data.samplers.data_sampler.InferenceSampler.__init__
__init__(self, int size)
Definition
data_sampler.py:66
fastreid.data.samplers.data_sampler.InferenceSampler._rank
_rank
Definition
data_sampler.py:73
fastreid.data.samplers.data_sampler.TrainingSampler
Definition
data_sampler.py:15
fastreid.data.samplers.data_sampler.TrainingSampler._shuffle
_shuffle
Definition
data_sampler.py:37
fastreid.data.samplers.data_sampler.TrainingSampler._infinite_indices
_infinite_indices(self)
Definition
data_sampler.py:49
fastreid.data.samplers.data_sampler.TrainingSampler._size
_size
Definition
data_sampler.py:35
fastreid.data.samplers.data_sampler.TrainingSampler.__iter__
__iter__(self)
Definition
data_sampler.py:45
fastreid.data.samplers.data_sampler.TrainingSampler.__init__
__init__(self, int size, bool shuffle=True, Optional[int] seed=None)
Definition
data_sampler.py:26
fastreid.data.samplers.data_sampler.TrainingSampler._rank
_rank
Definition
data_sampler.py:42
fastreid.data.samplers.data_sampler.TrainingSampler._world_size
_world_size
Definition
data_sampler.py:43
fastreid.data.samplers.data_sampler.TrainingSampler._seed
_seed
Definition
data_sampler.py:40
fastreid.utils
Definition
__init__.py:1
torch.utils.data
Generated by
1.10.0