from typing import Union  # isort:skip
import numpy as np
from gym.spaces import Discrete
import torch
from catalyst.rl import utils
from .agent import ActorSpec, CriticSpec
from .environment import EnvironmentSpec
def _state2device(array: np.ndarray, device):
    array = utils.any2device(array, device)
    if isinstance(array, dict):
        array = {
            key: value.to(device).unsqueeze(0)
            for key, value in array.items()
        }
    else:
        array = array.to(device).unsqueeze(0)
    return array
[docs]class PolicyHandler:
    def __init__(
        self, env: EnvironmentSpec, agent: Union[ActorSpec, CriticSpec], device
    ):
        self.action_fn = None
        self.discrete_actions = isinstance(env.action_space, Discrete)
        # PPO, REINFORCE, DQN
        if self.discrete_actions:
            if isinstance(agent, ActorSpec):
                self.action_clip = None
                self.action_fn = self._actor_handler
            elif isinstance(agent, CriticSpec):
                self.action_fn = self._critic_handler
                self.value_distribution = agent.distribution
                if self.value_distribution == "categorical":
                    v_min, v_max = agent.values_range
                    self.z = torch.linspace(
                        start=v_min, end=v_max, steps=agent.num_atoms
                    ).to(device)
            else:
                raise NotImplementedError()
        # PPO, DDPG, SAC, TD3
        else:
            assert isinstance(agent, ActorSpec)
            self.action_fn = self._actor_handler
    @torch.no_grad()
    def _get_q_values(self, critic: CriticSpec, state: np.ndarray, device):
        states = _state2device(state, device)
        output = critic(states)
        # We use the last head to perform actions
        # This is the head corresponding to the largest gamma
        if self.value_distribution == "categorical":
            probs = torch.softmax(output[0, -1, :, :], dim=-1)
            q_values = torch.sum(probs * self.z, dim=-1)
        elif self.value_distribution == "quantile":
            q_values = torch.mean(output[0, -1, :, :], dim=-1)
        else:
            q_values = output[0, -1, :, 0]
        return q_values.cpu().numpy()
    @torch.no_grad()
    def _sample_from_actor(
        self,
        actor: ActorSpec,
        state: np.ndarray,
        device,
        deterministic: bool = False
    ):
        states = _state2device(state, device)
        action = actor(states, deterministic=deterministic)
        action = action[0].cpu().numpy()
        return action
    def _critic_handler(
        self,
        agent: CriticSpec,
        state: np.ndarray,
        device,
        deterministic: bool = False,
        exploration_strategy=None
    ):
        q_values = self._get_q_values(agent, state, device)
        if not deterministic and exploration_strategy is not None:
            action = exploration_strategy.get_action(q_values)
        else:
            action = np.argmax(q_values)
        return action
    def _actor_handler(
        self,
        agent: ActorSpec,
        state: np.ndarray,
        device,
        deterministic: bool = False,
        exploration_strategy=None
    ):
        action = self._sample_from_actor(agent, state, device, deterministic)
        if not deterministic and exploration_strategy is not None:
            action = exploration_strategy.get_action(action)
        return action