39def build_reid_train_loader(cfg):
40 cfg = cfg.clone()
41 cfg.defrost()
42
43 train_items = list()
44 for d in cfg.DATASETS.NAMES:
45 kwargs = parser_kwargs(cfg.DATASETS.KWARGS)
46 dataset = DATASET_REGISTRY.get(d)(root=cfg.DATASETS.ROOT,
47 combineall=cfg.DATASETS.COMBINEALL, **kwargs)
48 if comm.is_main_process():
49 dataset.show_train()
50 train_items.extend(dataset.train)
51
52 iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
53 cfg.SOLVER.MAX_ITER *= iters_per_epoch
54 train_transforms = build_transforms(cfg, is_train=True)
55 if not cfg.DATASETS.IS_LMDB:
56 train_set = CommDataset(train_items, train_transforms, relabel=True)
57 else:
58 train_set = LMDBDataset(dataset, train_transforms)
59
60 num_workers = cfg.DATALOADER.NUM_WORKERS
61 num_instance = cfg.DATALOADER.NUM_INSTANCE
62 mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
63
64 if cfg.DATALOADER.PK_SAMPLER:
65 if cfg.DATALOADER.NAIVE_WAY:
66 data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
67 cfg.SOLVER.IMS_PER_BATCH, num_instance)
68 else:
69 data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
70 cfg.SOLVER.IMS_PER_BATCH, num_instance)
71 else:
72 data_sampler = samplers.TrainingSampler(len(train_set))
73 batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
74
75 train_loader = torch.utils.data.DataLoader(
76 train_set,
77 num_workers=num_workers,
78 batch_sampler=batch_sampler,
79 collate_fn=fast_batch_collator,
80 pin_memory=True,
81 )
82 return train_loader
83
84