Source code for catalyst.contrib.nn.optimizers.qhadamw
import torch
from torch.optim.optimizer import Optimizer
[docs]class QHAdamW(Optimizer):
[docs] def __init__(
self,
params,
lr=1e-3,
betas=(0.995, 0.999),
nus=(0.7, 1.0),
weight_decay=0.0,
eps=1e-8
):
r"""
Combines the weight decay decoupling from AdamW (Decoupled Weight
Decay Regularization. Loshchilov and Hutter, 2019) with QHAdam
(Quasi-hyperbolic momentum and Adam for deep learning. Ma and
Yarats, 2019).
https://github.com/iprally/qhadamw-pytorch/blob/master/qhadamw.py
Args:
params (iterable):
iterable of parameters to optimize or dicts defining parameter
groups
lr (float, optional): learning rate (:math:`\alpha` from the paper)
(default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for
computing running averages of the gradient and its square
(default: (0.995, 0.999))
nus (Tuple[float, float], optional): immediate discount factors
used to estimate the gradient and its square
(default: (0.7, 1.0))
eps (float, optional): term added to the denominator to improve
numerical stability
(default: 1e-8)
weight_decay (float, optional): weight decay
(L2 regularization coefficient, times two)
(default: 0.0)
Example:
>>> optimizer = QHAdamW(
... model.parameters(),
... lr=3e-4, nus=(0.8, 1.0), betas=(0.99, 0.999))
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
QHAdam paper:
.. _`(Ma and Yarats, 2019)`: https://arxiv.org/abs/1810.06801
AdamW paper:
.. _`(Loshchilov and Hutter, 2019)`: https://arxiv.org/abs/1711.05101
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = {
"lr": lr,
"betas": betas,
"nus": nus,
"weight_decay": weight_decay,
"eps": eps
}
super(QHAdamW, self).__init__(params, defaults)
[docs] def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
lr = group["lr"]
beta1, beta2 = group["betas"]
nu1, nu2 = group["nus"]
weight_decay = group["weight_decay"]
eps = group["eps"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
if d_p.is_sparse:
raise RuntimeError(
"QHAdamW does not support sparse gradients"
)
param_state = self.state[p]
# Original QHAdam implementation for weight decay:
# if weight_decay != 0:
# d_p.add_(weight_decay, p.data)
d_p_sq = d_p.mul(d_p)
if len(param_state) == 0:
param_state["beta1_weight"] = 0.0
param_state["beta2_weight"] = 0.0
param_state["exp_avg"] = torch.zeros_like(p.data)
param_state["exp_avg_sq"] = torch.zeros_like(p.data)
param_state["beta1_weight"] = \
1.0 + beta1 * param_state["beta1_weight"]
param_state["beta2_weight"] = \
1.0 + beta2 * param_state["beta2_weight"]
beta1_weight = param_state["beta1_weight"]
beta2_weight = param_state["beta2_weight"]
exp_avg = param_state["exp_avg"]
exp_avg_sq = param_state["exp_avg_sq"]
beta1_adj = 1.0 - (1.0 / beta1_weight)
beta2_adj = 1.0 - (1.0 / beta2_weight)
exp_avg.mul_(beta1_adj).add_(1.0 - beta1_adj, d_p)
exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj, d_p_sq)
avg_grad = exp_avg.mul(nu1)
if nu1 != 1.0:
avg_grad.add_(1.0 - nu1, d_p)
avg_grad_rms = exp_avg_sq.mul(nu2)
if nu2 != 1.0:
avg_grad_rms.add_(1.0 - nu2, d_p_sq)
avg_grad_rms.sqrt_()
if eps != 0.0:
avg_grad_rms.add_(eps)
# Original QHAdam implementation:
# p.data.addcdiv_(-lr, avg_grad, avg_grad_rms)
# Implementation following AdamW paper:
p.data.add_(-weight_decay, p.data) \
.addcdiv_(-lr, avg_grad, avg_grad_rms)
return loss