Safemotion Lib
Loading...
Searching...
No Matches
Functions | Variables
fastreid.utils.precision_bn Namespace Reference

Functions

 update_bn_stats (model, data_loader, int num_iters=200)
 
 get_bn_modules (model)
 

Variables

tuple BN_MODULE_TYPES
 

Detailed Description

@author:  liaoxingyu
@contact: sherlockliao01@gmail.com

Function Documentation

◆ get_bn_modules()

fastreid.utils.precision_bn.get_bn_modules ( model)
Find all BatchNorm (BN) modules that are in training mode. See
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.
Args:
    model (nn.Module): a model possibly containing BN modules.
Returns:
    list[nn.Module]: all BN modules in the model.

Definition at line 80 of file precision_bn.py.

80def get_bn_modules(model):
81 """
82 Find all BatchNorm (BN) modules that are in training mode. See
83 fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
84 included in this search.
85 Args:
86 model (nn.Module): a model possibly containing BN modules.
87 Returns:
88 list[nn.Module]: all BN modules in the model.
89 """
90 # Finds all the bn layers.
91 bn_layers = [
92 m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
93 ]
94 return bn_layers

◆ update_bn_stats()

fastreid.utils.precision_bn.update_bn_stats ( model,
data_loader,
int num_iters = 200 )
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
Args:
    model (nn.Module): the model whose bn stats will be recomputed.
        Note that:
        1. This function will not alter the training mode of the given model.
           Users are responsible for setting the layers that needs
           precise-BN to training mode, prior to calling this function.
        2. Be careful if your models contain other stateful layers in
           addition to BN, i.e. layers whose state can change in forward
           iterations.  This function will alter their state. If you wish
           them unchanged, you need to either pass in a submodule without
           those layers, or backup the states.
    data_loader (iterator): an iterator. Produce data as inputs to the model.
    num_iters (int): number of iterations to compute the stats.

Definition at line 20 of file precision_bn.py.

20def update_bn_stats(model, data_loader, num_iters: int = 200):
21 """
22 Recompute and update the batch norm stats to make them more precise. During
23 training both BN stats and the weight are changing after every iteration, so
24 the running average can not precisely reflect the actual stats of the
25 current model.
26 In this function, the BN stats are recomputed with fixed weights, to make
27 the running average more precise. Specifically, it computes the true average
28 of per-batch mean/variance instead of the running average.
29 Args:
30 model (nn.Module): the model whose bn stats will be recomputed.
31 Note that:
32 1. This function will not alter the training mode of the given model.
33 Users are responsible for setting the layers that needs
34 precise-BN to training mode, prior to calling this function.
35 2. Be careful if your models contain other stateful layers in
36 addition to BN, i.e. layers whose state can change in forward
37 iterations. This function will alter their state. If you wish
38 them unchanged, you need to either pass in a submodule without
39 those layers, or backup the states.
40 data_loader (iterator): an iterator. Produce data as inputs to the model.
41 num_iters (int): number of iterations to compute the stats.
42 """
43 bn_layers = get_bn_modules(model)
44 if len(bn_layers) == 0:
45 return
46
47 # In order to make the running stats only reflect the current batch, the
48 # momentum is disabled.
49 # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
50 # Setting the momentum to 1.0 to compute the stats without momentum.
51 momentum_actual = [bn.momentum for bn in bn_layers]
52 for bn in bn_layers:
53 bn.momentum = 1.0
54
55 # Note that running_var actually means "running average of variance"
56 running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
57 running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
58
59 for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
60 inputs['targets'].fill_(-1)
61 with torch.no_grad(): # No need to backward
62 model(inputs)
63 for i, bn in enumerate(bn_layers):
64 # Accumulates the bn stats.
65 running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
66 running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
67 # We compute the "average of variance" across iterations.
68 assert ind == num_iters - 1, (
69 "update_bn_stats is meant to run for {} iterations, "
70 "but the dataloader stops at {} iterations.".format(num_iters, ind)
71 )
72
73 for i, bn in enumerate(bn_layers):
74 # Sets the precise bn stats.
75 bn.running_mean = running_mean[i]
76 bn.running_var = running_var[i]
77 bn.momentum = momentum_actual[i]
78
79

Variable Documentation

◆ BN_MODULE_TYPES

tuple fastreid.utils.precision_bn.BN_MODULE_TYPES
Initial value:
1= (
2 torch.nn.BatchNorm1d,
3 torch.nn.BatchNorm2d,
4 torch.nn.BatchNorm3d,
5 torch.nn.SyncBatchNorm,
6)

Definition at line 11 of file precision_bn.py.