Safemotion Lib
Loading...
Searching...
No Matches
frn.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8from torch import nn
9from torch.nn.modules.batchnorm import BatchNorm2d
10from torch.nn import ReLU, LeakyReLU
11from torch.nn.parameter import Parameter
12
13
14class TLU(nn.Module):
15 def __init__(self, num_features):
16 """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
17 super(TLU, self).__init__()
18 self.num_features = num_features
19 self.tau = Parameter(torch.Tensor(num_features))
20 self.reset_parameters()
21
22 def reset_parameters(self):
23 nn.init.zeros_(self.tau)
24
25 def extra_repr(self):
26 return 'num_features={num_features}'.format(**self.__dict__)
27
28 def forward(self, x):
29 return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
30
31
32class FRN(nn.Module):
33 def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
34 """
35 weight = gamma, bias = beta
36 beta, gamma:
37 Variables of shape [1, 1, 1, C]. if TensorFlow
38 Variables of shape [1, C, 1, 1]. if PyTorch
39 eps: A scalar constant or learnable variable.
40 """
41 super(FRN, self).__init__()
42
43 self.num_features = num_features
44 self.init_eps = eps
45 self.is_eps_leanable = is_eps_leanable
46
47 self.weight = Parameter(torch.Tensor(num_features))
48 self.bias = Parameter(torch.Tensor(num_features))
49 if is_eps_leanable:
50 self.eps = Parameter(torch.Tensor(1))
51 else:
52 self.register_buffer('eps', torch.Tensor([eps]))
53 self.reset_parameters()
54
55 def reset_parameters(self):
56 nn.init.ones_(self.weight)
57 nn.init.zeros_(self.bias)
58 if self.is_eps_leanable:
59 nn.init.constant_(self.eps, self.init_eps)
60
61 def extra_repr(self):
62 return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
63
64 def forward(self, x):
65 """
66 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
67 0, 1, 2, 3 -> (B, C, H, W) in PyTorch
68 TensorFlow code
69 nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
70 x = x * tf.rsqrt(nu2 + tf.abs(eps))
71 # This Code include TLU function max(y, tau)
72 return tf.maximum(gamma * x + beta, tau)
73 """
74 # Compute the mean norm of activations per channel.
75 nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
76
77 # Perform FRN.
78 x = x * torch.rsqrt(nu2 + self.eps.abs())
79
80 # Scale and Bias
81 x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
82 # x = self.weight * x + self.bias
83 return x
84
85
86def bnrelu_to_frn(module):
87 """
88 Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
89 """
90 mod = module
91 before_name = None
92 before_child = None
93 is_before_bn = False
94
95 for name, child in module.named_children():
96 if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
97 # Convert BN to FRN
98 if isinstance(before_child, BatchNorm2d):
99 mod.add_module(
100 before_name, FRN(num_features=before_child.num_features))
101 else:
102 raise NotImplementedError()
103
104 # Convert ReLU to TLU
105 mod.add_module(name, TLU(num_features=before_child.num_features))
106 else:
107 mod.add_module(name, bnrelu_to_frn(child))
108
109 before_name = name
110 before_child = child
111 is_before_bn = isinstance(child, BatchNorm2d)
112 return mod
113
114
115def convert(module, flag_name):
116 mod = module
117 before_ch = None
118 for name, child in module.named_children():
119 if hasattr(child, flag_name) and getattr(child, flag_name):
120 if isinstance(child, BatchNorm2d):
121 before_ch = child.num_features
122 mod.add_module(name, FRN(num_features=child.num_features))
123 # TODO bn is no good...
124 if isinstance(child, (ReLU, LeakyReLU)):
125 mod.add_module(name, TLU(num_features=before_ch))
126 else:
127 mod.add_module(name, convert(child, flag_name))
128 return mod
129
130
131def remove_flags(module, flag_name):
132 mod = module
133 for name, child in module.named_children():
134 if hasattr(child, 'is_convert_frn'):
135 delattr(child, flag_name)
136 mod.add_module(name, remove_flags(child, flag_name))
137 else:
138 mod.add_module(name, remove_flags(child, flag_name))
139 return mod
140
141
142def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
143 forard_hooks = list()
144 backward_hooks = list()
145
146 is_before_bn = [False]
147
148 def register_forward_hook(module):
149 def hook(self, input, output):
150 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
151 is_before_bn.append(False)
152 return
153
154 # input and output is required in hook def
155 is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
156 if is_converted:
157 setattr(self, flag_name, True)
158 is_before_bn.append(isinstance(self, BatchNorm2d))
159
160 forard_hooks.append(module.register_forward_hook(hook))
161
162 is_before_relu = [False]
163
164 def register_backward_hook(module):
165 def hook(self, input, output):
166 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
167 is_before_relu.append(False)
168 return
169 is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
170 if is_converted:
171 setattr(self, flag_name, True)
172 is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
173
174 backward_hooks.append(module.register_backward_hook(hook))
175
176 # multiple inputs to the network
177 if isinstance(input_size, tuple):
178 input_size = [input_size]
179
180 # batch_size of 2 for batchnorm
181 x = [torch.rand(batch_size, *in_size) for in_size in input_size]
182
183 # register hook
184 model.apply(register_forward_hook)
185 model.apply(register_backward_hook)
186
187 # make a forward pass
188 output = model(*x)
189 output.sum().backward() # Raw output is not enabled to use backward()
190
191 # remove these hooks
192 for h in forard_hooks:
193 h.remove()
194 for h in backward_hooks:
195 h.remove()
196
197 model = convert(model, flag_name=flag_name)
198 model = remove_flags(model, flag_name=flag_name)
199 return model
__init__(self, num_features)
Definition frn.py:15