Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Static Protected Member Functions | Protected Attributes | List of all members
fastreid.solver.optim.swa.SWA Class Reference
Inheritance diagram for fastreid.solver.optim.swa.SWA:

Public Member Functions

 __init__ (self, optimizer, swa_freq=None, swa_lr_factor=None)
 
 reset_lr_to_swa (self)
 
 update_swa_group (self, group)
 
 update_swa (self)
 
 swap_swa_param (self)
 
 step (self, closure=None)
 
 state_dict (self)
 
 load_state_dict (self, state_dict)
 
 add_param_group (self, param_group)
 

Public Attributes

 swa_freq
 
 swa_lr_factor
 
 optimizer
 
 defaults
 
 param_groups
 
 state
 
 opt_state
 

Static Protected Member Functions

 _check_params (swa_freq)
 

Protected Attributes

 _auto_mode
 

Detailed Description

Definition at line 16 of file swa.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.solver.optim.swa.SWA.__init__ ( self,
optimizer,
swa_freq = None,
swa_lr_factor = None )
Implements Stochastic Weight Averaging (SWA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
SWA is implemented as a wrapper class taking optimizer instance as input
and applying SWA on top of that optimizer.
SWA can be used in two modes: automatic and manual. In the automatic
mode SWA running averages are automatically updated every
:attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
:attr:`swa_lr` is provided, the learning rate of the optimizer is reset
to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
SWA in automatic mode provide values for both :attr:`swa_start` and
:attr:`swa_freq` arguments.
Alternatively, in the manual mode, use :meth:`update_swa` or
:meth:`update_swa_group` methods to update the SWA running averages.
In the end of training use `swap_swa_sgd` method to set the optimized
variables to the computed averages.
Args:
    swa_freq (int): number of steps between subsequent updates of
        SWA running averages in automatic mode; if None, manual mode is
        selected (default: None)
    swa_lr (float): learning rate to use starting from step swa_start
        in automatic mode; if None, learning rate is not changed
        (default: None)
Examples:
    >>> # automatic mode
    >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
    >>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
    >>> for _ in range(100):
    >>>     opt.zero_grad()
    >>>     loss_fn(model(input), target).backward()
    >>>     opt.step()
    >>> opt.swap_swa_param()
    >>> # manual mode
    >>> opt = SWA(base_opt)
    >>> for i in range(100):
    >>>     opt.zero_grad()
    >>>     loss_fn(model(input), target).backward()
    >>>     opt.step()
    >>>     if i > 10 and i % 5 == 0:
    >>>         opt.update_swa()
    >>> opt.swap_swa_param()
.. note::
    SWA does not support parameter-specific values of :attr:`swa_start`,
    :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
    same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
    parameter groups. If needed, use manual mode with
    :meth:`update_swa_group` to use different update schedules for
    different parameter groups.
.. note::
    Call :meth:`swap_swa_sgd` in the end of training to use the computed
    running averages.
.. note::
    If you are using SWA to optimize the parameters of a Neural Network
    containing Batch Normalization layers, you need to update the
    :attr:`running_mean` and :attr:`running_var` statistics of the
    Batch Normalization module. You can do so by using
    `torchcontrib.optim.swa.bn_update` utility.
.. note::
    See the blogpost
    https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
    for an extended description of this SWA implementation.
.. note::
    The repo https://github.com/izmailovpavel/contrib_swa_examples
    contains examples of using this SWA implementation.
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
    https://arxiv.org/abs/1803.05407
.. _Improving Consistency-Based Semi-Supervised Learning with Weight
    Averaging:
    https://arxiv.org/abs/1806.05594

Definition at line 17 of file swa.py.

17 def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
18 r"""Implements Stochastic Weight Averaging (SWA).
19 Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
20 Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
21 Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
22 (UAI 2018).
23 SWA is implemented as a wrapper class taking optimizer instance as input
24 and applying SWA on top of that optimizer.
25 SWA can be used in two modes: automatic and manual. In the automatic
26 mode SWA running averages are automatically updated every
27 :attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
28 :attr:`swa_lr` is provided, the learning rate of the optimizer is reset
29 to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
30 SWA in automatic mode provide values for both :attr:`swa_start` and
31 :attr:`swa_freq` arguments.
32 Alternatively, in the manual mode, use :meth:`update_swa` or
33 :meth:`update_swa_group` methods to update the SWA running averages.
34 In the end of training use `swap_swa_sgd` method to set the optimized
35 variables to the computed averages.
36 Args:
37 swa_freq (int): number of steps between subsequent updates of
38 SWA running averages in automatic mode; if None, manual mode is
39 selected (default: None)
40 swa_lr (float): learning rate to use starting from step swa_start
41 in automatic mode; if None, learning rate is not changed
42 (default: None)
43 Examples:
44 >>> # automatic mode
45 >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
46 >>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
47 >>> for _ in range(100):
48 >>> opt.zero_grad()
49 >>> loss_fn(model(input), target).backward()
50 >>> opt.step()
51 >>> opt.swap_swa_param()
52 >>> # manual mode
53 >>> opt = SWA(base_opt)
54 >>> for i in range(100):
55 >>> opt.zero_grad()
56 >>> loss_fn(model(input), target).backward()
57 >>> opt.step()
58 >>> if i > 10 and i % 5 == 0:
59 >>> opt.update_swa()
60 >>> opt.swap_swa_param()
61 .. note::
62 SWA does not support parameter-specific values of :attr:`swa_start`,
63 :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
64 same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
65 parameter groups. If needed, use manual mode with
66 :meth:`update_swa_group` to use different update schedules for
67 different parameter groups.
68 .. note::
69 Call :meth:`swap_swa_sgd` in the end of training to use the computed
70 running averages.
71 .. note::
72 If you are using SWA to optimize the parameters of a Neural Network
73 containing Batch Normalization layers, you need to update the
74 :attr:`running_mean` and :attr:`running_var` statistics of the
75 Batch Normalization module. You can do so by using
76 `torchcontrib.optim.swa.bn_update` utility.
77 .. note::
78 See the blogpost
79 https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
80 for an extended description of this SWA implementation.
81 .. note::
82 The repo https://github.com/izmailovpavel/contrib_swa_examples
83 contains examples of using this SWA implementation.
84 .. _Averaging Weights Leads to Wider Optima and Better Generalization:
85 https://arxiv.org/abs/1803.05407
86 .. _Improving Consistency-Based Semi-Supervised Learning with Weight
87 Averaging:
88 https://arxiv.org/abs/1806.05594
89 """
90 self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
91 self.swa_lr_factor = swa_lr_factor
92
93 if self._auto_mode:
94 if swa_freq < 1:
95 raise ValueError("Invalid swa_freq: {}".format(swa_freq))
96 else:
97 if self.swa_lr_factor is not None:
98 warnings.warn(
99 "Swa_freq is None, ignoring swa_lr")
100 # If not in auto mode make all swa parameters None
101 self.swa_lr_factor = None
102 self.swa_freq = None
103
104 if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
105 raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
106
107 self.optimizer = optimizer
108
109 self.defaults = self.optimizer.defaults
110 self.param_groups = self.optimizer.param_groups
111 self.state = defaultdict(dict)
112 self.opt_state = self.optimizer.state
113 for group in self.param_groups:
114 group['n_avg'] = 0
115 group['step_counter'] = 0
116

Member Function Documentation

◆ _check_params()

fastreid.solver.optim.swa.SWA._check_params ( swa_freq)
staticprotected

Definition at line 118 of file swa.py.

118 def _check_params(swa_freq):
119 params = [swa_freq]
120 params_none = [param is None for param in params]
121 if not all(params_none) and any(params_none):
122 warnings.warn(
123 "Some of swa_start, swa_freq is None, ignoring other")
124 for i, param in enumerate(params):
125 if param is not None and not isinstance(param, int):
126 params[i] = int(param)
127 warnings.warn("Casting swa_start, swa_freq to int")
128 return not any(params_none), params
129

◆ add_param_group()

fastreid.solver.optim.swa.SWA.add_param_group ( self,
param_group )
Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.

Definition at line 235 of file swa.py.

235 def add_param_group(self, param_group):
236 r"""Add a param group to the :class:`Optimizer` s `param_groups`.
237 This can be useful when fine tuning a pre-trained network as frozen
238 layers can be made trainable and added to the :class:`Optimizer` as
239 training progresses.
240 Args:
241 param_group (dict): Specifies what Tensors should be optimized along
242 with group specific optimization options.
243 """
244 param_group['n_avg'] = 0
245 param_group['step_counter'] = 0
246 self.optimizer.add_param_group(param_group)

◆ load_state_dict()

fastreid.solver.optim.swa.SWA.load_state_dict ( self,
state_dict )
Loads the optimizer state.
Args:
state_dict (dict): SWA optimizer state. Should be an object returned
from a call to `state_dict`.

Definition at line 221 of file swa.py.

221 def load_state_dict(self, state_dict):
222 r"""Loads the optimizer state.
223 Args:
224 state_dict (dict): SWA optimizer state. Should be an object returned
225 from a call to `state_dict`.
226 """
227 swa_state_dict = {"state": state_dict["swa_state"],
228 "param_groups": state_dict["param_groups"]}
229 opt_state_dict = {"state": state_dict["opt_state"],
230 "param_groups": state_dict["param_groups"]}
231 super(SWA, self).load_state_dict(swa_state_dict)
232 self.optimizer.load_state_dict(opt_state_dict)
233 self.opt_state = self.optimizer.state
234

◆ reset_lr_to_swa()

fastreid.solver.optim.swa.SWA.reset_lr_to_swa ( self)

Definition at line 130 of file swa.py.

130 def reset_lr_to_swa(self):
131 for param_group in self.param_groups:
132 param_group['initial_lr'] = self.swa_lr_factor * param_group['lr']
133

◆ state_dict()

fastreid.solver.optim.swa.SWA.state_dict ( self)
Returns the state of SWA as a :class:`dict`.
It contains three entries:
* opt_state - a dict holding current optimization state of the base
optimizer. Its content differs between optimizer classes.
* swa_state - a dict containing current state of SWA. For each
optimized variable it contains swa_buffer keeping the running
average of the variable
* param_groups - a dict containing all parameter groups

Definition at line 203 of file swa.py.

203 def state_dict(self):
204 r"""Returns the state of SWA as a :class:`dict`.
205 It contains three entries:
206 * opt_state - a dict holding current optimization state of the base
207 optimizer. Its content differs between optimizer classes.
208 * swa_state - a dict containing current state of SWA. For each
209 optimized variable it contains swa_buffer keeping the running
210 average of the variable
211 * param_groups - a dict containing all parameter groups
212 """
213 opt_state_dict = self.optimizer.state_dict()
214 swa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
215 for k, v in self.state.items()}
216 opt_state = opt_state_dict["state"]
217 param_groups = opt_state_dict["param_groups"]
218 return {"opt_state": opt_state, "swa_state": swa_state,
219 "param_groups": param_groups}
220

◆ step()

fastreid.solver.optim.swa.SWA.step ( self,
closure = None )
Performs a single optimization step.
In automatic mode also updates SWA running averages.

Definition at line 190 of file swa.py.

190 def step(self, closure=None):
191 r"""Performs a single optimization step.
192 In automatic mode also updates SWA running averages.
193 """
194 loss = self.optimizer.step(closure)
195 for group in self.param_groups:
196 group["step_counter"] += 1
197 steps = group["step_counter"]
198 if self._auto_mode:
199 if steps % self.swa_freq == 0:
200 self.update_swa_group(group)
201 return loss
202

◆ swap_swa_param()

fastreid.solver.optim.swa.SWA.swap_swa_param ( self)
Swaps the values of the optimized variables and swa buffers.
It's meant to be called in the end of training to use the collected
swa running averages. It can also be used to evaluate the running
averages during training; to continue training `swap_swa_sgd`
should be called again.

Definition at line 169 of file swa.py.

169 def swap_swa_param(self):
170 r"""Swaps the values of the optimized variables and swa buffers.
171 It's meant to be called in the end of training to use the collected
172 swa running averages. It can also be used to evaluate the running
173 averages during training; to continue training `swap_swa_sgd`
174 should be called again.
175 """
176 for group in self.param_groups:
177 for p in group['params']:
178 param_state = self.state[p]
179 if 'swa_buffer' not in param_state:
180 # If swa wasn't applied we don't swap params
181 warnings.warn(
182 "SWA wasn't applied to param {}; skipping it".format(p))
183 continue
184 buf = param_state['swa_buffer']
185 tmp = torch.empty_like(p.data)
186 tmp.copy_(p.data)
187 p.data.copy_(buf)
188 buf.copy_(tmp)
189

◆ update_swa()

fastreid.solver.optim.swa.SWA.update_swa ( self)
Updates the SWA running averages of all optimized parameters.

Definition at line 163 of file swa.py.

163 def update_swa(self):
164 r"""Updates the SWA running averages of all optimized parameters.
165 """
166 for group in self.param_groups:
167 self.update_swa_group(group)
168

◆ update_swa_group()

fastreid.solver.optim.swa.SWA.update_swa_group ( self,
group )
Updates the SWA running averages for the given parameter group.
Arguments:
group (dict): Specifies for what parameter group SWA running
averages should be updated
Examples:
>>> # automatic mode
>>> base_opt = torch.optim.SGD([{'params': [x]},
>>>             {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
>>> opt = torchcontrib.optim.SWA(base_opt)
>>> for i in range(100):
>>>     opt.zero_grad()
>>>     loss_fn(model(input), target).backward()
>>>     opt.step()
>>>     if i > 10 and i % 5 == 0:
>>>         # Update SWA for the second parameter group
>>>         opt.update_swa_group(opt.param_groups[1])
>>> opt.swap_swa_param()

Definition at line 134 of file swa.py.

134 def update_swa_group(self, group):
135 r"""Updates the SWA running averages for the given parameter group.
136 Arguments:
137 group (dict): Specifies for what parameter group SWA running
138 averages should be updated
139 Examples:
140 >>> # automatic mode
141 >>> base_opt = torch.optim.SGD([{'params': [x]},
142 >>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
143 >>> opt = torchcontrib.optim.SWA(base_opt)
144 >>> for i in range(100):
145 >>> opt.zero_grad()
146 >>> loss_fn(model(input), target).backward()
147 >>> opt.step()
148 >>> if i > 10 and i % 5 == 0:
149 >>> # Update SWA for the second parameter group
150 >>> opt.update_swa_group(opt.param_groups[1])
151 >>> opt.swap_swa_param()
152 """
153 for p in group['params']:
154 param_state = self.state[p]
155 if 'swa_buffer' not in param_state:
156 param_state['swa_buffer'] = torch.zeros_like(p.data)
157 buf = param_state['swa_buffer']
158 virtual_decay = 1 / float(group["n_avg"] + 1)
159 diff = (p.data - buf) * virtual_decay
160 buf.add_(diff)
161 group["n_avg"] += 1
162

Member Data Documentation

◆ _auto_mode

fastreid.solver.optim.swa.SWA._auto_mode
protected

Definition at line 90 of file swa.py.

◆ defaults

fastreid.solver.optim.swa.SWA.defaults

Definition at line 109 of file swa.py.

◆ opt_state

fastreid.solver.optim.swa.SWA.opt_state

Definition at line 112 of file swa.py.

◆ optimizer

fastreid.solver.optim.swa.SWA.optimizer

Definition at line 107 of file swa.py.

◆ param_groups

fastreid.solver.optim.swa.SWA.param_groups

Definition at line 110 of file swa.py.

◆ state

fastreid.solver.optim.swa.SWA.state

Definition at line 111 of file swa.py.

◆ swa_freq

fastreid.solver.optim.swa.SWA.swa_freq

Definition at line 90 of file swa.py.

◆ swa_lr_factor

fastreid.solver.optim.swa.SWA.swa_lr_factor

Definition at line 91 of file swa.py.


The documentation for this class was generated from the following file: