Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Static Public Member Functions | Public Attributes | Protected Attributes | List of all members
fastreid.engine.defaults.DefaultTrainer Class Reference
Inheritance diagram for fastreid.engine.defaults.DefaultTrainer:
fastreid.engine.train_loop.SimpleTrainer fastreid.engine.train_loop.TrainerBase

Public Member Functions

 __init__ (self, cfg)
 
 resume_or_load (self, resume=True)
 
 build_hooks (self)
 
 build_writers (self)
 
 train (self)
 
 build_model (cls, cfg)
 
 build_optimizer (cls, cfg, model)
 
 build_lr_scheduler (cls, cfg, optimizer)
 
 build_train_loader (cls, cfg)
 
 build_test_loader (cls, cfg, dataset_name)
 
 build_evaluator (cls, cfg, dataset_name, output_dir=None)
 
 test (cls, cfg, model)
 
- Public Member Functions inherited from fastreid.engine.train_loop.SimpleTrainer
 run_step (self)
 
- Public Member Functions inherited from fastreid.engine.train_loop.TrainerBase
 register_hooks (self, hooks)
 
 before_train (self)
 
 after_train (self)
 
 before_step (self)
 
 after_step (self)
 

Static Public Member Functions

 auto_scale_hyperparams (cfg, data_loader)
 

Public Attributes

 scheduler
 
 checkpointer
 
 start_iter
 
 max_iter
 
 cfg
 
 optimizer
 
 model
 
- Public Attributes inherited from fastreid.engine.train_loop.SimpleTrainer
 model
 
 data_loader
 
 optimizer
 
 amp_enabled
 
 scaler
 
 iter
 
- Public Attributes inherited from fastreid.engine.train_loop.TrainerBase
 iter
 
 max_iter
 

Protected Attributes

 _last_eval_results
 
- Protected Attributes inherited from fastreid.engine.train_loop.SimpleTrainer
 _data_loader_iter
 
- Protected Attributes inherited from fastreid.engine.train_loop.TrainerBase
 _hooks
 

Additional Inherited Members

- Protected Member Functions inherited from fastreid.engine.train_loop.SimpleTrainer
 _detect_anomaly (self, losses, loss_dict)
 
 _write_metrics (self, dict metrics_dict)
 

Detailed Description

A trainer with default training logic. Compared to `SimpleTrainer`, it
contains the following logic in addition:
1. Create model, optimizer, scheduler, dataloader from the given config.
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
3. Register a few common hooks.
It is created to simplify the **standard model training workflow** and reduce code boilerplate
for users who only need the standard training workflow, with standard features.
It means this class makes *many assumptions* about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`SimpleTrainer` are too much for research.
The code of this class has been annotated about restrictive assumptions it mades.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
   nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to `tools/plain_train_net.py`.
Also note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in fastreid.
To obtain more stable behavior, write your own training logic with other public APIs.
Attributes:
    scheduler:
    checkpointer:
    cfg (CfgNode):
Examples:
.. code-block:: python
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
    trainer.train()

Definition at line 167 of file defaults.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.engine.defaults.DefaultTrainer.__init__ ( self,
cfg )
Args:
    cfg (CfgNode):

Reimplemented from fastreid.engine.train_loop.SimpleTrainer.

Definition at line 200 of file defaults.py.

200 def __init__(self, cfg):
201 """
202 Args:
203 cfg (CfgNode):
204 """
205 logger = logging.getLogger("fastreid")
206 if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
207 setup_logger()
208
209 # Assume these objects must be constructed in this order.
210 data_loader = self.build_train_loader(cfg)
211 cfg = self.auto_scale_hyperparams(cfg, data_loader)
212 model = self.build_model(cfg)
213 optimizer = self.build_optimizer(cfg, model)
214
215 # For training, wrap with DDP. But don't need this for inference.
216 if comm.get_world_size() > 1:
217 # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
218 # for part of the parameters is not updated.
219 model = DistributedDataParallel(
220 model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
221 )
222
223 super().__init__(model, data_loader, optimizer, cfg.SOLVER.AMP_ENABLED)
224
225 self.scheduler = self.build_lr_scheduler(cfg, optimizer)
226 # Assume no other objects need to be checkpointed.
227 # We can later make it checkpoint the stateful hooks
228 self.checkpointer = Checkpointer(
229 # Assume you want to save checkpoints together with logs/statistics
230 model,
231 cfg.OUTPUT_DIR,
232 save_to_disk=comm.is_main_process(),
233 optimizer=optimizer,
234 scheduler=self.scheduler,
235 )
236 self.start_iter = 0
237 if cfg.SOLVER.SWA.ENABLED:
238 self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
239 else:
240 self.max_iter = cfg.SOLVER.MAX_ITER
241
242 self.cfg = cfg
243
244 self.register_hooks(self.build_hooks())
245

Member Function Documentation

◆ auto_scale_hyperparams()

fastreid.engine.defaults.DefaultTrainer.auto_scale_hyperparams ( cfg,
data_loader )
static
    This is used for auto-computation actual training iterations,
    because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
    so we need to convert specific hyper-param to training iterations.

Definition at line 466 of file defaults.py.

466 def auto_scale_hyperparams(cfg, data_loader):
467 r"""
468 This is used for auto-computation actual training iterations,
469 because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
470 so we need to convert specific hyper-param to training iterations.
471 """
472
473 cfg = cfg.clone()
474 frozen = cfg.is_frozen()
475 cfg.defrost()
476
477 iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
478 cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
479 cfg.SOLVER.MAX_ITER *= iters_per_epoch
480 cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
481 cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
482 cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
483 for i in range(len(cfg.SOLVER.STEPS)):
484 cfg.SOLVER.STEPS[i] *= iters_per_epoch
485 cfg.SOLVER.SWA.ITER *= iters_per_epoch
486 cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
487
488 ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
489 # Evaluation period must be divided by 200 for writing into tensorboard.
490 eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
491 cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
492 # Change checkpoint saving period consistent with evaluation period.
493 cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
494
495 logger = logging.getLogger(__name__)
496 logger.info(
497 f"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
498 f"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
499 f"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
500 f"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
501 f"eval_Iter={cfg.TEST.EVAL_PERIOD}."
502 )
503
504 if frozen: cfg.freeze()
505
506 return cfg

◆ build_evaluator()

fastreid.engine.defaults.DefaultTrainer.build_evaluator ( cls,
cfg,
dataset_name,
output_dir = None )

Definition at line 424 of file defaults.py.

424 def build_evaluator(cls, cfg, dataset_name, output_dir=None):
425 data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
426 return data_loader, ReidEvaluator(cfg, num_query, output_dir)
427

◆ build_hooks()

fastreid.engine.defaults.DefaultTrainer.build_hooks ( self)
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
    list[HookBase]:

Definition at line 262 of file defaults.py.

262 def build_hooks(self):
263 """
264 Build a list of default hooks, including timing, evaluation,
265 checkpointing, lr scheduling, precise BN, writing events.
266 Returns:
267 list[HookBase]:
268 """
269 logger = logging.getLogger(__name__)
270 cfg = self.cfg.clone()
271 cfg.defrost()
272 cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
273 cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
274
275 ret = [
276 hooks.IterationTimer(),
277 hooks.LRScheduler(self.optimizer, self.scheduler),
278 ]
279
280 if cfg.SOLVER.SWA.ENABLED:
281 ret.append(
282 hooks.SWA(
283 cfg.SOLVER.MAX_ITER,
284 cfg.SOLVER.SWA.PERIOD,
285 cfg.SOLVER.SWA.LR_FACTOR,
286 cfg.SOLVER.SWA.ETA_MIN_LR,
287 cfg.SOLVER.SWA.LR_SCHED,
288 )
289 )
290
291 if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
292 logger.info("Prepare precise BN dataset")
293 ret.append(hooks.PreciseBN(
294 # Run at the same freq as (but before) evaluation.
295 self.model,
296 # Build a new data loader to not affect training
297 self.build_train_loader(cfg),
298 cfg.TEST.PRECISE_BN.NUM_ITER,
299 ))
300
301 if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
302 freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
303 logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
304 ret.append(hooks.FreezeLayer(
305 self.model,
306 self.optimizer,
307 cfg.MODEL.FREEZE_LAYERS,
308 cfg.SOLVER.FREEZE_ITERS,
309 ))
310 # Do PreciseBN before checkpointer, because it updates the model and need to
311 # be saved by checkpointer.
312 # This is not always the best: if checkpointing has a different frequency,
313 # some checkpoints may have more precise statistics than others.
314 if comm.is_main_process():
315 ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
316
317 def test_and_save_results():
318 self._last_eval_results = self.test(self.cfg, self.model)
319 return self._last_eval_results
320
321 # Do evaluation after checkpointer, because then if it fails,
322 # we can use the saved checkpoint to debug.
323 ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
324
325 if comm.is_main_process():
326 # run writers in the end, so that evaluation metrics are written
327 ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
328
329 return ret
330

◆ build_lr_scheduler()

fastreid.engine.defaults.DefaultTrainer.build_lr_scheduler ( cls,
cfg,
optimizer )
It now calls :func:`fastreid.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.

Definition at line 394 of file defaults.py.

394 def build_lr_scheduler(cls, cfg, optimizer):
395 """
396 It now calls :func:`fastreid.solver.build_lr_scheduler`.
397 Overwrite it if you'd like a different scheduler.
398 """
399 return build_lr_scheduler(cfg, optimizer)
400

◆ build_model()

fastreid.engine.defaults.DefaultTrainer.build_model ( cls,
cfg )
Returns:
    torch.nn.Module:
It now calls :func:`fastreid.modeling.build_model`.
Overwrite it if you'd like a different model.

Definition at line 371 of file defaults.py.

371 def build_model(cls, cfg):
372 """
373 Returns:
374 torch.nn.Module:
375 It now calls :func:`fastreid.modeling.build_model`.
376 Overwrite it if you'd like a different model.
377 """
378 model = build_model(cfg)
379 # logger = logging.getLogger(__name__)
380 # logger.info("Model:\n{}".format(model))
381 return model
382

◆ build_optimizer()

fastreid.engine.defaults.DefaultTrainer.build_optimizer ( cls,
cfg,
model )
Returns:
    torch.optim.Optimizer:
It now calls :func:`fastreid.solver.build_optimizer`.
Overwrite it if you'd like a different optimizer.

Definition at line 384 of file defaults.py.

384 def build_optimizer(cls, cfg, model):
385 """
386 Returns:
387 torch.optim.Optimizer:
388 It now calls :func:`fastreid.solver.build_optimizer`.
389 Overwrite it if you'd like a different optimizer.
390 """
391 return build_optimizer(cfg, model)
392

◆ build_test_loader()

fastreid.engine.defaults.DefaultTrainer.build_test_loader ( cls,
cfg,
dataset_name )
Returns:
    iterable
It now calls :func:`fastreid.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.

Definition at line 414 of file defaults.py.

414 def build_test_loader(cls, cfg, dataset_name):
415 """
416 Returns:
417 iterable
418 It now calls :func:`fastreid.data.build_detection_test_loader`.
419 Overwrite it if you'd like a different data loader.
420 """
421 return build_reid_test_loader(cfg, dataset_name)
422

◆ build_train_loader()

fastreid.engine.defaults.DefaultTrainer.build_train_loader ( cls,
cfg )
Returns:
    iterable
It now calls :func:`fastreid.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.

Definition at line 402 of file defaults.py.

402 def build_train_loader(cls, cfg):
403 """
404 Returns:
405 iterable
406 It now calls :func:`fastreid.data.build_detection_train_loader`.
407 Overwrite it if you'd like a different data loader.
408 """
409 logger = logging.getLogger(__name__)
410 logger.info("Prepare training set")
411 return build_reid_train_loader(cfg)
412

◆ build_writers()

fastreid.engine.defaults.DefaultTrainer.build_writers ( self)
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
    list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
    return [
        CommonMetricPrinter(self.max_iter),
        JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(self.cfg.OUTPUT_DIR),
    ]

Definition at line 331 of file defaults.py.

331 def build_writers(self):
332 """
333 Build a list of writers to be used. By default it contains
334 writers that write metrics to the screen,
335 a json file, and a tensorboard event file respectively.
336 If you'd like a different list of writers, you can overwrite it in
337 your trainer.
338 Returns:
339 list[EventWriter]: a list of :class:`EventWriter` objects.
340 It is now implemented by:
341 .. code-block:: python
342 return [
343 CommonMetricPrinter(self.max_iter),
344 JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
345 TensorboardXWriter(self.cfg.OUTPUT_DIR),
346 ]
347 """
348 # Assume the default print/log frequency.
349 return [
350 # It may not always print what you want to see, since it prints "common" metrics only.
351 CommonMetricPrinter(self.max_iter),
352 JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
353 TensorboardXWriter(self.cfg.OUTPUT_DIR),
354 ]
355

◆ resume_or_load()

fastreid.engine.defaults.DefaultTrainer.resume_or_load ( self,
resume = True )
If `resume==True`, and last checkpoint exists, resume from it.
Otherwise, load a model specified by the config.
Args:
    resume (bool): whether to do resume or not

Definition at line 246 of file defaults.py.

246 def resume_or_load(self, resume=True):
247 """
248 If `resume==True`, and last checkpoint exists, resume from it.
249 Otherwise, load a model specified by the config.
250 Args:
251 resume (bool): whether to do resume or not
252 """
253 # The checkpoint stores the training iteration that just finished, thus we start
254 # at the next iteration (or iter zero if there's no checkpoint).
255 checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
256
257 if resume and self.checkpointer.has_checkpoint():
258 self.start_iter = checkpoint.get("iteration", -1) + 1
259 # The checkpoint stores the training iteration that just finished, thus we start
260 # at the next iteration (or iter zero if there's no checkpoint).
261

◆ test()

fastreid.engine.defaults.DefaultTrainer.test ( cls,
cfg,
model )
Args:
    cfg (CfgNode):
    model (nn.Module):
Returns:
    dict: a dict of result metrics

Definition at line 429 of file defaults.py.

429 def test(cls, cfg, model):
430 """
431 Args:
432 cfg (CfgNode):
433 model (nn.Module):
434 Returns:
435 dict: a dict of result metrics
436 """
437 logger = logging.getLogger(__name__)
438
439 results = OrderedDict()
440 for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
441 logger.info("Prepare testing set")
442 try:
443 data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
444 except NotImplementedError:
445 logger.warn(
446 "No evaluator found. implement its `build_evaluator` method."
447 )
448 results[dataset_name] = {}
449 continue
450 results_i = inference_on_dataset(model, data_loader, evaluator)
451 results[dataset_name] = results_i
452
453 if comm.is_main_process():
454 assert isinstance(
455 results, dict
456 ), "Evaluator must return a dict on the main process. Got {} instead.".format(
457 results
458 )
459 print_csv_format(results)
460
461 if len(results) == 1: results = list(results.values())[0]
462
463 return results
464

◆ train()

fastreid.engine.defaults.DefaultTrainer.train ( self)
Run training.
Returns:
    OrderedDict of results, if evaluation is enabled. Otherwise None.

Reimplemented from fastreid.engine.train_loop.TrainerBase.

Definition at line 356 of file defaults.py.

356 def train(self):
357 """
358 Run training.
359 Returns:
360 OrderedDict of results, if evaluation is enabled. Otherwise None.
361 """
362 super().train(self.start_iter, self.max_iter)
363 if comm.is_main_process():
364 assert hasattr(
365 self, "_last_eval_results"
366 ), "No evaluation results obtained during training!"
367 # verify_results(self.cfg, self._last_eval_results)
368 return self._last_eval_results
369

Member Data Documentation

◆ _last_eval_results

fastreid.engine.defaults.DefaultTrainer._last_eval_results
protected

Definition at line 318 of file defaults.py.

◆ cfg

fastreid.engine.defaults.DefaultTrainer.cfg

Definition at line 242 of file defaults.py.

◆ checkpointer

fastreid.engine.defaults.DefaultTrainer.checkpointer

Definition at line 228 of file defaults.py.

◆ max_iter

fastreid.engine.defaults.DefaultTrainer.max_iter

Definition at line 238 of file defaults.py.

◆ model

fastreid.engine.defaults.DefaultTrainer.model

Definition at line 291 of file defaults.py.

◆ optimizer

fastreid.engine.defaults.DefaultTrainer.optimizer

Definition at line 277 of file defaults.py.

◆ scheduler

fastreid.engine.defaults.DefaultTrainer.scheduler

Definition at line 225 of file defaults.py.

◆ start_iter

fastreid.engine.defaults.DefaultTrainer.start_iter

Definition at line 236 of file defaults.py.


The documentation for this class was generated from the following file: