187def get_norm(norm, out_channels, **kwargs):
188 """
189 Args:
190 norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
191 or a callable that thakes a channel number and returns
192 the normalization layer as a nn.Module
193 out_channels: number of channels for normalization layer
194
195 Returns:
196 nn.Module or None: the normalization layer
197 """
198 if isinstance(norm, str):
199 if len(norm) == 0:
200 return None
201 norm = {
202 "BN": BatchNorm,
203 "GhostBN": GhostBatchNorm,
204 "FrozenBN": FrozenBatchNorm,
205 "GN": lambda channels, **args: nn.GroupNorm(32, channels),
206 "syncBN": SyncBatchNorm,
207 }[norm]
208 return norm(out_channels, **kwargs)