54 def step(self, closure=None):
55 """Performs a single optimization step.
56 Arguments:
57 closure (callable, optional): A closure that reevaluates the model
58 and returns the loss.
59 """
60 loss = None
61 if closure is not None:
62 with torch.enable_grad():
63 loss = closure()
64
65 for group in self.param_groups:
66 if group['freeze']: continue
67
68 for p in group['params']:
69 if p.grad is None:
70 continue
71 grad = p.grad
72 if grad.is_sparse:
73 raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
74 amsgrad = group['amsgrad']
75
76 state = self.state[p]
77
78
79 if len(state) == 0:
80 state['step'] = 0
81
82 state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
83
84 state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
85 if amsgrad:
86
87 state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
88
89 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90 if amsgrad:
91 max_exp_avg_sq = state['max_exp_avg_sq']
92 beta1, beta2 = group['betas']
93
94 state['step'] += 1
95 bias_correction1 = 1 - beta1 ** state['step']
96 bias_correction2 = 1 - beta2 ** state['step']
97
98 if group['weight_decay'] != 0:
99 grad = grad.add(p, alpha=group['weight_decay'])
100
101
102 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
103 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
104 if amsgrad:
105
106 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
107
108 denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
109 else:
110 denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
111
112 step_size = group['lr'] / bias_correction1
113
114 p.addcdiv_(exp_avg, denom, value=-step_size)
115
116 return loss