import torch
import torch.nn as nn
from catalyst.contrib.models import SequentialNet
from catalyst.contrib.registry import MODULES
from catalyst.utils import create_optimal_inner_init, log1p_exp, outer_init
[docs]class SquashingLayer(nn.Module):
[docs]    def __init__(self, squashing_fn=nn.Tanh):
        """
        Layer that squashes samples from some distribution to be bounded.
        """
        super().__init__()
        self.squashing_fn = MODULES.get_if_str(squashing_fn)() 
[docs]    def forward(self, action, action_logprob):
        # compute log det jacobian of squashing transformation
        if isinstance(self.squashing_fn, nn.Tanh):
            log2 = torch.log(torch.tensor(2.0).to(action.device))
            log_det_jacobian = 2 * (log2 + action - log1p_exp(2 * action))
            log_det_jacobian = torch.sum(log_det_jacobian, dim=-1)
        elif isinstance(self.squashing_fn, nn.Sigmoid):
            log_det_jacobian = -action - 2 * log1p_exp(-action)
            log_det_jacobian = torch.sum(log_det_jacobian, dim=-1)
        elif self.squashing_fn is None:
            return action, action_logprob
        else:
            raise NotImplementedError
        action = self.squashing_fn.forward(action)
        action_logprob = action_logprob - log_det_jacobian
        return action, action_logprob  
[docs]class CouplingLayer(nn.Module):
[docs]    def __init__(
        self,
        action_size,
        layer_fn,
        activation_fn=nn.ReLU,
        bias=True,
        parity="odd"
    ):
        """
        Conditional affine coupling layer used in Real NVP Bijector.
        Original paper: https://arxiv.org/abs/1605.08803
        Adaptation to RL: https://arxiv.org/abs/1804.02808
        Important notes
        ---------------
        1. State embeddings are supposed to have size (action_size * 2).
        2. Scale and translation networks used in the Real NVP Bijector
        both have one hidden layer of (action_size) (activation_fn) units.
        3. Parity ("odd" or "even") determines which part of the input
        is being copied and which is being transformed.
        """
        super().__init__()
        assert parity in ["odd", "even"]
        layer_fn = MODULES.get_if_str(layer_fn)
        self.parity = parity
        if self.parity == "odd":
            self.copy_size = action_size // 2
        else:
            self.copy_size = action_size - action_size // 2
        self.scale_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn={"module": layer_fn, "bias": bias},
            activation_fn=activation_fn,
            norm_fn=None,
        )
        self.scale_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn={"module": layer_fn, "bias": True},
            activation_fn=None,
            norm_fn=None,
        )
        self.translation_prenet = SequentialNet(
            hiddens=[action_size * 2 + self.copy_size, action_size],
            layer_fn={"module": layer_fn, "bias": bias},
            activation_fn=activation_fn,
            norm_fn=None,
        )
        self.translation_net = SequentialNet(
            hiddens=[action_size, action_size - self.copy_size],
            layer_fn={"module": layer_fn, "bias": True},
            activation_fn=None,
            norm_fn=None,
        )
        inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
        self.scale_prenet.apply(inner_init)
        self.scale_net.apply(outer_init)
        self.translation_prenet.apply(inner_init)
        self.translation_net.apply(outer_init) 
[docs]    def forward(self, action, state_embedding, action_logprob):
        if self.parity == "odd":
            action_copy = action[:, :self.copy_size]
            action_transform = action[:, self.copy_size:]
        else:
            action_copy = action[:, -self.copy_size:]
            action_transform = action[:, :-self.copy_size]
        x = torch.cat((state_embedding, action_copy), dim=1)
        t = self.translation_prenet(x)
        t = self.translation_net(t)
        s = self.scale_prenet(x)
        s = self.scale_net(s)
        out_transform = t + action_transform * torch.exp(s)
        if self.parity == "odd":
            action = torch.cat((action_copy, out_transform), dim=1)
        else:
            action = torch.cat((out_transform, action_copy), dim=1)
        log_det_jacobian = s.sum(dim=1)
        action_logprob = action_logprob - log_det_jacobian
        return action, action_logprob  
__all__ = ["SquashingLayer", "CouplingLayer"]