Safemotion Lib
Loading...
Searching...
No Matches
defaults.py
Go to the documentation of this file.
1# -*- coding: utf-8 -*-
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
4"""
5This file contains components with some default boilerplate logic user may need
6in training / testing. They will not work for everyone, but many users may find them useful.
7The behavior of functions/classes in this file is subject to change,
8since they are meant to represent the "common default behavior" people need in their projects.
9"""
10
11import argparse
12import logging
13import os, random
14import sys
15from collections import OrderedDict
16
17import torch
18import torch.nn.functional as F
19from torch.nn.parallel import DistributedDataParallel
20
21from fastreid.data import build_reid_test_loader, build_reid_train_loader
22from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
23 inference_on_dataset, print_csv_format)
24from fastreid.modeling.meta_arch import build_model
25from fastreid.solver import build_lr_scheduler, build_optimizer
26from fastreid.utils import comm
27from fastreid.utils.checkpoint import Checkpointer
28from fastreid.utils.collect_env import collect_env_info
29from fastreid.utils.env import seed_all_rng
30from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
31from fastreid.utils.file_io import PathManager
32from fastreid.utils.logger import setup_logger
33from . import hooks
34from .train_loop import SimpleTrainer
35
36__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
37
38
40 """
41 Create a parser with some common arguments used by fastreid users.
42 Returns:
43 argparse.ArgumentParser:
44 """
45 parser = argparse.ArgumentParser(description="fastreid Training")
46 parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
47 parser.add_argument(
48 "--finetune",
49 action="store_true",
50 help="whether to attempt to finetune from the trained model",
51 )
52 parser.add_argument(
53 "--resume",
54 action="store_true",
55 help="whether to attempt to resume from the checkpoint directory",
56 )
57 parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
58 parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
59 parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
60 parser.add_argument(
61 "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
62 )
63
64 # PyTorch still may leave orphan processes in multi-gpu training.
65 # Therefore we use a deterministic way to obtain port,
66 # so that users are aware of orphan processes by seeing the port occupied.
67 # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
68 port = 30000 + random.randint(1, 10000) + random.randint(1, 5000)
69 parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
70 parser.add_argument(
71 "opts",
72 help="Modify config options using the command-line",
73 default=None,
74 nargs=argparse.REMAINDER,
75 )
76 return parser
77
78
79def default_setup(cfg, args):
80 """
81 Perform some basic common setups at the beginning of a job, including:
82 1. Set up the detectron2 logger
83 2. Log basic information about environment, cmdline arguments, and config
84 3. Backup the config to the output directory
85 Args:
86 cfg (CfgNode): the full config to be used
87 args (argparse.NameSpace): the command line arguments to be logged
88 """
89 output_dir = cfg.OUTPUT_DIR
90 if comm.is_main_process() and output_dir:
91 PathManager.mkdirs(output_dir)
92
93 rank = comm.get_rank()
94 setup_logger(output_dir, distributed_rank=rank, name="fvcore")
95 logger = setup_logger(output_dir, distributed_rank=rank)
96
97 logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
98 logger.info("Environment info:\n" + collect_env_info())
99
100 logger.info("Command line arguments: " + str(args))
101 if hasattr(args, "config_file") and args.config_file != "":
102 logger.info(
103 "Contents of args.config_file={}:\n{}".format(
104 args.config_file, PathManager.open(args.config_file, "r").read()
105 )
106 )
107
108 logger.info("Running with full config:\n{}".format(cfg))
109 if comm.is_main_process() and output_dir:
110 # Note: some of our scripts may expect the existence of
111 # config.yaml in output directory
112 path = os.path.join(output_dir, "config.yaml")
113 with PathManager.open(path, "w") as f:
114 f.write(cfg.dump())
115 logger.info("Full config saved to {}".format(os.path.abspath(path)))
116
117 # make sure each worker has a different, yet deterministic seed if specified
118 seed_all_rng()
119
120 # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
121 # typical validation set.
122 if not (hasattr(args, "eval_only") and args.eval_only):
123 torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
124
125
127 """
128 Create a simple end-to-end predictor with the given config.
129 The predictor takes an BGR image, resizes it to the specified resolution,
130 runs the model and produces a dict of predictions.
131 This predictor takes care of model loading and input preprocessing for you.
132 If you'd like to do anything more fancy, please refer to its source code
133 as examples to build and use the model manually.
134 Attributes:
135 Examples:
136 .. code-block:: python
137 pred = DefaultPredictor(cfg)
138 inputs = cv2.imread("input.jpg")
139 outputs = pred(inputs)
140 """
141
142 def __init__(self, cfg):
143 self.cfg = cfg.clone() # cfg can be modified by model
144 self.cfg.defrost()
145 self.cfg.MODEL.BACKBONE.PRETRAIN = False
146 self.model = build_model(self.cfg)
147 self.model.eval()
148
149 Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
150
151 def __call__(self, image):
152 """
153 Args:
154 image (torch.tensor): an image tensor of shape (B, C, H, W).
155 Returns:
156 predictions (torch.tensor): the output features of the model
157 """
158 inputs = {"images": image}
159 with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
160 predictions = self.model(inputs)
161 # Normalize feature to compute cosine distance
162 features = F.normalize(predictions)
163 features = features.cpu().data
164 return features
165
166
168 """
169 A trainer with default training logic. Compared to `SimpleTrainer`, it
170 contains the following logic in addition:
171 1. Create model, optimizer, scheduler, dataloader from the given config.
172 2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
173 3. Register a few common hooks.
174 It is created to simplify the **standard model training workflow** and reduce code boilerplate
175 for users who only need the standard training workflow, with standard features.
176 It means this class makes *many assumptions* about your training logic that
177 may easily become invalid in a new research. In fact, any assumptions beyond those made in the
178 :class:`SimpleTrainer` are too much for research.
179 The code of this class has been annotated about restrictive assumptions it mades.
180 When they do not work for you, you're encouraged to:
181 1. Overwrite methods of this class, OR:
182 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
183 nothing else. You can then add your own hooks if needed. OR:
184 3. Write your own training loop similar to `tools/plain_train_net.py`.
185 Also note that the behavior of this class, like other functions/classes in
186 this file, is not stable, since it is meant to represent the "common default behavior".
187 It is only guaranteed to work well with the standard models and training workflow in fastreid.
188 To obtain more stable behavior, write your own training logic with other public APIs.
189 Attributes:
190 scheduler:
191 checkpointer:
192 cfg (CfgNode):
193 Examples:
194 .. code-block:: python
195 trainer = DefaultTrainer(cfg)
196 trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
197 trainer.train()
198 """
199
200 def __init__(self, cfg):
201 """
202 Args:
203 cfg (CfgNode):
204 """
205 logger = logging.getLogger("fastreid")
206 if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
207 setup_logger()
208
209 # Assume these objects must be constructed in this order.
210 data_loader = self.build_train_loader(cfg)
211 cfg = self.auto_scale_hyperparams(cfg, data_loader)
212 model = self.build_model(cfg)
213 optimizer = self.build_optimizer(cfg, model)
214
215 # For training, wrap with DDP. But don't need this for inference.
216 if comm.get_world_size() > 1:
217 # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
218 # for part of the parameters is not updated.
219 model = DistributedDataParallel(
220 model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
221 )
222
223 super().__init__(model, data_loader, optimizer, cfg.SOLVER.AMP_ENABLED)
224
225 self.scheduler = self.build_lr_scheduler(cfg, optimizer)
226 # Assume no other objects need to be checkpointed.
227 # We can later make it checkpoint the stateful hooks
229 # Assume you want to save checkpoints together with logs/statistics
230 model,
231 cfg.OUTPUT_DIR,
232 save_to_disk=comm.is_main_process(),
233 optimizer=optimizer,
234 scheduler=self.scheduler,
235 )
236 self.start_iter = 0
237 if cfg.SOLVER.SWA.ENABLED:
238 self.max_itermax_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
239 else:
240 self.max_itermax_iter = cfg.SOLVER.MAX_ITER
241
242 self.cfg = cfg
243
244 self.register_hooks(self.build_hooks())
245
246 def resume_or_load(self, resume=True):
247 """
248 If `resume==True`, and last checkpoint exists, resume from it.
249 Otherwise, load a model specified by the config.
250 Args:
251 resume (bool): whether to do resume or not
252 """
253 # The checkpoint stores the training iteration that just finished, thus we start
254 # at the next iteration (or iter zero if there's no checkpoint).
255 checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
256
257 if resume and self.checkpointer.has_checkpoint():
258 self.start_iter = checkpoint.get("iteration", -1) + 1
259 # The checkpoint stores the training iteration that just finished, thus we start
260 # at the next iteration (or iter zero if there's no checkpoint).
261
262 def build_hooks(self):
263 """
264 Build a list of default hooks, including timing, evaluation,
265 checkpointing, lr scheduling, precise BN, writing events.
266 Returns:
267 list[HookBase]:
268 """
269 logger = logging.getLogger(__name__)
270 cfg = self.cfg.clone()
271 cfg.defrost()
272 cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
273 cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
274
275 ret = [
278 ]
279
280 if cfg.SOLVER.SWA.ENABLED:
281 ret.append(
282 hooks.SWA(
283 cfg.SOLVER.MAX_ITER,
284 cfg.SOLVER.SWA.PERIOD,
285 cfg.SOLVER.SWA.LR_FACTOR,
286 cfg.SOLVER.SWA.ETA_MIN_LR,
287 cfg.SOLVER.SWA.LR_SCHED,
288 )
289 )
290
291 if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.modelmodel):
292 logger.info("Prepare precise BN dataset")
293 ret.append(hooks.PreciseBN(
294 # Run at the same freq as (but before) evaluation.
295 self.modelmodel,
296 # Build a new data loader to not affect training
297 self.build_train_loader(cfg),
298 cfg.TEST.PRECISE_BN.NUM_ITER,
299 ))
300
301 if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
302 freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
303 logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
304 ret.append(hooks.FreezeLayer(
305 self.modelmodel,
307 cfg.MODEL.FREEZE_LAYERS,
308 cfg.SOLVER.FREEZE_ITERS,
309 ))
310 # Do PreciseBN before checkpointer, because it updates the model and need to
311 # be saved by checkpointer.
312 # This is not always the best: if checkpointing has a different frequency,
313 # some checkpoints may have more precise statistics than others.
314 if comm.is_main_process():
315 ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
316
317 def test_and_save_results():
318 self._last_eval_results = self.test(self.cfg, self.modelmodel)
319 return self._last_eval_results
320
321 # Do evaluation after checkpointer, because then if it fails,
322 # we can use the saved checkpoint to debug.
323 ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
324
325 if comm.is_main_process():
326 # run writers in the end, so that evaluation metrics are written
327 ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
328
329 return ret
330
331 def build_writers(self):
332 """
333 Build a list of writers to be used. By default it contains
334 writers that write metrics to the screen,
335 a json file, and a tensorboard event file respectively.
336 If you'd like a different list of writers, you can overwrite it in
337 your trainer.
338 Returns:
339 list[EventWriter]: a list of :class:`EventWriter` objects.
340 It is now implemented by:
341 .. code-block:: python
342 return [
343 CommonMetricPrinter(self.max_iter),
344 JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
345 TensorboardXWriter(self.cfg.OUTPUT_DIR),
346 ]
347 """
348 # Assume the default print/log frequency.
349 return [
350 # It may not always print what you want to see, since it prints "common" metrics only.
352 JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
353 TensorboardXWriter(self.cfg.OUTPUT_DIR),
354 ]
355
356 def train(self):
357 """
358 Run training.
359 Returns:
360 OrderedDict of results, if evaluation is enabled. Otherwise None.
361 """
362 super().train(self.start_iter, self.max_itermax_iter)
363 if comm.is_main_process():
364 assert hasattr(
365 self, "_last_eval_results"
366 ), "No evaluation results obtained during training!"
367 # verify_results(self.cfg, self._last_eval_results)
368 return self._last_eval_results
369
370 @classmethod
371 def build_model(cls, cfg):
372 """
373 Returns:
374 torch.nn.Module:
375 It now calls :func:`fastreid.modeling.build_model`.
376 Overwrite it if you'd like a different model.
377 """
378 model = build_model(cfg)
379 # logger = logging.getLogger(__name__)
380 # logger.info("Model:\n{}".format(model))
381 return model
382
383 @classmethod
384 def build_optimizer(cls, cfg, model):
385 """
386 Returns:
387 torch.optim.Optimizer:
388 It now calls :func:`fastreid.solver.build_optimizer`.
389 Overwrite it if you'd like a different optimizer.
390 """
391 return build_optimizer(cfg, model)
392
393 @classmethod
394 def build_lr_scheduler(cls, cfg, optimizer):
395 """
396 It now calls :func:`fastreid.solver.build_lr_scheduler`.
397 Overwrite it if you'd like a different scheduler.
398 """
399 return build_lr_scheduler(cfg, optimizer)
400
401 @classmethod
402 def build_train_loader(cls, cfg):
403 """
404 Returns:
405 iterable
406 It now calls :func:`fastreid.data.build_detection_train_loader`.
407 Overwrite it if you'd like a different data loader.
408 """
409 logger = logging.getLogger(__name__)
410 logger.info("Prepare training set")
411 return build_reid_train_loader(cfg)
412
413 @classmethod
414 def build_test_loader(cls, cfg, dataset_name):
415 """
416 Returns:
417 iterable
418 It now calls :func:`fastreid.data.build_detection_test_loader`.
419 Overwrite it if you'd like a different data loader.
420 """
421 return build_reid_test_loader(cfg, dataset_name)
422
423 @classmethod
424 def build_evaluator(cls, cfg, dataset_name, output_dir=None):
425 data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
426 return data_loader, ReidEvaluator(cfg, num_query, output_dir)
427
428 @classmethod
429 def test(cls, cfg, model):
430 """
431 Args:
432 cfg (CfgNode):
433 model (nn.Module):
434 Returns:
435 dict: a dict of result metrics
436 """
437 logger = logging.getLogger(__name__)
438
439 results = OrderedDict()
440 for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
441 logger.info("Prepare testing set")
442 try:
443 data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
444 except NotImplementedError:
445 logger.warn(
446 "No evaluator found. implement its `build_evaluator` method."
447 )
448 results[dataset_name] = {}
449 continue
450 results_i = inference_on_dataset(model, data_loader, evaluator)
451 results[dataset_name] = results_i
452
453 if comm.is_main_process():
454 assert isinstance(
455 results, dict
456 ), "Evaluator must return a dict on the main process. Got {} instead.".format(
457 results
458 )
459 print_csv_format(results)
460
461 if len(results) == 1: results = list(results.values())[0]
462
463 return results
464
465 @staticmethod
466 def auto_scale_hyperparams(cfg, data_loader):
467 r"""
468 This is used for auto-computation actual training iterations,
469 because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
470 so we need to convert specific hyper-param to training iterations.
471 """
472
473 cfg = cfg.clone()
474 frozen = cfg.is_frozen()
475 cfg.defrost()
476
477 iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
478 cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
479 cfg.SOLVER.MAX_ITER *= iters_per_epoch
480 cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
481 cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
482 cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
483 for i in range(len(cfg.SOLVER.STEPS)):
484 cfg.SOLVER.STEPS[i] *= iters_per_epoch
485 cfg.SOLVER.SWA.ITER *= iters_per_epoch
486 cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
487
488 ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
489 # Evaluation period must be divided by 200 for writing into tensorboard.
490 eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
491 cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
492 # Change checkpoint saving period consistent with evaluation period.
493 cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
494
495 logger = logging.getLogger(__name__)
496 logger.info(
497 f"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
498 f"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
499 f"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
500 f"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
501 f"eval_Iter={cfg.TEST.EVAL_PERIOD}."
502 )
503
504 if frozen: cfg.freeze()
505
506 return cfg
build_test_loader(cls, cfg, dataset_name)
Definition defaults.py:414
build_evaluator(cls, cfg, dataset_name, output_dir=None)
Definition defaults.py:424
build_lr_scheduler(cls, cfg, optimizer)
Definition defaults.py:394
default_setup(cfg, args)
Definition defaults.py:79