86def bnrelu_to_frn(module):
88 Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
95 for name, child in module.named_children():
96 if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
98 if isinstance(before_child, BatchNorm2d):
100 before_name, FRN(num_features=before_child.num_features))
102 raise NotImplementedError()
104 # Convert ReLU to TLU
105 mod.add_module(name, TLU(num_features=before_child.num_features))
107 mod.add_module(name, bnrelu_to_frn(child))
111 is_before_bn = isinstance(child, BatchNorm2d)
115def convert(module, flag_name):
118 for name, child in module.named_children():
119 if hasattr(child, flag_name) and getattr(child, flag_name):
120 if isinstance(child, BatchNorm2d):
121 before_ch = child.num_features
122 mod.add_module(name, FRN(num_features=child.num_features))
123 # TODO bn is no good...
124 if isinstance(child, (ReLU, LeakyReLU)):
125 mod.add_module(name, TLU(num_features=before_ch))
127 mod.add_module(name, convert(child, flag_name))
131def remove_flags(module, flag_name):
133 for name, child in module.named_children():
134 if hasattr(child, 'is_convert_frn'):
135 delattr(child, flag_name)
136 mod.add_module(name, remove_flags(child, flag_name))
138 mod.add_module(name, remove_flags(child, flag_name))
142def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
143 forard_hooks = list()
144 backward_hooks = list()
146 is_before_bn = [False]
148 def register_forward_hook(module):
149 def hook(self, input, output):
150 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
151 is_before_bn.append(False)
154 # input and output is required in hook def
155 is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
157 setattr(self, flag_name, True)
158 is_before_bn.append(isinstance(self, BatchNorm2d))
160 forard_hooks.append(module.register_forward_hook(hook))
162 is_before_relu = [False]
164 def register_backward_hook(module):
165 def hook(self, input, output):
166 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
167 is_before_relu.append(False)
169 is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
171 setattr(self, flag_name, True)
172 is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
174 backward_hooks.append(module.register_backward_hook(hook))
176 # multiple inputs to the network
177 if isinstance(input_size, tuple):
178 input_size = [input_size]
180 # batch_size of 2 for batchnorm
181 x = [torch.rand(batch_size, *in_size) for in_size in input_size]
184 model.apply(register_forward_hook)
185 model.apply(register_backward_hook)
187 # make a forward pass
189 output.sum().backward() # Raw output is not enabled to use backward()
192 for h in forard_hooks:
194 for h in backward_hooks:
197 model = convert(model, flag_name=flag_name)
198 model = remove_flags(model, flag_name=flag_name)