Source code for catalyst.rl.agent.network

import collections

import torch
import torch.nn as nn

from catalyst.contrib.models import get_linear_net
from catalyst.contrib.nn.modules import LamaPooling, TemporalConcatPooling
from catalyst.rl import utils


[docs]class StateNet(nn.Module):
[docs] def __init__( self, main_net: nn.Module, observation_net: nn.Module = None, aggregation_net: nn.Module = None, ): """ Abstract network, that takes some tensor T of shape [bs; history_len; ...] and outputs some representation tensor R of shape [bs; representation_size] input_T [bs; history_len; in_features] -> observation_net (aka observation_encoder) -> observations_representations [bs; history_len; obs_features] -> aggregation_net (flatten in simplified case) -> aggregated_representation [bs; hid_features] -> main_net -> output_T [bs; representation_size] Args: main_net: observation_net: aggregation_net: """ super().__init__() self.main_net = main_net self.observation_net = observation_net or (lambda x: x) self.aggregation_net = aggregation_net self._forward_fn = None if aggregation_net is None: self._forward_fn = self._forward_ff self._process_state = utils.process_state_ff_kv \ if isinstance(self.observation_net, nn.ModuleDict) \ else utils.process_state_ff elif isinstance(aggregation_net, (TemporalConcatPooling, LamaPooling)): self._forward_fn = self._forward_temporal self._process_state = utils.process_state_temporal_kv \ if isinstance(self.observation_net, nn.ModuleDict) \ else utils.process_state_temporal else: raise NotImplementedError()
def _forward_ff(self, state): x = state x = self._process_state(x, self.observation_net) x = self.main_net(x) return x def _forward_temporal(self, state): x = state x = self._process_state(x, self.observation_net) x = self.aggregation_net(x) x = self.main_net(x) return x
[docs] def forward(self, state): x = self._forward_fn(state) return x
[docs] @classmethod def get_from_params( cls, state_shape, observation_net_params=None, aggregation_net_params=None, main_net_params=None, ) -> "StateNet": assert main_net_params is not None # @TODO: refactor, too complicated; fast&furious development main_net_in_features = 0 observation_net_out_features = 0 # observation net if observation_net_params is not None: key_value_flag = observation_net_params.pop("_key_value", False) if key_value_flag: observation_net = collections.OrderedDict() for key in observation_net_params: net_, out_features_ = \ utils.get_observation_net( state_shape[key], **observation_net_params[key] ) observation_net[key] = net_ observation_net_out_features += out_features_ observation_net = nn.ModuleDict(observation_net) else: observation_net, observation_net_out_features = \ utils.get_observation_net( state_shape, **observation_net_params ) main_net_in_features += observation_net_out_features else: observation_net, observation_net_out_features = \ utils.get_observation_net(state_shape) main_net_in_features += observation_net_out_features # aggregation net if aggregation_net_params is not None: aggregation_type = \ aggregation_net_params.pop("_network_type", "concat") if aggregation_type == "concat": aggregation_net = TemporalConcatPooling( observation_net_out_features, **aggregation_net_params ) elif aggregation_type == "lama": aggregation_net = LamaPooling( observation_net_out_features, **aggregation_net_params ) else: raise NotImplementedError() main_net_in_features = aggregation_net.out_features else: aggregation_net = None # main net main_net_params["in_features"] = main_net_in_features main_net = get_linear_net(**main_net_params) net = cls( main_net=main_net, aggregation_net=aggregation_net, observation_net=observation_net ) return net
[docs]class StateActionNet(nn.Module): def __init__( self, main_net: nn.Module, observation_net: nn.Module = None, action_net: nn.Module = None, aggregation_net: nn.Module = None ): super().__init__() self.main_net = main_net self.observation_net = observation_net or (lambda x: x) self.action_net = action_net or (lambda x: x) self.aggregation_net = aggregation_net self._forward_fn = None if aggregation_net is None: self._forward_fn = self._forward_ff self._process_state = utils.process_state_ff_kv \ if isinstance(self.observation_net, nn.ModuleDict) \ else utils.process_state_ff elif isinstance(aggregation_net, (TemporalConcatPooling, LamaPooling)): self._forward_fn = self._forward_temporal self._process_state = utils.process_state_temporal_kv \ if isinstance(self.observation_net, nn.ModuleDict) \ else utils.process_state_temporal else: raise NotImplementedError() def _forward_ff(self, state, action): state_ = self._process_state(state, self.observation_net) action_ = self.action_net(action) x = torch.cat((state_, action_), dim=1) x = self.main_net(x) return x def _forward_temporal(self, state, action): state_ = self._process_state(state, self.observation_net) state_ = self.aggregation_net(state_) action_ = self.action_net(action) x = torch.cat((state_, action_), dim=1) x = self.main_net(x) return x
[docs] def forward(self, state, action): x = self._forward_fn(state, action) return x
[docs] @classmethod def get_from_params( cls, state_shape, action_shape, observation_net_params=None, action_net_params=None, aggregation_net_params=None, main_net_params=None, ) -> "StateNet": assert main_net_params is not None # @TODO: refactor, too complicated; fast&furious development main_net_in_features = 0 observation_net_out_features = 0 # observation net if observation_net_params is not None: key_value_flag = observation_net_params.pop("_key_value", False) if key_value_flag: observation_net = collections.OrderedDict() for key in observation_net_params: net_, out_features_ = \ utils.get_observation_net( state_shape[key], **observation_net_params[key] ) observation_net[key] = net_ observation_net_out_features += out_features_ observation_net = nn.ModuleDict(observation_net) else: observation_net, observation_net_out_features = \ utils.get_observation_net( state_shape, **observation_net_params ) else: observation_net, observation_net_out_features = \ utils.get_observation_net(state_shape) main_net_in_features += observation_net_out_features # aggregation net if aggregation_net_params is not None: aggregation_type = \ aggregation_net_params.pop("_network_type", "concat") if aggregation_type == "concat": aggregation_net = TemporalConcatPooling( observation_net_out_features, **aggregation_net_params ) elif aggregation_type == "lama": aggregation_net = LamaPooling( observation_net_out_features, **aggregation_net_params ) else: raise NotImplementedError() main_net_in_features = aggregation_net.out_features else: aggregation_net = None # action net if action_net_params is not None: # @TODO: hacky solution for code reuse action_shape = (1, ) + action_shape action_net, action_net_out_features = \ utils.get_observation_net(action_shape, **action_net_params) else: action_net, action_net_out_features = \ utils.get_observation_net(action_shape) main_net_in_features += action_net_out_features # main net main_net_params["in_features"] = main_net_in_features main_net = get_linear_net(**main_net_params) net = cls( observation_net=observation_net, action_net=action_net, aggregation_net=aggregation_net, main_net=main_net ) return net
__all__ = ["StateNet", "StateActionNet"]