44 for d
in cfg.DATASETS.NAMES:
46 dataset = DATASET_REGISTRY.get(d)(root=cfg.DATASETS.ROOT,
47 combineall=cfg.DATASETS.COMBINEALL, **kwargs)
48 if comm.is_main_process():
50 train_items.extend(dataset.train)
52 iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
53 cfg.SOLVER.MAX_ITER *= iters_per_epoch
54 train_transforms = build_transforms(cfg, is_train=
True)
55 if not cfg.DATASETS.IS_LMDB:
56 train_set =
CommDataset(train_items, train_transforms, relabel=
True)
60 num_workers = cfg.DATALOADER.NUM_WORKERS
61 num_instance = cfg.DATALOADER.NUM_INSTANCE
62 mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
64 if cfg.DATALOADER.PK_SAMPLER:
65 if cfg.DATALOADER.NAIVE_WAY:
67 cfg.SOLVER.IMS_PER_BATCH, num_instance)
70 cfg.SOLVER.IMS_PER_BATCH, num_instance)
73 batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size,
True)
75 train_loader = torch.utils.data.DataLoader(
77 num_workers=num_workers,
78 batch_sampler=batch_sampler,
79 collate_fn=fast_batch_collator,
90 dataset = DATASET_REGISTRY.get(dataset_name)(root=cfg.DATASETS.ROOT, **kwargs)
92 if comm.is_main_process():
94 test_items = dataset.query + dataset.gallery
96 print(f
'test_items = {test_items}\n')
98 test_transforms = build_transforms(cfg, is_train=
False)
100 print(f
'test_transforms = {test_transforms}\n')
102 test_set =
CommDataset(test_items, test_transforms, relabel=
False)
104 print(f
'test_set = {test_set}\n')
106 mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
108 print(f
'mini_batch_size = {mini_batch_size}\n')
111 batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size,
False)
112 test_loader = DataLoader(
114 batch_sampler=batch_sampler,
116 collate_fn=fast_batch_collator,
119 return test_loader, len(dataset.query)
126 test_transforms = build_transforms(cfg, is_train=
False)
128 test_set =
KidDataset(test_items, test_transforms, relabel=
False)
130 mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
133 batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size,
False)
134 test_loader = DataLoader(
136 batch_sampler=batch_sampler,
138 collate_fn=fast_batch_collator,
153 A simple batch collator for most common reid tasks
155 elem = batched_inputs[0]
156 if isinstance(elem, torch.Tensor):
157 out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
158 for i, tensor
in enumerate(batched_inputs):
162 elif isinstance(elem, container_abcs.Mapping):
165 elif isinstance(elem, float):
166 return torch.tensor(batched_inputs, dtype=torch.float64)
167 elif isinstance(elem, int_classes):
168 return torch.tensor(batched_inputs)
169 elif isinstance(elem, string_classes):
170 return batched_inputs