Safemotion Lib
Loading...
Searching...
No Matches
batch_norm.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import logging
8
9import torch
10import torch.nn.functional as F
11from torch import nn
12
13__all__ = [
14 "BatchNorm",
15 "IBN",
16 "GhostBatchNorm",
17 "FrozenBatchNorm",
18 "SyncBatchNorm",
19 "get_norm",
20]
21
22
23class BatchNorm(nn.BatchNorm2d):
24 def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
25 bias_init=0.0, **kwargs):
26 super().__init__(num_features, eps=eps, momentum=momentum)
27 if weight_init is not None: nn.init.constant_(self.weight, weight_init)
28 if bias_init is not None: nn.init.constant_(self.bias, bias_init)
29 self.weight.requires_grad_(not weight_freeze)
30 self.bias.requires_grad_(not bias_freeze)
31
32
33class SyncBatchNorm(nn.SyncBatchNorm):
34 def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
35 bias_init=0.0):
36 super().__init__(num_features, eps=eps, momentum=momentum)
37 if weight_init is not None: nn.init.constant_(self.weight, weight_init)
38 if bias_init is not None: nn.init.constant_(self.bias, bias_init)
39 self.weight.requires_grad_(not weight_freeze)
40 self.bias.requires_grad_(not bias_freeze)
41
42
43class IBN(nn.Module):
44 def __init__(self, planes, bn_norm, **kwargs):
45 super(IBN, self).__init__()
46 half1 = int(planes / 2)
47 self.half = half1
48 half2 = planes - half1
49 self.IN = nn.InstanceNorm2d(half1, affine=True)
50 self.BN = get_norm(bn_norm, half2, **kwargs)
51
52 def forward(self, x):
53 split = torch.split(x, self.half, 1)
54 out1 = self.IN(split[0].contiguous())
55 out2 = self.BN(split[1].contiguous())
56 out = torch.cat((out1, out2), 1)
57 return out
58
59
61 def __init__(self, num_features, num_splits=1, **kwargs):
62 super().__init__(num_features, **kwargs)
63 self.num_splits = num_splits
64 self.register_buffer('running_mean', torch.zeros(num_features))
65 self.register_buffer('running_var', torch.ones(num_features))
66
67 def forward(self, input):
68 N, C, H, W = input.shape
69 if self.training or not self.track_running_stats:
70 self.running_mean = self.running_mean.repeat(self.num_splits)
71 self.running_var = self.running_var.repeat(self.num_splits)
72 outputs = F.batch_norm(
73 input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
74 self.weightweight.repeat(self.num_splits), self.biasbias.repeat(self.num_splits),
75 True, self.momentum, self.eps).view(N, C, H, W)
76 self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
77 self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
78 return outputs
79 else:
80 return F.batch_norm(
81 input, self.running_mean, self.running_var,
82 self.weightweight, self.biasbias, False, self.momentum, self.eps)
83
84
86 """
87 BatchNorm2d where the batch statistics and the affine parameters are fixed.
88 It contains non-trainable buffers called
89 "weight" and "bias", "running_mean", "running_var",
90 initialized to perform identity transformation.
91 The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
92 which are computed from the original four parameters of BN.
93 The affine transform `x * weight + bias` will perform the equivalent
94 computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
95 When loading a backbone model from Caffe2, "running_mean" and "running_var"
96 will be left unchanged as identity transformation.
97 Other pre-trained backbone models may contain all 4 parameters.
98 The forward is implemented by `F.batch_norm(..., training=False)`.
99 """
100
101 _version = 3
102
103 def __init__(self, num_features, eps=1e-5, **kwargs):
104 super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs)
105 self.num_features = num_features
106 self.eps = eps
107
108 def forward(self, x):
109 if x.requires_grad:
110 # When gradients are needed, F.batch_norm will use extra memory
111 # because its backward op computes gradients for weight/bias as well.
112 scale = self.weightweight * (self.running_var + self.eps).rsqrt()
113 bias = self.biasbias - self.running_mean * scale
114 scale = scale.reshape(1, -1, 1, 1)
115 bias = bias.reshape(1, -1, 1, 1)
116 return x * scale + bias
117 else:
118 # When gradients are not needed, F.batch_norm is a single fused op
119 # and provide more optimization opportunities.
120 return F.batch_norm(
121 x,
126 training=False,
127 eps=self.eps,
128 )
129
131 self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
132 ):
133 version = local_metadata.get("version", None)
134
135 if version is None or version < 2:
136 # No running_mean/var in early versions
137 # This will silent the warnings
138 if prefix + "running_mean" not in state_dict:
139 state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
140 if prefix + "running_var" not in state_dict:
141 state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
142
143 if version is not None and version < 3:
144 logger = logging.getLogger(__name__)
145 logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
146 # In version < 3, running_var are used without +eps.
147 state_dict[prefix + "running_var"] -= self.eps
148
149 super()._load_from_state_dict(
150 state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
151 )
152
153 def __repr__(self):
154 return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
155
156 @classmethod
157 def convert_frozen_batchnorm(cls, module):
158 """
159 Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
160 Args:
161 module (torch.nn.Module):
162 Returns:
163 If module is BatchNorm/SyncBatchNorm, returns a new module.
164 Otherwise, in-place convert module and return it.
165 Similar to convert_sync_batchnorm in
166 https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
167 """
168 bn_module = nn.modules.batchnorm
169 bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
170 res = module
171 if isinstance(module, bn_module):
172 res = cls(module.num_features)
173 if module.affine:
174 res.weight.data = module.weight.data.clone().detach()
175 res.bias.data = module.bias.data.clone().detach()
176 res.running_mean.data = module.running_mean.data
177 res.running_var.data = module.running_var.data
178 res.eps = module.eps
179 else:
180 for name, child in module.named_children():
181 new_child = cls.convert_frozen_batchnorm(child)
182 if new_child is not child:
183 res.add_module(name, new_child)
184 return res
185
186
187def get_norm(norm, out_channels, **kwargs):
188 """
189 Args:
190 norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
191 or a callable that thakes a channel number and returns
192 the normalization layer as a nn.Module
193 out_channels: number of channels for normalization layer
194
195 Returns:
196 nn.Module or None: the normalization layer
197 """
198 if isinstance(norm, str):
199 if len(norm) == 0:
200 return None
201 norm = {
202 "BN": BatchNorm,
203 "GhostBN": GhostBatchNorm,
204 "FrozenBN": FrozenBatchNorm,
205 "GN": lambda channels, **args: nn.GroupNorm(32, channels),
206 "syncBN": SyncBatchNorm,
207 }[norm]
208 return norm(out_channels, **kwargs)
__init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0, **kwargs)
Definition batch_norm.py:25
_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
__init__(self, num_features, eps=1e-5, **kwargs)
__init__(self, num_features, num_splits=1, **kwargs)
Definition batch_norm.py:61
__init__(self, planes, bn_norm, **kwargs)
Definition batch_norm.py:44
__init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0)
Definition batch_norm.py:35
get_norm(norm, out_channels, **kwargs)