142def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
143 forard_hooks = list()
144 backward_hooks = list()
145
146 is_before_bn = [False]
147
148 def register_forward_hook(module):
149 def hook(self, input, output):
150 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
151 is_before_bn.append(False)
152 return
153
154
155 is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
156 if is_converted:
157 setattr(self, flag_name, True)
158 is_before_bn.append(isinstance(self, BatchNorm2d))
159
160 forard_hooks.append(module.register_forward_hook(hook))
161
162 is_before_relu = [False]
163
164 def register_backward_hook(module):
165 def hook(self, input, output):
166 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
167 is_before_relu.append(False)
168 return
169 is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
170 if is_converted:
171 setattr(self, flag_name, True)
172 is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
173
174 backward_hooks.append(module.register_backward_hook(hook))
175
176
177 if isinstance(input_size, tuple):
178 input_size = [input_size]
179
180
181 x = [torch.rand(batch_size, *in_size) for in_size in input_size]
182
183
184 model.apply(register_forward_hook)
185 model.apply(register_backward_hook)
186
187
188 output = model(*x)
189 output.sum().backward()
190
191
192 for h in forard_hooks:
193 h.remove()
194 for h in backward_hooks:
195 h.remove()
196
197 model = convert(model, flag_name=flag_name)
198 model = remove_flags(model, flag_name=flag_name)
199 return model