Safemotion Lib
Loading...
Searching...
No Matches
sgd.py
Go to the documentation of this file.
1import torch
2from torch.optim.optimizer import Optimizer, required
3
4
5class SGD(Optimizer):
6 r"""Implements stochastic gradient descent (optionally with momentum).
7 Nesterov momentum is based on the formula from
8 `On the importance of initialization and momentum in deep learning`__.
9 Args:
10 params (iterable): iterable of parameters to optimize or dicts defining
11 parameter groups
12 lr (float): learning rate
13 momentum (float, optional): momentum factor (default: 0)
14 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
15 dampening (float, optional): dampening for momentum (default: 0)
16 nesterov (bool, optional): enables Nesterov momentum (default: False)
17 Example:
18 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
19 >>> optimizer.zero_grad()
20 >>> loss_fn(model(input), target).backward()
21 >>> optimizer.step()
22 __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
23 .. note::
24 The implementation of SGD with Momentum/Nesterov subtly differs from
25 Sutskever et. al. and implementations in some other frameworks.
26 Considering the specific case of Momentum, the update can be written as
27 .. math::
28 \begin{aligned}
29 v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
30 p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
31 \end{aligned}
32 where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
33 parameters, gradient, velocity, and momentum respectively.
34 This is in contrast to Sutskever et. al. and
35 other frameworks which employ an update of the form
36 .. math::
37 \begin{aligned}
38 v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
39 p_{t+1} & = p_{t} - v_{t+1}.
40 \end{aligned}
41 The Nesterov version is analogously modified.
42 """
43
44 def __init__(self, params, lr=required, momentum=0, dampening=0,
45 weight_decay=0, nesterov=False):
46 if lr is not required and lr < 0.0:
47 raise ValueError("Invalid learning rate: {}".format(lr))
48 if momentum < 0.0:
49 raise ValueError("Invalid momentum value: {}".format(momentum))
50 if weight_decay < 0.0:
51 raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
52
53 defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
54 weight_decay=weight_decay, nesterov=nesterov)
55 if nesterov and (momentum <= 0 or dampening != 0):
56 raise ValueError("Nesterov momentum requires a momentum and zero dampening")
57 super(SGD, self).__init__(params, defaults)
58
59 def __setstate__(self, state):
60 super(SGD, self).__setstate__(state)
61 for group in self.param_groups:
62 group.setdefault('nesterov', False)
63
64 @torch.no_grad()
65 def step(self, closure=None):
66 """Performs a single optimization step.
67 Arguments:
68 closure (callable, optional): A closure that reevaluates the model
69 and returns the loss.
70 """
71 loss = None
72 if closure is not None:
73 with torch.enable_grad():
74 loss = closure()
75
76 for group in self.param_groups:
77 if group['freeze']: continue
78
79 weight_decay = group['weight_decay']
80 momentum = group['momentum']
81 dampening = group['dampening']
82 nesterov = group['nesterov']
83
84 for p in group['params']:
85 if p.grad is None:
86 continue
87 d_p = p.grad
88 if weight_decay != 0:
89 d_p = d_p.add(p, alpha=weight_decay)
90 if momentum != 0:
91 param_state = self.state[p]
92 if 'momentum_buffer' not in param_state:
93 buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
94 else:
95 buf = param_state['momentum_buffer']
96 buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
97 if nesterov:
98 d_p = d_p.add(buf, alpha=momentum)
99 else:
100 d_p = buf
101
102 p.add_(d_p, alpha=-group['lr'])
103
104 return loss
__setstate__(self, state)
Definition sgd.py:59
__init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False)
Definition sgd.py:45
step(self, closure=None)
Definition sgd.py:65