Source code for catalyst.contrib.nn.optimizers.lookahead
from typing import Dict  # isort:skip
from collections import defaultdict
import torch
from torch.optim import Optimizer
[docs]class Lookahead(Optimizer):
[docs]    def __init__(
        self,
        optimizer: Optimizer,
        k: int = 5,
        alpha: float = 0.5
    ):
        """
        Taken from: https://github.com/alphadl/lookahead.pytorch
        """
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.defaults = self.optimizer.defaults
        self.state = defaultdict(dict)
        self.fast_state = self.optimizer.state
        for group in self.param_groups:
            group["counter"] = 0 
[docs]    def update(self, group):
        for fast in group["params"]:
            param_state = self.state[fast]
            if "slow_param" not in param_state:
                param_state["slow_param"] = torch.zeros_like(fast.data)
                param_state["slow_param"].copy_(fast.data)
            slow = param_state["slow_param"]
            slow += (fast.data - slow) * self.alpha
            fast.data.copy_(slow) 
[docs]    def update_lookahead(self):
        for group in self.param_groups:
            self.update(group) 
[docs]    def step(self, closure=None):
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            if group["counter"] == 0:
                self.update(group)
            group["counter"] += 1
            if group["counter"] >= self.k:
                group["counter"] = 0
        return loss 
[docs]    def state_dict(self):
        fast_state_dict = self.optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict["state"]
        param_groups = fast_state_dict["param_groups"]
        return {
            "fast_state": fast_state,
            "slow_state": slow_state,
            "param_groups": param_groups,
        } 
[docs]    def load_state_dict(self, state_dict):
        slow_state_dict = {
            "state": state_dict["slow_state"],
            "param_groups": state_dict["param_groups"],
        }
        fast_state_dict = {
            "state": state_dict["fast_state"],
            "param_groups": state_dict["param_groups"],
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.optimizer.load_state_dict(fast_state_dict)
        self.fast_state = self.optimizer.state 
[docs]    def add_param_group(self, param_group):
        param_group["counter"] = 0
        self.optimizer.add_param_group(param_group) 
[docs]    @classmethod
    def get_from_params(
        cls,
        params: Dict,
        base_optimizer_params: Dict = None,
        **kwargs,
    ) -> "Lookahead":
        from catalyst.dl.registry import OPTIMIZERS
        base_optimizer = OPTIMIZERS.get_from_params(
            params=params, **base_optimizer_params)
        optimizer = cls(optimizer=base_optimizer, **kwargs)
        return optimizer