20def update_bn_stats(model, data_loader, num_iters: int = 200):
21 """
22 Recompute and update the batch norm stats to make them more precise. During
23 training both BN stats and the weight are changing after every iteration, so
24 the running average can not precisely reflect the actual stats of the
25 current model.
26 In this function, the BN stats are recomputed with fixed weights, to make
27 the running average more precise. Specifically, it computes the true average
28 of per-batch mean/variance instead of the running average.
29 Args:
30 model (nn.Module): the model whose bn stats will be recomputed.
31 Note that:
32 1. This function will not alter the training mode of the given model.
33 Users are responsible for setting the layers that needs
34 precise-BN to training mode, prior to calling this function.
35 2. Be careful if your models contain other stateful layers in
36 addition to BN, i.e. layers whose state can change in forward
37 iterations. This function will alter their state. If you wish
38 them unchanged, you need to either pass in a submodule without
39 those layers, or backup the states.
40 data_loader (iterator): an iterator. Produce data as inputs to the model.
41 num_iters (int): number of iterations to compute the stats.
42 """
43 bn_layers = get_bn_modules(model)
44 if len(bn_layers) == 0:
45 return
46
47
48
49
50
51 momentum_actual = [bn.momentum for bn in bn_layers]
52 for bn in bn_layers:
53 bn.momentum = 1.0
54
55
56 running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
57 running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
58
59 for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
60 inputs['targets'].fill_(-1)
61 with torch.no_grad():
62 model(inputs)
63 for i, bn in enumerate(bn_layers):
64
65 running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
66 running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
67
68 assert ind == num_iters - 1, (
69 "update_bn_stats is meant to run for {} iterations, "
70 "but the dataloader stops at {} iterations.".format(num_iters, ind)
71 )
72
73 for i, bn in enumerate(bn_layers):
74
75 bn.running_mean = running_mean[i]
76 bn.running_var = running_var[i]
77 bn.momentum = momentum_actual[i]
78
79