4@contact: sherlockliao01@gmail.com
10import torch.nn.functional
as F
24 def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
25 bias_init=0.0, **kwargs):
26 super().
__init__(num_features, eps=eps, momentum=momentum)
27 if weight_init
is not None: nn.init.constant_(self.
weight, weight_init)
28 if bias_init
is not None: nn.init.constant_(self.
bias, bias_init)
29 self.
weight.requires_grad_(
not weight_freeze)
30 self.
bias.requires_grad_(
not bias_freeze)
34 def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
36 super().
__init__(num_features, eps=eps, momentum=momentum)
37 if weight_init
is not None: nn.init.constant_(self.
weight, weight_init)
38 if bias_init
is not None: nn.init.constant_(self.
bias, bias_init)
39 self.
weight.requires_grad_(
not weight_freeze)
40 self.
bias.requires_grad_(
not bias_freeze)
46 half1 = int(planes / 2)
48 half2 = planes - half1
49 self.
IN = nn.InstanceNorm2d(half1, affine=
True)
50 self.
BN = get_norm(bn_norm, half2, **kwargs)
53 split = torch.split(x, self.
half, 1)
54 out1 = self.
IN(split[0].contiguous())
55 out2 = self.
BN(split[1].contiguous())
56 out = torch.cat((out1, out2), 1)
61 def __init__(self, num_features, num_splits=1, **kwargs):
62 super().
__init__(num_features, **kwargs)
64 self.register_buffer(
'running_mean', torch.zeros(num_features))
65 self.register_buffer(
'running_var', torch.ones(num_features))
68 N, C, H, W = input.shape
69 if self.training
or not self.track_running_stats:
72 outputs = F.batch_norm(
87 BatchNorm2d where the batch statistics and the affine parameters are fixed.
88 It contains non-trainable buffers called
89 "weight" and "bias", "running_mean", "running_var",
90 initialized to perform identity transformation.
91 The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
92 which are computed from the original four parameters of BN.
93 The affine transform `x * weight + bias` will perform the equivalent
94 computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
95 When loading a backbone model from Caffe2, "running_mean" and "running_var"
96 will be left unchanged as identity transformation.
97 Other pre-trained backbone models may contain all 4 parameters.
98 The forward is implemented by `F.batch_norm(..., training=False)`.
103 def __init__(self, num_features, eps=1e-5, **kwargs):
104 super().
__init__(num_features, weight_freeze=
True, bias_freeze=
True, **kwargs)
114 scale = scale.reshape(1, -1, 1, 1)
115 bias = bias.reshape(1, -1, 1, 1)
116 return x * scale + bias
131 self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
133 version = local_metadata.get(
"version",
None)
135 if version
is None or version < 2:
138 if prefix +
"running_mean" not in state_dict:
139 state_dict[prefix +
"running_mean"] = torch.zeros_like(self.
running_mean)
140 if prefix +
"running_var" not in state_dict:
141 state_dict[prefix +
"running_var"] = torch.ones_like(self.
running_var)
143 if version
is not None and version < 3:
144 logger = logging.getLogger(__name__)
145 logger.info(
"FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(
".")))
147 state_dict[prefix +
"running_var"] -= self.
eps
150 state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
154 return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.
num_features, self.
eps)
159 Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
161 module (torch.nn.Module):
163 If module is BatchNorm/SyncBatchNorm, returns a new module.
164 Otherwise, in-place convert module and return it.
165 Similar to convert_sync_batchnorm in
166 https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
168 bn_module = nn.modules.batchnorm
169 bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
171 if isinstance(module, bn_module):
172 res = cls(module.num_features)
174 res.weight.data = module.weight.data.clone().detach()
175 res.bias.data = module.bias.data.clone().detach()
176 res.running_mean.data = module.running_mean.data
177 res.running_var.data = module.running_var.data
180 for name, child
in module.named_children():
182 if new_child
is not child:
183 res.add_module(name, new_child)
190 norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
191 or a callable that thakes a channel number and returns
192 the normalization layer as a nn.Module
193 out_channels: number of channels for normalization layer
196 nn.Module or None: the normalization layer
198 if isinstance(norm, str):
203 "GhostBN": GhostBatchNorm,
204 "FrozenBN": FrozenBatchNorm,
205 "GN":
lambda channels, **args: nn.GroupNorm(32, channels),
206 "syncBN": SyncBatchNorm,
208 return norm(out_channels, **kwargs)
__init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0, **kwargs)
_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
__init__(self, num_features, eps=1e-5, **kwargs)
convert_frozen_batchnorm(cls, module)
__init__(self, num_features, num_splits=1, **kwargs)
__init__(self, planes, bn_norm, **kwargs)
__init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0)
get_norm(norm, out_channels, **kwargs)