22 Recompute and update the batch norm stats to make them more precise. During
23 training both BN stats and the weight are changing after every iteration, so
24 the running average can not precisely reflect the actual stats of the
26 In this function, the BN stats are recomputed with fixed weights, to make
27 the running average more precise. Specifically, it computes the true average
28 of per-batch mean/variance instead of the running average.
30 model (nn.Module): the model whose bn stats will be recomputed.
32 1. This function will not alter the training mode of the given model.
33 Users are responsible for setting the layers that needs
34 precise-BN to training mode, prior to calling this function.
35 2. Be careful if your models contain other stateful layers in
36 addition to BN, i.e. layers whose state can change in forward
37 iterations. This function will alter their state. If you wish
38 them unchanged, you need to either pass in a submodule without
39 those layers, or backup the states.
40 data_loader (iterator): an iterator. Produce data as inputs to the model.
41 num_iters (int): number of iterations to compute the stats.
43 bn_layers = get_bn_modules(model)
44 if len(bn_layers) == 0:
51 momentum_actual = [bn.momentum
for bn
in bn_layers]
56 running_mean = [torch.zeros_like(bn.running_mean)
for bn
in bn_layers]
57 running_var = [torch.zeros_like(bn.running_var)
for bn
in bn_layers]
59 for ind, inputs
in enumerate(itertools.islice(data_loader, num_iters)):
60 inputs[
'targets'].fill_(-1)
63 for i, bn
in enumerate(bn_layers):
65 running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
66 running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
68 assert ind == num_iters - 1, (
69 "update_bn_stats is meant to run for {} iterations, "
70 "but the dataloader stops at {} iterations.".format(num_iters, ind)
73 for i, bn
in enumerate(bn_layers):
75 bn.running_mean = running_mean[i]
76 bn.running_var = running_var[i]
77 bn.momentum = momentum_actual[i]