Safemotion Lib
Loading...
Searching...
No Matches
data_sampler.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: l1aoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6import itertools
7from typing import Optional
8
9import numpy as np
10from torch.utils.data import Sampler
11
12from fastreid.utils import comm
13
14
15class TrainingSampler(Sampler):
16 """
17 In training, we only care about the "infinite stream" of training data.
18 So this sampler produces an infinite stream of indices and
19 all workers cooperate to correctly shuffle the indices and sample different indices.
20 The samplers in each worker effectively produces `indices[worker_id::num_workers]`
21 where `indices` is an infinite stream of indices consisting of
22 `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
23 or `range(size) + range(size) + ...` (if shuffle is False)
24 """
25
26 def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
27 """
28 Args:
29 size (int): the total number of data of the underlying dataset to sample from
30 shuffle (bool): whether to shuffle the indices or not
31 seed (int): the initial seed of the shuffle. Must be the same
32 across all workers. If None, will use a random seed shared
33 among workers (require synchronization among all workers).
34 """
35 self._size = size
36 assert size > 0
37 self._shuffle = shuffle
38 if seed is None:
39 seed = comm.shared_random_seed()
40 self._seed = int(seed)
41
42 self._rank = comm.get_rank()
43 self._world_size = comm.get_world_size()
44
45 def __iter__(self):
46 start = self._rank
47 yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
48
50 np.random.seed(self._seed)
51 while True:
52 if self._shuffle:
53 yield from np.random.permutation(self._size)
54 else:
55 yield from np.arange(self._size)
56
57
58class InferenceSampler(Sampler):
59 """
60 Produce indices for inference.
61 Inference needs to run on the __exact__ set of samples,
62 therefore when the total number of samples is not divisible by the number of workers,
63 this sampler produces different number of samples on different workers.
64 """
65
66 def __init__(self, size: int):
67 """
68 Args:
69 size (int): the total number of data of the underlying dataset to sample from
70 """
71 self._size = size
72 assert size > 0
73 self._rank = comm.get_rank()
74 self._world_size = comm.get_world_size()
75
76 shard_size = (self._size - 1) // self._world_size + 1
77 begin = shard_size * self._rank
78 end = min(shard_size * (self._rank + 1), self._size)
79 self._local_indices = range(begin, end)
80
81 def __iter__(self):
82 yield from self._local_indices
83
84 def __len__(self):
85 return len(self._local_indices)
__init__(self, int size, bool shuffle=True, Optional[int] seed=None)