Source code for catalyst.contrib.nn.optimizers.lookahead

from typing import Dict  # isort:skip
from collections import defaultdict

import torch
from torch.optim import Optimizer


[docs]class Lookahead(Optimizer):
[docs] def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5): """ Taken from: https://github.com/alphadl/lookahead.pytorch """ self.optimizer = optimizer self.k = k self.alpha = alpha self.param_groups = self.optimizer.param_groups self.defaults = self.optimizer.defaults self.state = defaultdict(dict) self.fast_state = self.optimizer.state for group in self.param_groups: group["counter"] = 0
[docs] def update(self, group): for fast in group["params"]: param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.zeros_like(fast.data) param_state["slow_param"].copy_(fast.data) slow = param_state["slow_param"] slow += (fast.data - slow) * self.alpha fast.data.copy_(slow)
[docs] def update_lookahead(self): for group in self.param_groups: self.update(group)
[docs] def step(self, closure=None): loss = self.optimizer.step(closure) for group in self.param_groups: if group["counter"] == 0: self.update(group) group["counter"] += 1 if group["counter"] >= self.k: group["counter"] = 0 return loss
[docs] def state_dict(self): fast_state_dict = self.optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return { "fast_state": fast_state, "slow_state": slow_state, "param_groups": param_groups, }
[docs] def load_state_dict(self, state_dict): slow_state_dict = { "state": state_dict["slow_state"], "param_groups": state_dict["param_groups"], } fast_state_dict = { "state": state_dict["fast_state"], "param_groups": state_dict["param_groups"], } super(Lookahead, self).load_state_dict(slow_state_dict) self.optimizer.load_state_dict(fast_state_dict) self.fast_state = self.optimizer.state
[docs] def add_param_group(self, param_group): param_group["counter"] = 0 self.optimizer.add_param_group(param_group)
[docs] @classmethod def get_from_params( cls, params: Dict, base_optimizer_params: Dict = None, **kwargs, ) -> "Lookahead": from catalyst.dl.registry import OPTIMIZERS base_optimizer = OPTIMIZERS.get_from_params( params=params, **base_optimizer_params ) optimizer = cls(optimizer=base_optimizer, **kwargs) return optimizer