Safemotion Lib
Loading...
Searching...
No Matches
precision_bn.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import itertools
8
9import torch
10
11BN_MODULE_TYPES = (
12 torch.nn.BatchNorm1d,
13 torch.nn.BatchNorm2d,
14 torch.nn.BatchNorm3d,
15 torch.nn.SyncBatchNorm,
16)
17
18
19@torch.no_grad()
20def update_bn_stats(model, data_loader, num_iters: int = 200):
21 """
22 Recompute and update the batch norm stats to make them more precise. During
23 training both BN stats and the weight are changing after every iteration, so
24 the running average can not precisely reflect the actual stats of the
25 current model.
26 In this function, the BN stats are recomputed with fixed weights, to make
27 the running average more precise. Specifically, it computes the true average
28 of per-batch mean/variance instead of the running average.
29 Args:
30 model (nn.Module): the model whose bn stats will be recomputed.
31 Note that:
32 1. This function will not alter the training mode of the given model.
33 Users are responsible for setting the layers that needs
34 precise-BN to training mode, prior to calling this function.
35 2. Be careful if your models contain other stateful layers in
36 addition to BN, i.e. layers whose state can change in forward
37 iterations. This function will alter their state. If you wish
38 them unchanged, you need to either pass in a submodule without
39 those layers, or backup the states.
40 data_loader (iterator): an iterator. Produce data as inputs to the model.
41 num_iters (int): number of iterations to compute the stats.
42 """
43 bn_layers = get_bn_modules(model)
44 if len(bn_layers) == 0:
45 return
46
47 # In order to make the running stats only reflect the current batch, the
48 # momentum is disabled.
49 # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
50 # Setting the momentum to 1.0 to compute the stats without momentum.
51 momentum_actual = [bn.momentum for bn in bn_layers]
52 for bn in bn_layers:
53 bn.momentum = 1.0
54
55 # Note that running_var actually means "running average of variance"
56 running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
57 running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
58
59 for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
60 inputs['targets'].fill_(-1)
61 with torch.no_grad(): # No need to backward
62 model(inputs)
63 for i, bn in enumerate(bn_layers):
64 # Accumulates the bn stats.
65 running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
66 running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
67 # We compute the "average of variance" across iterations.
68 assert ind == num_iters - 1, (
69 "update_bn_stats is meant to run for {} iterations, "
70 "but the dataloader stops at {} iterations.".format(num_iters, ind)
71 )
72
73 for i, bn in enumerate(bn_layers):
74 # Sets the precise bn stats.
75 bn.running_mean = running_mean[i]
76 bn.running_var = running_var[i]
77 bn.momentum = momentum_actual[i]
78
79
80def get_bn_modules(model):
81 """
82 Find all BatchNorm (BN) modules that are in training mode. See
83 fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
84 included in this search.
85 Args:
86 model (nn.Module): a model possibly containing BN modules.
87 Returns:
88 list[nn.Module]: all BN modules in the model.
89 """
90 # Finds all the bn layers.
91 bn_layers = [
92 m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
93 ]
94 return bn_layers
update_bn_stats(model, data_loader, int num_iters=200)