32 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
33 weight_decay=0, amsgrad=
False):
35 raise ValueError(
"Invalid learning rate: {}".format(lr))
37 raise ValueError(
"Invalid epsilon value: {}".format(eps))
38 if not 0.0 <= betas[0] < 1.0:
39 raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0]))
40 if not 0.0 <= betas[1] < 1.0:
41 raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1]))
42 if not 0.0 <= weight_decay:
43 raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
44 defaults = dict(lr=lr, betas=betas, eps=eps,
45 weight_decay=weight_decay, amsgrad=amsgrad)
46 super(Adam, self).
__init__(params, defaults)
54 def step(self, closure=None):
55 """Performs a single optimization step.
57 closure (callable, optional): A closure that reevaluates the model
61 if closure
is not None:
62 with torch.enable_grad():
65 for group
in self.param_groups:
66 if group[
'freeze']:
continue
68 for p
in group[
'params']:
73 raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead')
74 amsgrad = group[
'amsgrad']
82 state[
'exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
84 state[
'exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
87 state[
'max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
89 exp_avg, exp_avg_sq = state[
'exp_avg'], state[
'exp_avg_sq']
91 max_exp_avg_sq = state[
'max_exp_avg_sq']
92 beta1, beta2 = group[
'betas']
95 bias_correction1 = 1 - beta1 ** state[
'step']
96 bias_correction2 = 1 - beta2 ** state[
'step']
98 if group[
'weight_decay'] != 0:
99 grad = grad.add(p, alpha=group[
'weight_decay'])
102 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
103 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
106 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108 denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group[
'eps'])
110 denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group[
'eps'])
112 step_size = group[
'lr'] / bias_correction1
114 p.addcdiv_(exp_avg, denom, value=-step_size)