Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Protected Attributes | List of all members
fastreid.engine.hooks.PreciseBN Class Reference
Inheritance diagram for fastreid.engine.hooks.PreciseBN:
fastreid.engine.train_loop.HookBase

Public Member Functions

 __init__ (self, model, data_loader, num_iter)
 
 after_step (self)
 
 update_stats (self)
 
- Public Member Functions inherited from fastreid.engine.train_loop.HookBase
 before_train (self)
 
 after_train (self)
 
 before_step (self)
 

Protected Attributes

 _logger
 
 _disabled
 
 _model
 
 _data_loader
 
 _num_iter
 
 _data_iter
 

Detailed Description

The standard implementation of BatchNorm uses EMA in inference, which is
sometimes suboptimal.
This class computes the true average of statistics rather than the moving average,
and put true averages to every BN layer in the given model.
It is executed after the last iteration.

Definition at line 349 of file hooks.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.engine.hooks.PreciseBN.__init__ ( self,
model,
data_loader,
num_iter )
Args:
    model (nn.Module): a module whose all BN layers in training mode will be
        updated by precise BN.
        Note that user is responsible for ensuring the BN layers to be
        updated are in training mode when this hook is triggered.
    data_loader (iterable): it will produce data to be run by `model(data)`.
    num_iter (int): number of iterations used to compute the precise
        statistics.

Definition at line 358 of file hooks.py.

358 def __init__(self, model, data_loader, num_iter):
359 """
360 Args:
361 model (nn.Module): a module whose all BN layers in training mode will be
362 updated by precise BN.
363 Note that user is responsible for ensuring the BN layers to be
364 updated are in training mode when this hook is triggered.
365 data_loader (iterable): it will produce data to be run by `model(data)`.
366 num_iter (int): number of iterations used to compute the precise
367 statistics.
368 """
369 self._logger = logging.getLogger(__name__)
370 if len(get_bn_modules(model)) == 0:
371 self._logger.info(
372 "PreciseBN is disabled because model does not contain BN layers in training mode."
373 )
374 self._disabled = True
375 return
376
377 self._model = model
378 self._data_loader = data_loader
379 self._num_iter = num_iter
380 self._disabled = False
381
382 self._data_iter = None
383

Member Function Documentation

◆ after_step()

fastreid.engine.hooks.PreciseBN.after_step ( self)
Called after each iteration.

Reimplemented from fastreid.engine.train_loop.HookBase.

Definition at line 384 of file hooks.py.

384 def after_step(self):
385 next_iter = self.trainer.iter + 1
386 is_final = next_iter == self.trainer.max_iter
387 if is_final:
388 self.update_stats()
389

◆ update_stats()

fastreid.engine.hooks.PreciseBN.update_stats ( self)
Update the model with precise statistics. Users can manually call this method.

Definition at line 390 of file hooks.py.

390 def update_stats(self):
391 """
392 Update the model with precise statistics. Users can manually call this method.
393 """
394 if self._disabled:
395 return
396
397 if self._data_iter is None:
398 self._data_iter = iter(self._data_loader)
399
400 def data_loader():
401 for num_iter in itertools.count(1):
402 if num_iter % 100 == 0:
403 self._logger.info(
404 "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
405 )
406 # This way we can reuse the same iterator
407 yield next(self._data_iter)
408
409 with EventStorage(): # capture events in a new storage to discard them
410 self._logger.info(
411 "Running precise-BN for {} iterations... ".format(self._num_iter)
412 + "Note that this could produce different statistics every time."
413 )
414 update_bn_stats(self._model, data_loader(), self._num_iter)
415
416

Member Data Documentation

◆ _data_iter

fastreid.engine.hooks.PreciseBN._data_iter
protected

Definition at line 382 of file hooks.py.

◆ _data_loader

fastreid.engine.hooks.PreciseBN._data_loader
protected

Definition at line 378 of file hooks.py.

◆ _disabled

fastreid.engine.hooks.PreciseBN._disabled
protected

Definition at line 374 of file hooks.py.

◆ _logger

fastreid.engine.hooks.PreciseBN._logger
protected

Definition at line 369 of file hooks.py.

◆ _model

fastreid.engine.hooks.PreciseBN._model
protected

Definition at line 377 of file hooks.py.

◆ _num_iter

fastreid.engine.hooks.PreciseBN._num_iter
protected

Definition at line 379 of file hooks.py.


The documentation for this class was generated from the following file: