Safemotion Lib
Loading...
Searching...
No Matches
Classes | Functions
fastreid.layers.frn Namespace Reference

Classes

class  FRN
 
class  TLU
 

Functions

 bnrelu_to_frn (module)
 
 convert (module, flag_name)
 
 remove_flags (module, flag_name)
 
 bnrelu_to_frn2 (model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn')
 

Detailed Description

@author:  liaoxingyu
@contact: sherlockliao01@gmail.com

Function Documentation

◆ bnrelu_to_frn()

fastreid.layers.frn.bnrelu_to_frn ( module)
Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'

Definition at line 86 of file frn.py.

86def bnrelu_to_frn(module):
87 """
88 Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
89 """
90 mod = module
91 before_name = None
92 before_child = None
93 is_before_bn = False
94
95 for name, child in module.named_children():
96 if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
97 # Convert BN to FRN
98 if isinstance(before_child, BatchNorm2d):
99 mod.add_module(
100 before_name, FRN(num_features=before_child.num_features))
101 else:
102 raise NotImplementedError()
103
104 # Convert ReLU to TLU
105 mod.add_module(name, TLU(num_features=before_child.num_features))
106 else:
107 mod.add_module(name, bnrelu_to_frn(child))
108
109 before_name = name
110 before_child = child
111 is_before_bn = isinstance(child, BatchNorm2d)
112 return mod
113
114

◆ bnrelu_to_frn2()

fastreid.layers.frn.bnrelu_to_frn2 ( model,
input_size = (3, 128, 128),
batch_size = 2,
flag_name = 'is_convert_frn' )

Definition at line 142 of file frn.py.

142def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
143 forard_hooks = list()
144 backward_hooks = list()
145
146 is_before_bn = [False]
147
148 def register_forward_hook(module):
149 def hook(self, input, output):
150 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
151 is_before_bn.append(False)
152 return
153
154 # input and output is required in hook def
155 is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
156 if is_converted:
157 setattr(self, flag_name, True)
158 is_before_bn.append(isinstance(self, BatchNorm2d))
159
160 forard_hooks.append(module.register_forward_hook(hook))
161
162 is_before_relu = [False]
163
164 def register_backward_hook(module):
165 def hook(self, input, output):
166 if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
167 is_before_relu.append(False)
168 return
169 is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
170 if is_converted:
171 setattr(self, flag_name, True)
172 is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
173
174 backward_hooks.append(module.register_backward_hook(hook))
175
176 # multiple inputs to the network
177 if isinstance(input_size, tuple):
178 input_size = [input_size]
179
180 # batch_size of 2 for batchnorm
181 x = [torch.rand(batch_size, *in_size) for in_size in input_size]
182
183 # register hook
184 model.apply(register_forward_hook)
185 model.apply(register_backward_hook)
186
187 # make a forward pass
188 output = model(*x)
189 output.sum().backward() # Raw output is not enabled to use backward()
190
191 # remove these hooks
192 for h in forard_hooks:
193 h.remove()
194 for h in backward_hooks:
195 h.remove()
196
197 model = convert(model, flag_name=flag_name)
198 model = remove_flags(model, flag_name=flag_name)
199 return model

◆ convert()

fastreid.layers.frn.convert ( module,
flag_name )

Definition at line 115 of file frn.py.

115def convert(module, flag_name):
116 mod = module
117 before_ch = None
118 for name, child in module.named_children():
119 if hasattr(child, flag_name) and getattr(child, flag_name):
120 if isinstance(child, BatchNorm2d):
121 before_ch = child.num_features
122 mod.add_module(name, FRN(num_features=child.num_features))
123 # TODO bn is no good...
124 if isinstance(child, (ReLU, LeakyReLU)):
125 mod.add_module(name, TLU(num_features=before_ch))
126 else:
127 mod.add_module(name, convert(child, flag_name))
128 return mod
129
130

◆ remove_flags()

fastreid.layers.frn.remove_flags ( module,
flag_name )

Definition at line 131 of file frn.py.

131def remove_flags(module, flag_name):
132 mod = module
133 for name, child in module.named_children():
134 if hasattr(child, 'is_convert_frn'):
135 delattr(child, flag_name)
136 mod.add_module(name, remove_flags(child, flag_name))
137 else:
138 mod.add_module(name, remove_flags(child, flag_name))
139 return mod
140
141