Source code for catalyst.rl.agent.actor

from typing import Dict  # isort:skip
import copy

from gym import spaces
import torch

from catalyst.rl.core import ActorSpec, EnvironmentSpec
from .head import PolicyHead
from .network import StateNet


[docs]class Actor(ActorSpec): """ Actor which learns agents policy. """ def __init__( self, state_net: StateNet, head_net: PolicyHead, ): super().__init__() self.state_net = state_net self.head_net = head_net @property def policy_type(self) -> str: return self.head_net.policy_type
[docs] def forward(self, state: torch.Tensor, logprob=False, deterministic=False): x = self.state_net(state) x = self.head_net(x, logprob, deterministic) return x
[docs] @classmethod def get_from_params( cls, state_net_params: Dict, policy_head_params: Dict, env_spec: EnvironmentSpec, ): state_net_params = copy.deepcopy(state_net_params) policy_head_params = copy.deepcopy(policy_head_params) # @TODO: any better solution? action_space = env_spec.action_space if isinstance(action_space, spaces.Box): # continuous control policy_head_params["out_features"] = action_space.shape[0] elif isinstance(action_space, spaces.Discrete): # discrete control policy_head_params["out_features"] = action_space.n else: raise NotImplementedError() # @TODO: any better solution? if isinstance(env_spec.state_space, spaces.Dict): state_net_params["state_shape"] = { k: v.shape for k, v in env_spec.state_space.spaces.items() } else: state_net_params["state_shape"] = env_spec.state_space.shape state_net = StateNet.get_from_params(**state_net_params) head_net = PolicyHead(**policy_head_params) net = cls(state_net=state_net, head_net=head_net) return net
__all__ = ["ActorSpec", "Actor"]