Source code for catalyst.contrib.nn.optimizers.lookahead
# flake8: noqa
# @TODO: code formatting issue for 20.07 release
from typing import Callable, Dict, Optional
from collections import defaultdict
import torch
from torch.optim import Optimizer
[docs]class Lookahead(Optimizer):
"""Implements Lookahead algorithm.
It has been proposed in `Lookahead Optimizer: k steps forward,
1 step back`_.
Adapted from:
https://github.com/alphadl/lookahead.pytorch (MIT License)
.. _`Lookahead Optimizer\: k steps forward, 1 step back`:
https://arxiv.org/abs/1907.08610
"""
[docs] def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5):
"""@TODO: Docs. Contribution is welcome."""
self.optimizer = optimizer
self.k = k
self.alpha = alpha
self.param_groups = self.optimizer.param_groups
self.defaults = self.optimizer.defaults
self.state = defaultdict(dict)
self.fast_state = self.optimizer.state
for group in self.param_groups:
group["counter"] = 0
def update(self, group):
"""@TODO: Docs. Contribution is welcome."""
for fast in group["params"]:
param_state = self.state[fast]
if "slow_param" not in param_state:
param_state["slow_param"] = torch.zeros_like(fast.data)
param_state["slow_param"].copy_(fast.data)
slow = param_state["slow_param"]
slow += (fast.data - slow) * self.alpha
fast.data.copy_(slow)
def update_lookahead(self):
"""@TODO: Docs. Contribution is welcome."""
for group in self.param_groups:
self.update(group)
def step(self, closure: Optional[Callable] = None):
"""Makes optimizer step.
Args:
closure (callable, optional): A closure that reevaluates
the model and returns the loss.
Returns:
computed loss
"""
loss = self.optimizer.step(closure)
for group in self.param_groups:
if group["counter"] == 0:
self.update(group)
group["counter"] += 1
if group["counter"] >= self.k:
group["counter"] = 0
return loss
def state_dict(self):
"""@TODO: Docs. Contribution is welcome."""
fast_state_dict = self.optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"fast_state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict):
"""@TODO: Docs. Contribution is welcome."""
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = {
"state": state_dict["fast_state"],
"param_groups": state_dict["param_groups"],
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.optimizer.load_state_dict(fast_state_dict)
self.fast_state = self.optimizer.state
def add_param_group(self, param_group):
"""@TODO: Docs. Contribution is welcome."""
param_group["counter"] = 0
self.optimizer.add_param_group(param_group)
@classmethod
def get_from_params(
cls, params: Dict, base_optimizer_params: Dict = None, **kwargs
) -> "Lookahead":
"""@TODO: Docs. Contribution is welcome."""
from catalyst.registry import REGISTRY
base_optimizer = REGISTRY.get_from_params(params=params, **base_optimizer_params)
optimizer = cls(optimizer=base_optimizer, **kwargs)
return optimizer
__all__ = ["Lookahead"]