17 def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
18 r"""Implements Stochastic Weight Averaging (SWA).
19 Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
20 Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
21 Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
22 (UAI 2018).
23 SWA is implemented as a wrapper class taking optimizer instance as input
24 and applying SWA on top of that optimizer.
25 SWA can be used in two modes: automatic and manual. In the automatic
26 mode SWA running averages are automatically updated every
27 :attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
28 :attr:`swa_lr` is provided, the learning rate of the optimizer is reset
29 to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
30 SWA in automatic mode provide values for both :attr:`swa_start` and
31 :attr:`swa_freq` arguments.
32 Alternatively, in the manual mode, use :meth:`update_swa` or
33 :meth:`update_swa_group` methods to update the SWA running averages.
34 In the end of training use `swap_swa_sgd` method to set the optimized
35 variables to the computed averages.
36 Args:
37 swa_freq (int): number of steps between subsequent updates of
38 SWA running averages in automatic mode; if None, manual mode is
39 selected (default: None)
40 swa_lr (float): learning rate to use starting from step swa_start
41 in automatic mode; if None, learning rate is not changed
42 (default: None)
43 Examples:
44 >>> # automatic mode
45 >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
46 >>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
47 >>> for _ in range(100):
48 >>> opt.zero_grad()
49 >>> loss_fn(model(input), target).backward()
50 >>> opt.step()
51 >>> opt.swap_swa_param()
52 >>> # manual mode
53 >>> opt = SWA(base_opt)
54 >>> for i in range(100):
55 >>> opt.zero_grad()
56 >>> loss_fn(model(input), target).backward()
57 >>> opt.step()
58 >>> if i > 10 and i % 5 == 0:
59 >>> opt.update_swa()
60 >>> opt.swap_swa_param()
61 .. note::
62 SWA does not support parameter-specific values of :attr:`swa_start`,
63 :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
64 same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
65 parameter groups. If needed, use manual mode with
66 :meth:`update_swa_group` to use different update schedules for
67 different parameter groups.
68 .. note::
69 Call :meth:`swap_swa_sgd` in the end of training to use the computed
70 running averages.
71 .. note::
72 If you are using SWA to optimize the parameters of a Neural Network
73 containing Batch Normalization layers, you need to update the
74 :attr:`running_mean` and :attr:`running_var` statistics of the
75 Batch Normalization module. You can do so by using
76 `torchcontrib.optim.swa.bn_update` utility.
77 .. note::
78 See the blogpost
79 https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
80 for an extended description of this SWA implementation.
81 .. note::
82 The repo https://github.com/izmailovpavel/contrib_swa_examples
83 contains examples of using this SWA implementation.
84 .. _Averaging Weights Leads to Wider Optima and Better Generalization:
85 https://arxiv.org/abs/1803.05407
86 .. _Improving Consistency-Based Semi-Supervised Learning with Weight
87 Averaging:
88 https://arxiv.org/abs/1806.05594
89 """
90 self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
91 self.swa_lr_factor = swa_lr_factor
92
93 if self._auto_mode:
94 if swa_freq < 1:
95 raise ValueError("Invalid swa_freq: {}".format(swa_freq))
96 else:
97 if self.swa_lr_factor is not None:
98 warnings.warn(
99 "Swa_freq is None, ignoring swa_lr")
100
101 self.swa_lr_factor = None
102 self.swa_freq = None
103
104 if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
105 raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
106
107 self.optimizer = optimizer
108
109 self.defaults = self.optimizer.defaults
110 self.param_groups = self.optimizer.param_groups
111 self.state = defaultdict(dict)
112 self.opt_state = self.optimizer.state
113 for group in self.param_groups:
114 group['n_avg'] = 0
115 group['step_counter'] = 0
116