Safemotion Lib
Loading...
Searching...
No Matches
swa.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6# based on:
7# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
8
9import warnings
10from collections import defaultdict
11
12import torch
13from torch.optim.optimizer import Optimizer
14
15
16class SWA(Optimizer):
17 def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
18 r"""Implements Stochastic Weight Averaging (SWA).
19 Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
20 Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
21 Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
22 (UAI 2018).
23 SWA is implemented as a wrapper class taking optimizer instance as input
24 and applying SWA on top of that optimizer.
25 SWA can be used in two modes: automatic and manual. In the automatic
26 mode SWA running averages are automatically updated every
27 :attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
28 :attr:`swa_lr` is provided, the learning rate of the optimizer is reset
29 to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
30 SWA in automatic mode provide values for both :attr:`swa_start` and
31 :attr:`swa_freq` arguments.
32 Alternatively, in the manual mode, use :meth:`update_swa` or
33 :meth:`update_swa_group` methods to update the SWA running averages.
34 In the end of training use `swap_swa_sgd` method to set the optimized
35 variables to the computed averages.
36 Args:
37 swa_freq (int): number of steps between subsequent updates of
38 SWA running averages in automatic mode; if None, manual mode is
39 selected (default: None)
40 swa_lr (float): learning rate to use starting from step swa_start
41 in automatic mode; if None, learning rate is not changed
42 (default: None)
43 Examples:
44 >>> # automatic mode
45 >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
46 >>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
47 >>> for _ in range(100):
48 >>> opt.zero_grad()
49 >>> loss_fn(model(input), target).backward()
50 >>> opt.step()
51 >>> opt.swap_swa_param()
52 >>> # manual mode
53 >>> opt = SWA(base_opt)
54 >>> for i in range(100):
55 >>> opt.zero_grad()
56 >>> loss_fn(model(input), target).backward()
57 >>> opt.step()
58 >>> if i > 10 and i % 5 == 0:
59 >>> opt.update_swa()
60 >>> opt.swap_swa_param()
61 .. note::
62 SWA does not support parameter-specific values of :attr:`swa_start`,
63 :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
64 same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
65 parameter groups. If needed, use manual mode with
66 :meth:`update_swa_group` to use different update schedules for
67 different parameter groups.
68 .. note::
69 Call :meth:`swap_swa_sgd` in the end of training to use the computed
70 running averages.
71 .. note::
72 If you are using SWA to optimize the parameters of a Neural Network
73 containing Batch Normalization layers, you need to update the
74 :attr:`running_mean` and :attr:`running_var` statistics of the
75 Batch Normalization module. You can do so by using
76 `torchcontrib.optim.swa.bn_update` utility.
77 .. note::
78 See the blogpost
79 https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
80 for an extended description of this SWA implementation.
81 .. note::
82 The repo https://github.com/izmailovpavel/contrib_swa_examples
83 contains examples of using this SWA implementation.
84 .. _Averaging Weights Leads to Wider Optima and Better Generalization:
85 https://arxiv.org/abs/1803.05407
86 .. _Improving Consistency-Based Semi-Supervised Learning with Weight
87 Averaging:
88 https://arxiv.org/abs/1806.05594
89 """
90 self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
91 self.swa_lr_factor = swa_lr_factor
92
93 if self._auto_mode:
94 if swa_freq < 1:
95 raise ValueError("Invalid swa_freq: {}".format(swa_freq))
96 else:
97 if self.swa_lr_factor is not None:
98 warnings.warn(
99 "Swa_freq is None, ignoring swa_lr")
100 # If not in auto mode make all swa parameters None
101 self.swa_lr_factor = None
102 self.swa_freq = None
103
104 if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
105 raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
106
107 self.optimizer = optimizer
108
109 self.defaults = self.optimizer.defaults
110 self.param_groups = self.optimizer.param_groups
111 self.state = defaultdict(dict)
112 self.opt_state = self.optimizer.state
113 for group in self.param_groups:
114 group['n_avg'] = 0
115 group['step_counter'] = 0
116
117 @staticmethod
118 def _check_params(swa_freq):
119 params = [swa_freq]
120 params_none = [param is None for param in params]
121 if not all(params_none) and any(params_none):
122 warnings.warn(
123 "Some of swa_start, swa_freq is None, ignoring other")
124 for i, param in enumerate(params):
125 if param is not None and not isinstance(param, int):
126 params[i] = int(param)
127 warnings.warn("Casting swa_start, swa_freq to int")
128 return not any(params_none), params
129
131 for param_group in self.param_groups:
132 param_group['initial_lr'] = self.swa_lr_factor * param_group['lr']
133
134 def update_swa_group(self, group):
135 r"""Updates the SWA running averages for the given parameter group.
136 Arguments:
137 group (dict): Specifies for what parameter group SWA running
138 averages should be updated
139 Examples:
140 >>> # automatic mode
141 >>> base_opt = torch.optim.SGD([{'params': [x]},
142 >>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
143 >>> opt = torchcontrib.optim.SWA(base_opt)
144 >>> for i in range(100):
145 >>> opt.zero_grad()
146 >>> loss_fn(model(input), target).backward()
147 >>> opt.step()
148 >>> if i > 10 and i % 5 == 0:
149 >>> # Update SWA for the second parameter group
150 >>> opt.update_swa_group(opt.param_groups[1])
151 >>> opt.swap_swa_param()
152 """
153 for p in group['params']:
154 param_state = self.state[p]
155 if 'swa_buffer' not in param_state:
156 param_state['swa_buffer'] = torch.zeros_like(p.data)
157 buf = param_state['swa_buffer']
158 virtual_decay = 1 / float(group["n_avg"] + 1)
159 diff = (p.data - buf) * virtual_decay
160 buf.add_(diff)
161 group["n_avg"] += 1
162
163 def update_swa(self):
164 r"""Updates the SWA running averages of all optimized parameters.
165 """
166 for group in self.param_groups:
167 self.update_swa_group(group)
168
169 def swap_swa_param(self):
170 r"""Swaps the values of the optimized variables and swa buffers.
171 It's meant to be called in the end of training to use the collected
172 swa running averages. It can also be used to evaluate the running
173 averages during training; to continue training `swap_swa_sgd`
174 should be called again.
175 """
176 for group in self.param_groups:
177 for p in group['params']:
178 param_state = self.state[p]
179 if 'swa_buffer' not in param_state:
180 # If swa wasn't applied we don't swap params
181 warnings.warn(
182 "SWA wasn't applied to param {}; skipping it".format(p))
183 continue
184 buf = param_state['swa_buffer']
185 tmp = torch.empty_like(p.data)
186 tmp.copy_(p.data)
187 p.data.copy_(buf)
188 buf.copy_(tmp)
189
190 def step(self, closure=None):
191 r"""Performs a single optimization step.
192 In automatic mode also updates SWA running averages.
193 """
194 loss = self.optimizer.step(closure)
195 for group in self.param_groups:
196 group["step_counter"] += 1
197 steps = group["step_counter"]
198 if self._auto_mode:
199 if steps % self.swa_freq == 0:
200 self.update_swa_group(group)
201 return loss
202
203 def state_dict(self):
204 r"""Returns the state of SWA as a :class:`dict`.
205 It contains three entries:
206 * opt_state - a dict holding current optimization state of the base
207 optimizer. Its content differs between optimizer classes.
208 * swa_state - a dict containing current state of SWA. For each
209 optimized variable it contains swa_buffer keeping the running
210 average of the variable
211 * param_groups - a dict containing all parameter groups
212 """
213 opt_state_dict = self.optimizer.state_dict()
214 swa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
215 for k, v in self.state.items()}
216 opt_state = opt_state_dict["state"]
217 param_groups = opt_state_dict["param_groups"]
218 return {"opt_state": opt_state, "swa_state": swa_state,
219 "param_groups": param_groups}
220
221 def load_state_dict(self, state_dict):
222 r"""Loads the optimizer state.
223 Args:
224 state_dict (dict): SWA optimizer state. Should be an object returned
225 from a call to `state_dict`.
226 """
227 swa_state_dict = {"state": state_dict["swa_state"],
228 "param_groups": state_dict["param_groups"]}
229 opt_state_dict = {"state": state_dict["opt_state"],
230 "param_groups": state_dict["param_groups"]}
231 super(SWA, self).load_state_dict(swa_state_dict)
232 self.optimizer.load_state_dict(opt_state_dict)
233 self.opt_state = self.optimizer.state
234
235 def add_param_group(self, param_group):
236 r"""Add a param group to the :class:`Optimizer` s `param_groups`.
237 This can be useful when fine tuning a pre-trained network as frozen
238 layers can be made trainable and added to the :class:`Optimizer` as
239 training progresses.
240 Args:
241 param_group (dict): Specifies what Tensors should be optimized along
242 with group specific optimization options.
243 """
244 param_group['n_avg'] = 0
245 param_group['step_counter'] = 0
246 self.optimizer.add_param_group(param_group)
step(self, closure=None)
Definition swa.py:190
add_param_group(self, param_group)
Definition swa.py:235
__init__(self, optimizer, swa_freq=None, swa_lr_factor=None)
Definition swa.py:17
update_swa_group(self, group)
Definition swa.py:134
load_state_dict(self, state_dict)
Definition swa.py:221