Safemotion Lib
Loading...
Searching...
No Matches
adam.py
Go to the documentation of this file.
1import math
2
3import torch
4from torch.optim.optimizer import Optimizer
5
6
7class Adam(Optimizer):
8 r"""Implements Adam algorithm.
9 It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 The implementation of the L2 penalty follows changes proposed in
11 `Decoupled Weight Decay Regularization`_.
12 Arguments:
13 params (iterable): iterable of parameters to optimize or dicts defining
14 parameter groups
15 lr (float, optional): learning rate (default: 1e-3)
16 betas (Tuple[float, float], optional): coefficients used for computing
17 running averages of gradient and its square (default: (0.9, 0.999))
18 eps (float, optional): term added to the denominator to improve
19 numerical stability (default: 1e-8)
20 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
21 amsgrad (boolean, optional): whether to use the AMSGrad variant of this
22 algorithm from the paper `On the Convergence of Adam and Beyond`_
23 (default: False)
24 .. _Adam\: A Method for Stochastic Optimization:
25 https://arxiv.org/abs/1412.6980
26 .. _Decoupled Weight Decay Regularization:
27 https://arxiv.org/abs/1711.05101
28 .. _On the Convergence of Adam and Beyond:
29 https://openreview.net/forum?id=ryQu7f-RZ
30 """
31
32 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
33 weight_decay=0, amsgrad=False):
34 if not 0.0 <= lr:
35 raise ValueError("Invalid learning rate: {}".format(lr))
36 if not 0.0 <= eps:
37 raise ValueError("Invalid epsilon value: {}".format(eps))
38 if not 0.0 <= betas[0] < 1.0:
39 raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40 if not 0.0 <= betas[1] < 1.0:
41 raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42 if not 0.0 <= weight_decay:
43 raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
44 defaults = dict(lr=lr, betas=betas, eps=eps,
45 weight_decay=weight_decay, amsgrad=amsgrad)
46 super(Adam, self).__init__(params, defaults)
47
48 def __setstate__(self, state):
49 super(Adam, self).__setstate__(state)
50 for group in self.param_groups:
51 group.setdefault('amsgrad', False)
52
53 @torch.no_grad()
54 def step(self, closure=None):
55 """Performs a single optimization step.
56 Arguments:
57 closure (callable, optional): A closure that reevaluates the model
58 and returns the loss.
59 """
60 loss = None
61 if closure is not None:
62 with torch.enable_grad():
63 loss = closure()
64
65 for group in self.param_groups:
66 if group['freeze']: continue
67
68 for p in group['params']:
69 if p.grad is None:
70 continue
71 grad = p.grad
72 if grad.is_sparse:
73 raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
74 amsgrad = group['amsgrad']
75
76 state = self.state[p]
77
78 # State initialization
79 if len(state) == 0:
80 state['step'] = 0
81 # Exponential moving average of gradient values
82 state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
83 # Exponential moving average of squared gradient values
84 state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
85 if amsgrad:
86 # Maintains max of all exp. moving avg. of sq. grad. values
87 state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
88
89 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 if amsgrad:
91 max_exp_avg_sq = state['max_exp_avg_sq']
92 beta1, beta2 = group['betas']
93
94 state['step'] += 1
95 bias_correction1 = 1 - beta1 ** state['step']
96 bias_correction2 = 1 - beta2 ** state['step']
97
98 if group['weight_decay'] != 0:
99 grad = grad.add(p, alpha=group['weight_decay'])
100
101 # Decay the first and second moment running average coefficient
102 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
103 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
104 if amsgrad:
105 # Maintains the maximum of all 2nd moment running avg. till now
106 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
107 # Use the max. for normalizing running avg. of gradient
108 denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
109 else:
110 denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
111
112 step_size = group['lr'] / bias_correction1
113
114 p.addcdiv_(exp_avg, denom, value=-step_size)
115
116 return loss
step(self, closure=None)
Definition adam.py:54
__init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)
Definition adam.py:33
__setstate__(self, state)
Definition adam.py:48