Source code for catalyst.rl.offpolicy.algorithms.critic

from typing import Dict, Union  # isort:skip
import copy

from catalyst.rl import utils
from catalyst.rl.core import (
    ActorSpec, AlgorithmSpec, CriticSpec, EnvironmentSpec
)
from catalyst.rl.registry import AGENTS


[docs]class OffpolicyCritic(AlgorithmSpec): def __init__( self, critic: CriticSpec, gamma: float, n_step: int, critic_loss_params: Dict = None, critic_optimizer_params: Dict = None, critic_scheduler_params: Dict = None, critic_grad_clip_params: Dict = None, critic_tau: float = 1.0, **kwargs ): self._device = utils.get_device() self.critic = critic.to(self._device) self.target_critic = copy.deepcopy(critic).to(self._device) # preparation critic_components = utils.get_trainer_components( agent=self.critic, loss_params=critic_loss_params, optimizer_params=critic_optimizer_params, scheduler_params=critic_scheduler_params, grad_clip_params=critic_grad_clip_params ) # criterion self._critic_loss_params = critic_components["loss_params"] self.critic_criterion = critic_components["criterion"] # optimizer self._critic_optimizer_params = critic_components["optimizer_params"] self.critic_optimizer = critic_components["optimizer"] # scheduler self._critic_scheduler_params = critic_components["scheduler_params"] self.critic_scheduler = critic_components["scheduler"] # grad clipping self._critic_grad_clip_params = critic_components["grad_clip_params"] self.critic_grad_clip_fn = critic_components["grad_clip_fn"] # other hyperparameters self._n_step = n_step self._gamma = gamma self.critic_tau = critic_tau # other init self._init(**kwargs) def _init(self, **kwargs): assert len(kwargs) == 0 @property def n_step(self) -> int: return self._n_step @property def gamma(self) -> float: return self._gamma
[docs] def pack_checkpoint(self, with_optimizer: bool = True): checkpoint = {} for key in ["critic"]: checkpoint[f"{key}_state_dict"] = getattr(self, key).state_dict() if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value2 = getattr(self, key2, None) if value2 is not None: checkpoint[f"{key2}_state_dict"] = value2.state_dict() return checkpoint
[docs] def unpack_checkpoint(self, checkpoint, with_optimizer: bool = True): for key in ["critic"]: value_l = getattr(self, key, None) if value_l is not None: value_r = checkpoint[f"{key}_state_dict"] value_l.load_state_dict(value_r) if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value_l = getattr(self, key2, None) if value_l is not None: value_r = checkpoint[f"{key2}_state_dict"] value_l.load_state_dict(value_r)
[docs] def critic_update(self, loss): self.critic.zero_grad() self.critic_optimizer.zero_grad() loss.backward() if self.critic_grad_clip_fn is not None: self.critic_grad_clip_fn(self.critic.parameters()) self.critic_optimizer.step() if self.critic_scheduler is not None: self.critic_scheduler.step() return {"lr_critic": self.critic_scheduler.get_lr()[0]}
[docs] def target_critic_update(self): utils.soft_update(self.target_critic, self.critic, self.critic_tau)
[docs] def update_step(self, value_loss, critic_update=True): """ Updates parameters of neural networks and returns learning metrics Args: value_loss: critic_update: Returns: """ raise NotImplementedError
[docs] def train(self, batch, actor_update=False, critic_update=True): states_t, actions_t, rewards_t, states_tp1, done_t = \ batch["state"], batch["action"], batch["reward"], \ batch["next_state"], batch["done"] states_t = utils.any2device(states_t, self._device) actions_t = utils.any2device(actions_t, self._device).unsqueeze(1).long() rewards_t = utils.any2device(rewards_t, self._device).unsqueeze(1) states_tp1 = utils.any2device(states_tp1, device=self._device) done_t = utils.any2device(done_t, device=self._device).unsqueeze(1) """ states_t: [bs; history_len; observation_len] actions_t: [bs; 1] rewards_t: [bs; 1] states_tp1: [bs; history_len; observation_len] done_t: [bs; 1] """ value_loss = self._loss_fn( states_t, actions_t, rewards_t, states_tp1, done_t ) metrics = self.update_step( value_loss=value_loss, critic_update=critic_update ) return metrics
[docs] @classmethod def prepare_for_trainer( cls, env_spec: EnvironmentSpec, config: Dict ) -> "AlgorithmSpec": config_ = config.copy() agents_config = config_["agents"] critic_params = agents_config["critic"] critic = AGENTS.get_from_params( **critic_params, env_spec=env_spec, ) algorithm = cls( **config_["algorithm"], critic=critic, ) return algorithm
[docs] @classmethod def prepare_for_sampler( cls, env_spec: EnvironmentSpec, config: Dict ) -> Union[ActorSpec, CriticSpec]: config_ = config.copy() agents_config = config_["agents"] critic_params = agents_config["critic"] critic = AGENTS.get_from_params( **critic_params, env_spec=env_spec, ) return critic