Safemotion Lib
Loading...
Searching...
No Matches
build.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: l1aoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import os
8import torch
9from torch._six import string_classes
10int_classes = int
11from collections import abc as container_abcs
12from torch.utils.data import DataLoader
13from fastreid.utils import comm
14
15from . import samplers
16from .common import CommDataset, LMDBDataset, KidDataset
17from .datasets import DATASET_REGISTRY
18from .transforms import build_transforms
19
20def parser_kwargs(additional_args):
21 kwargs = {}
22 if len(additional_args):
23 args = [x.strip() for x in additional_args.split('+')]
24 for arg in args:
25 key, value = [x.strip() for x in arg.split(':')]
26 if '.' in value and value.replace('.', '').isdigit():
27 kwargs[key] = float(value)
28 elif 'e-' in value and value.replace('e-', '').isdigit():
29 kwargs[key] = float(value)
30 elif 'e' in value and value.replace('e', '').isdigit():
31 kwargs[key] = float(value)
32 elif value.isdigit():
33 kwargs[key] = int(value)
34 else:
35 kwargs[key] = value
36 return kwargs
37
38
40 cfg = cfg.clone()
41 cfg.defrost()
42
43 train_items = list()
44 for d in cfg.DATASETS.NAMES:
45 kwargs = parser_kwargs(cfg.DATASETS.KWARGS)
46 dataset = DATASET_REGISTRY.get(d)(root=cfg.DATASETS.ROOT,
47 combineall=cfg.DATASETS.COMBINEALL, **kwargs)
48 if comm.is_main_process():
49 dataset.show_train()
50 train_items.extend(dataset.train)
51
52 iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
53 cfg.SOLVER.MAX_ITER *= iters_per_epoch
54 train_transforms = build_transforms(cfg, is_train=True)
55 if not cfg.DATASETS.IS_LMDB:
56 train_set = CommDataset(train_items, train_transforms, relabel=True)
57 else:
58 train_set = LMDBDataset(dataset, train_transforms)
59
60 num_workers = cfg.DATALOADER.NUM_WORKERS
61 num_instance = cfg.DATALOADER.NUM_INSTANCE
62 mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
63
64 if cfg.DATALOADER.PK_SAMPLER:
65 if cfg.DATALOADER.NAIVE_WAY:
66 data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
67 cfg.SOLVER.IMS_PER_BATCH, num_instance)
68 else:
69 data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
70 cfg.SOLVER.IMS_PER_BATCH, num_instance)
71 else:
72 data_sampler = samplers.TrainingSampler(len(train_set))
73 batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
74
75 train_loader = torch.utils.data.DataLoader(
76 train_set,
77 num_workers=num_workers,
78 batch_sampler=batch_sampler,
79 collate_fn=fast_batch_collator,
80 pin_memory=True,
81 )
82 return train_loader
83
84
85def build_reid_test_loader(cfg, dataset_name):
86 cfg = cfg.clone()
87 cfg.defrost()
88
89 kwargs = parser_kwargs(cfg.DATASETS.KWARGS)
90 dataset = DATASET_REGISTRY.get(dataset_name)(root=cfg.DATASETS.ROOT, **kwargs)
91
92 if comm.is_main_process():
93 dataset.show_test()
94 test_items = dataset.query + dataset.gallery
95
96 print(f'test_items = {test_items}\n')
97
98 test_transforms = build_transforms(cfg, is_train=False)
99
100 print(f'test_transforms = {test_transforms}\n')
101
102 test_set = CommDataset(test_items, test_transforms, relabel=False)
103
104 print(f'test_set = {test_set}\n')
105
106 mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
107
108 print(f'mini_batch_size = {mini_batch_size}\n')
109
110 data_sampler = samplers.InferenceSampler(len(test_set))
111 batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
112 test_loader = DataLoader(
113 test_set,
114 batch_sampler=batch_sampler,
115 num_workers=4, # save some memory
116 collate_fn=fast_batch_collator,
117 pin_memory=True,
118 )
119 return test_loader, len(dataset.query)
120
121
122def build_kid_reid_test_loader(cfg, test_items):
123 cfg = cfg.clone()
124 cfg.defrost()
125
126 test_transforms = build_transforms(cfg, is_train=False)
127
128 test_set = KidDataset(test_items, test_transforms, relabel=False)
129
130 mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
131
132 data_sampler = samplers.InferenceSampler(len(test_set))
133 batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
134 test_loader = DataLoader(
135 test_set,
136 batch_sampler=batch_sampler,
137 num_workers=4, # save some memory
138 collate_fn=fast_batch_collator,
139 pin_memory=True,
140 )
141 return test_loader
142
143
145 """
146 A batch collator that does nothing.
147 """
148 return batch
149
150
151def fast_batch_collator(batched_inputs):
152 """
153 A simple batch collator for most common reid tasks
154 """
155 elem = batched_inputs[0]
156 if isinstance(elem, torch.Tensor):
157 out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
158 for i, tensor in enumerate(batched_inputs):
159 out[i] += tensor
160 return out
161
162 elif isinstance(elem, container_abcs.Mapping):
163 return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
164
165 elif isinstance(elem, float):
166 return torch.tensor(batched_inputs, dtype=torch.float64)
167 elif isinstance(elem, int_classes):
168 return torch.tensor(batched_inputs)
169 elif isinstance(elem, string_classes):
170 return batched_inputs
fast_batch_collator(batched_inputs)
Definition build.py:151
trivial_batch_collator(batch)
Definition build.py:144
build_kid_reid_test_loader(cfg, test_items)
Definition build.py:122
build_reid_test_loader(cfg, dataset_name)
Definition build.py:85
parser_kwargs(additional_args)
Definition build.py:20
build_reid_train_loader(cfg)
Definition build.py:39