Source code for catalyst.contrib.nn.optimizers.qhadamw
from typing import Callable, Optional
import torch
from torch.optim.optimizer import Optimizer
[docs]class QHAdamW(Optimizer):
"""Implements QHAdam algorithm.
Combines QHAdam algorithm that was proposed in `Quasi-hyperbolic momentum
and Adam for deep learning`_ with weight decay decoupling from
`Decoupled Weight Decay Regularization`_ paper.
Example:
>>> optimizer = QHAdamW(
... model.parameters(),
... lr=3e-4, nus=(0.8, 1.0), betas=(0.99, 0.999))
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
Adapted from:
https://github.com/iprally/qhadamw-pytorch/blob/master/qhadamw.py
(MIT License)
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _Quasi-hyperbolic momentum and Adam for deep learning:
https://arxiv.org/abs/1810.06801
"""
[docs] def __init__(
self, params, lr=1e-3, betas=(0.995, 0.999), nus=(0.7, 1.0), weight_decay=0.0, eps=1e-8
):
r"""
Args:
params (iterable):
iterable of parameters to optimize or dicts defining parameter
groups
lr (float, optional): learning rate (:math:`\alpha` from the paper)
(default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for
computing running averages of the gradient and its square
(default: (0.995, 0.999))
nus (Tuple[float, float], optional): immediate discount factors
used to estimate the gradient and its square
(default: (0.7, 1.0))
eps (float, optional): term added to the denominator to improve
numerical stability
(default: 1e-8)
weight_decay (float, optional): weight decay
(L2 regularization coefficient, times two)
(default: 0.0)
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = {
"lr": lr,
"betas": betas,
"nus": nus,
"weight_decay": weight_decay,
"eps": eps,
}
super(QHAdamW, self).__init__(params, defaults)
def step(self, closure: Optional[Callable] = None):
"""Makes optimizer step.
Args:
closure (callable, optional): A closure that reevaluates
the model and returns the loss.
Returns:
computed loss
Raises:
RuntimeError: QHAdamW does not support sparse gradients
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
lr = group["lr"]
beta1, beta2 = group["betas"]
nu1, nu2 = group["nus"]
weight_decay = group["weight_decay"]
eps = group["eps"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad.data
if d_p.is_sparse:
raise RuntimeError("QHAdamW does not support sparse gradients")
param_state = self.state[p]
# Original QHAdam implementation for weight decay:
# if weight_decay != 0:
# d_p.add_(weight_decay, p.data)
d_p_sq = d_p.mul(d_p)
if len(param_state) == 0:
param_state["beta1_weight"] = 0.0
param_state["beta2_weight"] = 0.0
param_state["exp_avg"] = torch.zeros_like(p.data)
param_state["exp_avg_sq"] = torch.zeros_like(p.data)
param_state["beta1_weight"] = 1.0 + beta1 * param_state["beta1_weight"]
param_state["beta2_weight"] = 1.0 + beta2 * param_state["beta2_weight"]
beta1_weight = param_state["beta1_weight"]
beta2_weight = param_state["beta2_weight"]
exp_avg = param_state["exp_avg"]
exp_avg_sq = param_state["exp_avg_sq"]
beta1_adj = 1.0 - (1.0 / beta1_weight)
beta2_adj = 1.0 - (1.0 / beta2_weight)
exp_avg.mul_(beta1_adj).add_(1.0 - beta1_adj, d_p)
exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj, d_p_sq)
avg_grad = exp_avg.mul(nu1)
if nu1 != 1.0:
avg_grad.add_(1.0 - nu1, d_p)
avg_grad_rms = exp_avg_sq.mul(nu2)
if nu2 != 1.0:
avg_grad_rms.add_(1.0 - nu2, d_p_sq)
avg_grad_rms.sqrt_()
if eps != 0.0:
avg_grad_rms.add_(eps)
# Original QHAdam implementation:
# p.data.addcdiv_(-lr, avg_grad, avg_grad_rms)
# Implementation following AdamW paper:
p.data.add_(-weight_decay, p.data).addcdiv_(-lr, avg_grad, avg_grad_rms)
return loss
__all__ = ["QHAdamW"]