4@contact: sherlockliao01@gmail.com
10from collections
import defaultdict
13from torch.optim.optimizer
import Optimizer
17 def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
18 r"""Implements Stochastic Weight Averaging (SWA).
19 Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
20 Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
21 Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
23 SWA is implemented as a wrapper class taking optimizer instance as input
24 and applying SWA on top of that optimizer.
25 SWA can be used in two modes: automatic and manual. In the automatic
26 mode SWA running averages are automatically updated every
27 :attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
28 :attr:`swa_lr` is provided, the learning rate of the optimizer is reset
29 to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
30 SWA in automatic mode provide values for both :attr:`swa_start` and
31 :attr:`swa_freq` arguments.
32 Alternatively, in the manual mode, use :meth:`update_swa` or
33 :meth:`update_swa_group` methods to update the SWA running averages.
34 In the end of training use `swap_swa_sgd` method to set the optimized
35 variables to the computed averages.
37 swa_freq (int): number of steps between subsequent updates of
38 SWA running averages in automatic mode; if None, manual mode is
39 selected (default: None)
40 swa_lr (float): learning rate to use starting from step swa_start
41 in automatic mode; if None, learning rate is not changed
45 >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
46 >>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
47 >>> for _ in range(100):
49 >>> loss_fn(model(input), target).backward()
51 >>> opt.swap_swa_param()
53 >>> opt = SWA(base_opt)
54 >>> for i in range(100):
56 >>> loss_fn(model(input), target).backward()
58 >>> if i > 10 and i % 5 == 0:
60 >>> opt.swap_swa_param()
62 SWA does not support parameter-specific values of :attr:`swa_start`,
63 :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
64 same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
65 parameter groups. If needed, use manual mode with
66 :meth:`update_swa_group` to use different update schedules for
67 different parameter groups.
69 Call :meth:`swap_swa_sgd` in the end of training to use the computed
72 If you are using SWA to optimize the parameters of a Neural Network
73 containing Batch Normalization layers, you need to update the
74 :attr:`running_mean` and :attr:`running_var` statistics of the
75 Batch Normalization module. You can do so by using
76 `torchcontrib.optim.swa.bn_update` utility.
79 https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
80 for an extended description of this SWA implementation.
82 The repo https://github.com/izmailovpavel/contrib_swa_examples
83 contains examples of using this SWA implementation.
84 .. _Averaging Weights Leads to Wider Optima and Better Generalization:
85 https://arxiv.org/abs/1803.05407
86 .. _Improving Consistency-Based Semi-Supervised Learning with Weight
88 https://arxiv.org/abs/1806.05594
95 raise ValueError(
"Invalid swa_freq: {}".format(swa_freq))
99 "Swa_freq is None, ignoring swa_lr")
105 raise ValueError(
"Invalid SWA learning rate factor: {}".format(swa_lr_factor))
115 group[
'step_counter'] = 0
120 params_none = [param
is None for param
in params]
121 if not all(params_none)
and any(params_none):
123 "Some of swa_start, swa_freq is None, ignoring other")
124 for i, param
in enumerate(params):
125 if param
is not None and not isinstance(param, int):
126 params[i] = int(param)
127 warnings.warn(
"Casting swa_start, swa_freq to int")
128 return not any(params_none), params
132 param_group[
'initial_lr'] = self.
swa_lr_factor * param_group[
'lr']
135 r"""Updates the SWA running averages for the given parameter group.
137 group (dict): Specifies for what parameter group SWA running
138 averages should be updated
141 >>> base_opt = torch.optim.SGD([{'params': [x]},
142 >>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
143 >>> opt = torchcontrib.optim.SWA(base_opt)
144 >>> for i in range(100):
146 >>> loss_fn(model(input), target).backward()
148 >>> if i > 10 and i % 5 == 0:
149 >>> # Update SWA for the second parameter group
150 >>> opt.update_swa_group(opt.param_groups[1])
151 >>> opt.swap_swa_param()
153 for p
in group[
'params']:
154 param_state = self.
state[p]
155 if 'swa_buffer' not in param_state:
156 param_state[
'swa_buffer'] = torch.zeros_like(p.data)
157 buf = param_state[
'swa_buffer']
158 virtual_decay = 1 / float(group[
"n_avg"] + 1)
159 diff = (p.data - buf) * virtual_decay
164 r"""Updates the SWA running averages of all optimized parameters.
170 r"""Swaps the values of the optimized variables and swa buffers.
171 It's meant to be called in the end of training to use the collected
172 swa running averages. It can also be used to evaluate the running
173 averages during training; to continue training `swap_swa_sgd`
174 should be called again.
177 for p
in group[
'params']:
178 param_state = self.
state[p]
179 if 'swa_buffer' not in param_state:
182 "SWA wasn't applied to param {}; skipping it".format(p))
184 buf = param_state[
'swa_buffer']
185 tmp = torch.empty_like(p.data)
191 r"""Performs a single optimization step.
192 In automatic mode also updates SWA running averages.
196 group[
"step_counter"] += 1
197 steps = group[
"step_counter"]
204 r"""Returns the state of SWA as a :class:`dict`.
205 It contains three entries:
206 * opt_state - a dict holding current optimization state of the base
207 optimizer. Its content differs between optimizer classes.
208 * swa_state - a dict containing current state of SWA. For each
209 optimized variable it contains swa_buffer keeping the running
210 average of the variable
211 * param_groups - a dict containing all parameter groups
214 swa_state = {(id(k)
if isinstance(k, torch.Tensor)
else k): v
215 for k, v
in self.
state.items()}
216 opt_state = opt_state_dict[
"state"]
217 param_groups = opt_state_dict[
"param_groups"]
218 return {
"opt_state": opt_state,
"swa_state": swa_state,
219 "param_groups": param_groups}
222 r"""Loads the optimizer state.
224 state_dict (dict): SWA optimizer state. Should be an object returned
225 from a call to `state_dict`.
227 swa_state_dict = {
"state": state_dict[
"swa_state"],
228 "param_groups": state_dict[
"param_groups"]}
229 opt_state_dict = {
"state": state_dict[
"opt_state"],
230 "param_groups": state_dict[
"param_groups"]}
236 r"""Add a param group to the :class:`Optimizer` s `param_groups`.
237 This can be useful when fine tuning a pre-trained network as frozen
238 layers can be made trainable and added to the :class:`Optimizer` as
241 param_group (dict): Specifies what Tensors should be optimized along
242 with group specific optimization options.
244 param_group[
'n_avg'] = 0
245 param_group[
'step_counter'] = 0
add_param_group(self, param_group)
__init__(self, optimizer, swa_freq=None, swa_lr_factor=None)
update_swa_group(self, group)
load_state_dict(self, state_dict)