import collections
import torch
import torch.nn as nn
from catalyst.contrib.models import get_linear_net
from catalyst.contrib.nn.modules import LamaPooling, TemporalConcatPooling
from catalyst.rl import utils
[docs]class StateNet(nn.Module):
[docs]    def __init__(
        self,
        main_net: nn.Module,
        observation_net: nn.Module = None,
        aggregation_net: nn.Module = None,
    ):
        """
        Abstract network, that takes some tensor
        T of shape [bs; history_len; ...]
        and outputs some representation tensor R
        of shape [bs; representation_size]
        input_T [bs; history_len; in_features]
        -> observation_net (aka observation_encoder) ->
        observations_representations [bs; history_len; obs_features]
        -> aggregation_net (flatten in simplified case) ->
        aggregated_representation [bs; hid_features]
        -> main_net ->
        output_T [bs; representation_size]
        Args:
            main_net:
            observation_net:
            aggregation_net:
        """
        super().__init__()
        self.main_net = main_net
        self.observation_net = observation_net or (lambda x: x)
        self.aggregation_net = aggregation_net
        self._forward_fn = None
        if aggregation_net is None:
            self._forward_fn = self._forward_ff
            self._process_state = utils.process_state_ff_kv \
                
if isinstance(self.observation_net, nn.ModuleDict) \
                
else utils.process_state_ff
        elif isinstance(aggregation_net, (TemporalConcatPooling, LamaPooling)):
            self._forward_fn = self._forward_temporal
            self._process_state = utils.process_state_temporal_kv \
                
if isinstance(self.observation_net, nn.ModuleDict) \
                
else utils.process_state_temporal
        else:
            raise NotImplementedError() 
    def _forward_ff(self, state):
        x = state
        x = self._process_state(x, self.observation_net)
        x = self.main_net(x)
        return x
    def _forward_temporal(self, state):
        x = state
        x = self._process_state(x, self.observation_net)
        x = self.aggregation_net(x)
        x = self.main_net(x)
        return x
[docs]    def forward(self, state):
        x = self._forward_fn(state)
        return x 
[docs]    @classmethod
    def get_from_params(
        cls,
        state_shape,
        observation_net_params=None,
        aggregation_net_params=None,
        main_net_params=None,
    ) -> "StateNet":
        assert main_net_params is not None
        # @TODO: refactor, too complicated; fast&furious development
        main_net_in_features = 0
        observation_net_out_features = 0
        # observation net
        if observation_net_params is not None:
            key_value_flag = observation_net_params.pop("_key_value", False)
            if key_value_flag:
                observation_net = collections.OrderedDict()
                for key in observation_net_params:
                    net_, out_features_ = \
                        
utils.get_observation_net(
                            state_shape[key],
                            **observation_net_params[key]
                        )
                    observation_net[key] = net_
                    observation_net_out_features += out_features_
                observation_net = nn.ModuleDict(observation_net)
            else:
                observation_net, observation_net_out_features = \
                    
utils.get_observation_net(
                        state_shape,
                        **observation_net_params
                    )
            main_net_in_features += observation_net_out_features
        else:
            observation_net, observation_net_out_features = \
                
utils.get_observation_net(state_shape)
            main_net_in_features += observation_net_out_features
        # aggregation net
        if aggregation_net_params is not None:
            aggregation_type = \
                
aggregation_net_params.pop("_network_type", "concat")
            if aggregation_type == "concat":
                aggregation_net = TemporalConcatPooling(
                    observation_net_out_features, **aggregation_net_params)
            elif aggregation_type == "lama":
                aggregation_net = LamaPooling(
                    observation_net_out_features,
                    **aggregation_net_params)
            else:
                raise NotImplementedError()
            main_net_in_features = aggregation_net.out_features
        else:
            aggregation_net = None
        # main net
        main_net_params["in_features"] = main_net_in_features
        main_net = get_linear_net(**main_net_params)
        net = cls(
            main_net=main_net,
            aggregation_net=aggregation_net,
            observation_net=observation_net)
        return net  
[docs]class StateActionNet(nn.Module):
    def __init__(
        self,
        main_net: nn.Module,
        observation_net: nn.Module = None,
        action_net: nn.Module = None,
        aggregation_net: nn.Module = None
    ):
        super().__init__()
        self.main_net = main_net
        self.observation_net = observation_net or (lambda x: x)
        self.action_net = action_net or (lambda x: x)
        self.aggregation_net = aggregation_net
        self._forward_fn = None
        if aggregation_net is None:
            self._forward_fn = self._forward_ff
            self._process_state = utils.process_state_ff_kv \
                
if isinstance(self.observation_net, nn.ModuleDict) \
                
else utils.process_state_ff
        elif isinstance(aggregation_net, (TemporalConcatPooling, LamaPooling)):
            self._forward_fn = self._forward_temporal
            self._process_state = utils.process_state_temporal_kv \
                
if isinstance(self.observation_net, nn.ModuleDict) \
                
else utils.process_state_temporal
        else:
            raise NotImplementedError()
    def _forward_ff(self, state, action):
        state_ = self._process_state(state, self.observation_net)
        action_ = self.action_net(action)
        x = torch.cat((state_, action_), dim=1)
        x = self.main_net(x)
        return x
    def _forward_temporal(self, state, action):
        state_ = self._process_state(state, self.observation_net)
        state_ = self.aggregation_net(state_)
        action_ = self.action_net(action)
        x = torch.cat((state_, action_), dim=1)
        x = self.main_net(x)
        return x
[docs]    def forward(self, state, action):
        x = self._forward_fn(state, action)
        return x 
[docs]    @classmethod
    def get_from_params(
        cls,
        state_shape,
        action_shape,
        observation_net_params=None,
        action_net_params=None,
        aggregation_net_params=None,
        main_net_params=None,
    ) -> "StateNet":
        assert main_net_params is not None
        # @TODO: refactor, too complicated; fast&furious development
        main_net_in_features = 0
        observation_net_out_features = 0
        # observation net
        if observation_net_params is not None:
            key_value_flag = observation_net_params.pop("_key_value", False)
            if key_value_flag:
                observation_net = collections.OrderedDict()
                for key in observation_net_params:
                    net_, out_features_ = \
                        
utils.get_observation_net(
                            state_shape[key],
                            **observation_net_params[key]
                        )
                    observation_net[key] = net_
                    observation_net_out_features += out_features_
                observation_net = nn.ModuleDict(observation_net)
            else:
                observation_net, observation_net_out_features = \
                    
utils.get_observation_net(
                        state_shape,
                        **observation_net_params
                    )
        else:
            observation_net, observation_net_out_features = \
                
utils.get_observation_net(state_shape)
        main_net_in_features += observation_net_out_features
        # aggregation net
        if aggregation_net_params is not None:
            aggregation_type = \
                
aggregation_net_params.pop("_network_type", "concat")
            if aggregation_type == "concat":
                aggregation_net = TemporalConcatPooling(
                    observation_net_out_features, **aggregation_net_params)
            elif aggregation_type == "lama":
                aggregation_net = LamaPooling(
                    observation_net_out_features,
                    **aggregation_net_params)
            else:
                raise NotImplementedError()
            main_net_in_features = aggregation_net.out_features
        else:
            aggregation_net = None
        # action net
        if action_net_params is not None:
            # @TODO: hacky solution for code reuse
            action_shape = (1, ) + action_shape
            action_net, action_net_out_features = \
                
utils.get_observation_net(action_shape, **action_net_params)
        else:
            action_net, action_net_out_features = \
                
utils.get_observation_net(action_shape)
        main_net_in_features += action_net_out_features
        # main net
        main_net_params["in_features"] = main_net_in_features
        main_net = get_linear_net(**main_net_params)
        net = cls(
            observation_net=observation_net,
            action_net=action_net,
            aggregation_net=aggregation_net,
            main_net=main_net
        )
        return net  
__all__ = ["StateNet", "StateActionNet"]