Source code for catalyst.rl.agent.critic

from typing import Dict, Tuple  # isort:skip
import copy

from gym import spaces

from catalyst.rl.core import CriticSpec, EnvironmentSpec
from .head import ValueHead
from .network import StateActionNet, StateNet


[docs]class StateCritic(CriticSpec): """ Critic that learns state value functions, like V(s). """ def __init__(self, state_net: StateNet, head_net: ValueHead): super().__init__() self.state_net = state_net self.head_net = head_net @property def num_outputs(self) -> int: return self.head_net.out_features @property def num_atoms(self) -> int: return self.head_net.num_atoms @property def distribution(self) -> str: return self.head_net.distribution @property def values_range(self) -> Tuple: return self.head_net.values_range @property def num_heads(self) -> int: return self.head_net.num_heads @property def hyperbolic_constant(self) -> float: return self.head_net.hyperbolic_constant
[docs] def forward(self, state): x = self.state_net(state) x = self.head_net(x) return x
[docs] @classmethod def get_from_params( cls, state_net_params: Dict, value_head_params: Dict, env_spec: EnvironmentSpec, ): state_net_params = copy.deepcopy(state_net_params) value_head_params = copy.deepcopy(value_head_params) # @TODO: any better solution? state_net_params["state_shape"] = env_spec.state_space.shape state_net = StateNet.get_from_params(**state_net_params) head_net = ValueHead(**value_head_params) net = cls(state_net=state_net, head_net=head_net) return net
[docs]class ActionCritic(StateCritic): """ Critic that learns state-action value functions, like Q(s). """
[docs] @classmethod def get_from_params( cls, state_net_params: Dict, value_head_params: Dict, env_spec: EnvironmentSpec, ): state_net_params = copy.deepcopy(state_net_params) value_head_params = copy.deepcopy(value_head_params) # @TODO: any better solution? action_space = env_spec.action_space assert isinstance(action_space, spaces.Discrete) value_head_params["out_features"] = action_space.n net = super().get_from_params( state_net_params=state_net_params, value_head_params=value_head_params, env_spec=env_spec ) return net
[docs]class StateActionCritic(CriticSpec): """ Critic which learns state-action value functions, like Q(s, a). """ def __init__(self, state_action_net: StateActionNet, head_net: ValueHead): super().__init__() self.state_action_net = state_action_net self.head_net = head_net
[docs] def forward(self, state, action): x = self.state_action_net(state, action) x = self.head_net(x) return x
@property def num_outputs(self) -> int: return self.head_net.out_features @property def num_atoms(self) -> int: return self.head_net.num_atoms @property def distribution(self) -> str: return self.head_net.distribution @property def values_range(self) -> Tuple: return self.head_net.values_range @property def num_heads(self) -> int: return self.head_net.num_heads @property def hyperbolic_constant(self) -> float: return self.head_net.hyperbolic_constant
[docs] @classmethod def get_from_params( cls, state_action_net_params: Dict, value_head_params: Dict, env_spec: EnvironmentSpec, ): state_action_net_params = copy.deepcopy(state_action_net_params) value_head_params = copy.deepcopy(value_head_params) # @TODO: any better solution? if isinstance(env_spec.state_space, spaces.Dict): state_action_net_params["state_shape"] = { k: v.shape for k, v in env_spec.state_space.spaces.items() } else: state_action_net_params["state_shape"] = env_spec.state_space.shape state_action_net_params["action_shape"] = \ env_spec.action_space.shape state_action_net = StateActionNet.get_from_params( **state_action_net_params ) value_head_params["out_features"] = 1 head_net = ValueHead(**value_head_params) net = cls(state_action_net=state_action_net, head_net=head_net) return net
__all__ = ["CriticSpec", "StateCritic", "ActionCritic", "StateActionCritic"]