Source code for catalyst.contrib.nn.optimizers.qhadamw
from typing import Callable, Optional
import torch
from torch.optim.optimizer import Optimizer
[docs]class QHAdamW(Optimizer):
    """Implements QHAdam algorithm.
    Combines QHAdam algorithm that was proposed in  `Quasi-hyperbolic momentum
    and Adam for deep learning`_ with weight decay decoupling from
    `Decoupled Weight Decay Regularization`_ paper.
    Example:
        >>> optimizer = QHAdamW(
        ...     model.parameters(),
        ...     lr=3e-4, nus=(0.8, 1.0), betas=(0.99, 0.999))
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    Main origins of inspiration:
        https://github.com/iprally/qhadamw-pytorch/blob/master/qhadamw.py
        (MIT License)
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _Quasi-hyperbolic momentum and Adam for deep learning:
        https://arxiv.org/abs/1810.06801
    """
[docs]    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.995, 0.999),
        nus=(0.7, 1.0),
        weight_decay=0.0,
        eps=1e-8,
    ):
        r"""
        Args:
            params (iterable):
                iterable of parameters to optimize or dicts defining parameter
                groups
            lr (float, optional): learning rate (:math:`\alpha` from the paper)
                (default: 1e-3)
            betas (Tuple[float, float], optional): coefficients used for
                computing running averages of the gradient and its square
                (default: (0.995, 0.999))
            nus (Tuple[float, float], optional): immediate discount factors
                used to estimate the gradient and its square
                (default: (0.7, 1.0))
            eps (float, optional): term added to the denominator to improve
                numerical stability
                (default: 1e-8)
            weight_decay (float, optional): weight decay
                (L2 regularization coefficient, times two)
                (default: 0.0)
        """
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        defaults = {
            "lr": lr,
            "betas": betas,
            "nus": nus,
            "weight_decay": weight_decay,
            "eps": eps,
        }
        super(QHAdamW, self).__init__(params, defaults) 
[docs]    def step(self, closure: Optional[Callable] = None):
        """Makes optimizer step.
        Args:
            closure (callable, optional): A closure that reevaluates
                the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            nu1, nu2 = group["nus"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if d_p.is_sparse:
                    raise RuntimeError(
                        "QHAdamW does not support sparse gradients"
                    )
                param_state = self.state[p]
                # Original QHAdam implementation for weight decay:
                # if weight_decay != 0:
                #    d_p.add_(weight_decay, p.data)
                d_p_sq = d_p.mul(d_p)
                if len(param_state) == 0:
                    param_state["beta1_weight"] = 0.0
                    param_state["beta2_weight"] = 0.0
                    param_state["exp_avg"] = torch.zeros_like(p.data)
                    param_state["exp_avg_sq"] = torch.zeros_like(p.data)
                param_state["beta1_weight"] = (
                    1.0 + beta1 * param_state["beta1_weight"]
                )
                param_state["beta2_weight"] = (
                    1.0 + beta2 * param_state["beta2_weight"]
                )
                beta1_weight = param_state["beta1_weight"]
                beta2_weight = param_state["beta2_weight"]
                exp_avg = param_state["exp_avg"]
                exp_avg_sq = param_state["exp_avg_sq"]
                beta1_adj = 1.0 - (1.0 / beta1_weight)
                beta2_adj = 1.0 - (1.0 / beta2_weight)
                exp_avg.mul_(beta1_adj).add_(1.0 - beta1_adj, d_p)
                exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj, d_p_sq)
                avg_grad = exp_avg.mul(nu1)
                if nu1 != 1.0:
                    avg_grad.add_(1.0 - nu1, d_p)
                avg_grad_rms = exp_avg_sq.mul(nu2)
                if nu2 != 1.0:
                    avg_grad_rms.add_(1.0 - nu2, d_p_sq)
                avg_grad_rms.sqrt_()
                if eps != 0.0:
                    avg_grad_rms.add_(eps)
                # Original QHAdam implementation:
                # p.data.addcdiv_(-lr, avg_grad, avg_grad_rms)
                # Implementation following AdamW paper:
                p.data.add_(-weight_decay, p.data).addcdiv_(
                    -lr, avg_grad, avg_grad_rms
                )
        return loss