Source code for catalyst.rl.core.policy_handler

from typing import Union  # isort:skip
import numpy as np

from gym.spaces import Discrete
import torch

from catalyst.rl import utils
from .agent import ActorSpec, CriticSpec
from .environment import EnvironmentSpec


def _state2device(array: np.ndarray, device):
    array = utils.any2device(array, device)

    if isinstance(array, dict):
        array = {
            key: value.to(device).unsqueeze(0)
            for key, value in array.items()
        }
    else:
        array = array.to(device).unsqueeze(0)

    return array


[docs]class PolicyHandler: def __init__( self, env: EnvironmentSpec, agent: Union[ActorSpec, CriticSpec], device ): self.action_fn = None self.discrete_actions = isinstance(env.action_space, Discrete) # PPO, REINFORCE, DQN if self.discrete_actions: if isinstance(agent, ActorSpec): self.action_clip = None self.action_fn = self._actor_handler elif isinstance(agent, CriticSpec): self.action_fn = self._critic_handler self.value_distribution = agent.distribution if self.value_distribution == "categorical": v_min, v_max = agent.values_range self.z = torch.linspace( start=v_min, end=v_max, steps=agent.num_atoms ).to(device) else: raise NotImplementedError() # PPO, DDPG, SAC, TD3 else: assert isinstance(agent, ActorSpec) self.action_fn = self._actor_handler @torch.no_grad() def _get_q_values(self, critic: CriticSpec, state: np.ndarray, device): states = _state2device(state, device) output = critic(states) # We use the last head to perform actions # This is the head corresponding to the largest gamma if self.value_distribution == "categorical": probs = torch.softmax(output[0, -1, :, :], dim=-1) q_values = torch.sum(probs * self.z, dim=-1) elif self.value_distribution == "quantile": q_values = torch.mean(output[0, -1, :, :], dim=-1) else: q_values = output[0, -1, :, 0] return q_values.cpu().numpy() @torch.no_grad() def _sample_from_actor( self, actor: ActorSpec, state: np.ndarray, device, deterministic: bool = False ): states = _state2device(state, device) action = actor(states, deterministic=deterministic) action = action[0].cpu().numpy() return action def _critic_handler( self, agent: CriticSpec, state: np.ndarray, device, deterministic: bool = False, exploration_strategy=None ): q_values = self._get_q_values(agent, state, device) if not deterministic and exploration_strategy is not None: action = exploration_strategy.get_action(q_values) else: action = np.argmax(q_values) return action def _actor_handler( self, agent: ActorSpec, state: np.ndarray, device, deterministic: bool = False, exploration_strategy=None ): action = self._sample_from_actor(agent, state, device, deterministic) if not deterministic and exploration_strategy is not None: action = exploration_strategy.get_action(action) return action