5This file contains components with some default boilerplate logic user may need
6in training / testing. They will not work for everyone, but many users may find them useful.
7The behavior of functions/classes in this file is subject to change,
8since they are meant to represent the "common default behavior" people need in their projects.
15from collections
import OrderedDict
18import torch.nn.functional
as F
19from torch.nn.parallel
import DistributedDataParallel
21from fastreid.data import build_reid_test_loader, build_reid_train_loader
23 inference_on_dataset, print_csv_format)
34from .train_loop
import SimpleTrainer
36__all__ = [
"default_argument_parser",
"default_setup",
"DefaultPredictor",
"DefaultTrainer"]
41 Create a parser with some common arguments used by fastreid users.
43 argparse.ArgumentParser:
45 parser = argparse.ArgumentParser(description=
"fastreid Training")
46 parser.add_argument(
"--config-file", default=
"", metavar=
"FILE", help=
"path to config file")
50 help=
"whether to attempt to finetune from the trained model",
55 help=
"whether to attempt to resume from the checkpoint directory",
57 parser.add_argument(
"--eval-only", action=
"store_true", help=
"perform evaluation only")
58 parser.add_argument(
"--num-gpus", type=int, default=1, help=
"number of gpus *per machine*")
59 parser.add_argument(
"--num-machines", type=int, default=1, help=
"total number of machines")
61 "--machine-rank", type=int, default=0, help=
"the rank of this machine (unique per machine)"
68 port = 30000 + random.randint(1, 10000) + random.randint(1, 5000)
69 parser.add_argument(
"--dist-url", default=
"tcp://127.0.0.1:{}".format(port))
72 help=
"Modify config options using the command-line",
74 nargs=argparse.REMAINDER,
81 Perform some basic common setups at the beginning of a job, including:
82 1. Set up the detectron2 logger
83 2. Log basic information about environment, cmdline arguments, and config
84 3. Backup the config to the output directory
86 cfg (CfgNode): the full config to be used
87 args (argparse.NameSpace): the command line arguments to be logged
89 output_dir = cfg.OUTPUT_DIR
90 if comm.is_main_process()
and output_dir:
91 PathManager.mkdirs(output_dir)
93 rank = comm.get_rank()
94 setup_logger(output_dir, distributed_rank=rank, name=
"fvcore")
95 logger = setup_logger(output_dir, distributed_rank=rank)
97 logger.info(
"Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
98 logger.info(
"Environment info:\n" + collect_env_info())
100 logger.info(
"Command line arguments: " + str(args))
101 if hasattr(args,
"config_file")
and args.config_file !=
"":
103 "Contents of args.config_file={}:\n{}".format(
104 args.config_file, PathManager.open(args.config_file,
"r").read()
108 logger.info(
"Running with full config:\n{}".format(cfg))
109 if comm.is_main_process()
and output_dir:
112 path = os.path.join(output_dir,
"config.yaml")
113 with PathManager.open(path,
"w")
as f:
115 logger.info(
"Full config saved to {}".format(os.path.abspath(path)))
122 if not (hasattr(args,
"eval_only")
and args.eval_only):
123 torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
128 Create a simple end-to-end predictor with the given config.
129 The predictor takes an BGR image, resizes it to the specified resolution,
130 runs the model and produces a dict of predictions.
131 This predictor takes care of model loading and input preprocessing for you.
132 If you'd like to do anything more fancy, please refer to its source code
133 as examples to build and use the model manually.
136 .. code-block:: python
137 pred = DefaultPredictor(cfg)
138 inputs = cv2.imread("input.jpg")
139 outputs = pred(inputs)
145 self.
cfg.MODEL.BACKBONE.PRETRAIN =
False
154 image (torch.tensor): an image tensor of shape (B, C, H, W).
156 predictions (torch.tensor): the output features of the model
158 inputs = {
"images": image}
159 with torch.no_grad():
160 predictions = self.
model(inputs)
162 features = F.normalize(predictions)
163 features = features.cpu().data
169 A trainer with default training logic. Compared to `SimpleTrainer`, it
170 contains the following logic in addition:
171 1. Create model, optimizer, scheduler, dataloader from the given config.
172 2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
173 3. Register a few common hooks.
174 It is created to simplify the **standard model training workflow** and reduce code boilerplate
175 for users who only need the standard training workflow, with standard features.
176 It means this class makes *many assumptions* about your training logic that
177 may easily become invalid in a new research. In fact, any assumptions beyond those made in the
178 :class:`SimpleTrainer` are too much for research.
179 The code of this class has been annotated about restrictive assumptions it mades.
180 When they do not work for you, you're encouraged to:
181 1. Overwrite methods of this class, OR:
182 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
183 nothing else. You can then add your own hooks if needed. OR:
184 3. Write your own training loop similar to `tools/plain_train_net.py`.
185 Also note that the behavior of this class, like other functions/classes in
186 this file, is not stable, since it is meant to represent the "common default behavior".
187 It is only guaranteed to work well with the standard models and training workflow in fastreid.
188 To obtain more stable behavior, write your own training logic with other public APIs.
194 .. code-block:: python
195 trainer = DefaultTrainer(cfg)
196 trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
205 logger = logging.getLogger(
"fastreid")
206 if not logger.isEnabledFor(logging.INFO):
216 if comm.get_world_size() > 1:
219 model = DistributedDataParallel(
220 model, device_ids=[comm.get_local_rank()], broadcast_buffers=
False
223 super().
__init__(model, data_loader, optimizer, cfg.SOLVER.AMP_ENABLED)
232 save_to_disk=comm.is_main_process(),
237 if cfg.SOLVER.SWA.ENABLED:
248 If `resume==True`, and last checkpoint exists, resume from it.
249 Otherwise, load a model specified by the config.
251 resume (bool): whether to do resume or not
258 self.
start_iter = checkpoint.get(
"iteration", -1) + 1
264 Build a list of default hooks, including timing, evaluation,
265 checkpointing, lr scheduling, precise BN, writing events.
269 logger = logging.getLogger(__name__)
270 cfg = self.
cfg.clone()
272 cfg.DATALOADER.NUM_WORKERS = 0
273 cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET])
280 if cfg.SOLVER.SWA.ENABLED:
284 cfg.SOLVER.SWA.PERIOD,
285 cfg.SOLVER.SWA.LR_FACTOR,
286 cfg.SOLVER.SWA.ETA_MIN_LR,
287 cfg.SOLVER.SWA.LR_SCHED,
291 if cfg.TEST.PRECISE_BN.ENABLED
and hooks.get_bn_modules(self.
modelmodel):
292 logger.info(
"Prepare precise BN dataset")
298 cfg.TEST.PRECISE_BN.NUM_ITER,
301 if cfg.MODEL.FREEZE_LAYERS != [
'']
and cfg.SOLVER.FREEZE_ITERS > 0:
302 freeze_layers =
",".join(cfg.MODEL.FREEZE_LAYERS)
303 logger.info(f
'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
307 cfg.MODEL.FREEZE_LAYERS,
308 cfg.SOLVER.FREEZE_ITERS,
314 if comm.is_main_process():
317 def test_and_save_results():
323 ret.append(
hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
325 if comm.is_main_process():
333 Build a list of writers to be used. By default it contains
334 writers that write metrics to the screen,
335 a json file, and a tensorboard event file respectively.
336 If you'd like a different list of writers, you can overwrite it in
339 list[EventWriter]: a list of :class:`EventWriter` objects.
340 It is now implemented by:
341 .. code-block:: python
343 CommonMetricPrinter(self.max_iter),
344 JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
345 TensorboardXWriter(self.cfg.OUTPUT_DIR),
352 JSONWriter(os.path.join(self.
cfg.OUTPUT_DIR,
"metrics.json")),
360 OrderedDict of results, if evaluation is enabled. Otherwise None.
363 if comm.is_main_process():
365 self,
"_last_eval_results"
366 ),
"No evaluation results obtained during training!"
375 It now calls :func:`fastreid.modeling.build_model`.
376 Overwrite it if you'd like a different model.
378 model = build_model(cfg)
387 torch.optim.Optimizer:
388 It now calls :func:`fastreid.solver.build_optimizer`.
389 Overwrite it if you'd like a different optimizer.
391 return build_optimizer(cfg, model)
396 It now calls :func:`fastreid.solver.build_lr_scheduler`.
397 Overwrite it if you'd like a different scheduler.
399 return build_lr_scheduler(cfg, optimizer)
406 It now calls :func:`fastreid.data.build_detection_train_loader`.
407 Overwrite it if you'd like a different data loader.
409 logger = logging.getLogger(__name__)
410 logger.info(
"Prepare training set")
411 return build_reid_train_loader(cfg)
418 It now calls :func:`fastreid.data.build_detection_test_loader`.
419 Overwrite it if you'd like a different data loader.
421 return build_reid_test_loader(cfg, dataset_name)
426 return data_loader,
ReidEvaluator(cfg, num_query, output_dir)
435 dict: a dict of result metrics
437 logger = logging.getLogger(__name__)
439 results = OrderedDict()
440 for idx, dataset_name
in enumerate(cfg.DATASETS.TESTS):
441 logger.info(
"Prepare testing set")
444 except NotImplementedError:
446 "No evaluator found. implement its `build_evaluator` method."
448 results[dataset_name] = {}
450 results_i = inference_on_dataset(model, data_loader, evaluator)
451 results[dataset_name] = results_i
453 if comm.is_main_process():
456 ),
"Evaluator must return a dict on the main process. Got {} instead.".format(
459 print_csv_format(results)
461 if len(results) == 1: results = list(results.values())[0]
468 This is used for auto-computation actual training iterations,
469 because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
470 so we need to convert specific hyper-param to training iterations.
474 frozen = cfg.is_frozen()
477 iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
478 cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
479 cfg.SOLVER.MAX_ITER *= iters_per_epoch
480 cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
481 cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
482 cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
483 for i
in range(len(cfg.SOLVER.STEPS)):
484 cfg.SOLVER.STEPS[i] *= iters_per_epoch
485 cfg.SOLVER.SWA.ITER *= iters_per_epoch
486 cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
488 ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
490 eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
491 cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
493 cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
495 logger = logging.getLogger(__name__)
497 f
"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
498 f
"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
499 f
"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
500 f
"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
501 f
"eval_Iter={cfg.TEST.EVAL_PERIOD}."
504 if frozen: cfg.freeze()
auto_scale_hyperparams(cfg, data_loader)
build_test_loader(cls, cfg, dataset_name)
build_optimizer(cls, cfg, model)
build_evaluator(cls, cfg, dataset_name, output_dir=None)
resume_or_load(self, resume=True)
build_train_loader(cls, cfg)
build_lr_scheduler(cls, cfg, optimizer)
register_hooks(self, hooks)
default_argument_parser()