import torch
import torch.nn as nn
from catalyst.contrib.models import SequentialNet
from catalyst.contrib.registry import MODULES
from catalyst.utils import create_optimal_inner_init, log1p_exp, outer_init
[docs]class SquashingLayer(nn.Module):
[docs] def __init__(self, squashing_fn=nn.Tanh):
"""
Layer that squashes samples from some distribution to be bounded.
"""
super().__init__()
self.squashing_fn = MODULES.get_if_str(squashing_fn)()
[docs] def forward(self, action, action_logprob):
# compute log det jacobian of squashing transformation
if isinstance(self.squashing_fn, nn.Tanh):
log2 = torch.log(torch.tensor(2.0).to(action.device))
log_det_jacobian = 2 * (log2 + action - log1p_exp(2 * action))
log_det_jacobian = torch.sum(log_det_jacobian, dim=-1)
elif isinstance(self.squashing_fn, nn.Sigmoid):
log_det_jacobian = -action - 2 * log1p_exp(-action)
log_det_jacobian = torch.sum(log_det_jacobian, dim=-1)
elif self.squashing_fn is None:
return action, action_logprob
else:
raise NotImplementedError
action = self.squashing_fn.forward(action)
action_logprob = action_logprob - log_det_jacobian
return action, action_logprob
[docs]class CouplingLayer(nn.Module):
[docs] def __init__(
self,
action_size,
layer_fn,
activation_fn=nn.ReLU,
bias=True,
parity="odd"
):
"""
Conditional affine coupling layer used in Real NVP Bijector.
Original paper: https://arxiv.org/abs/1605.08803
Adaptation to RL: https://arxiv.org/abs/1804.02808
Important notes
---------------
1. State embeddings are supposed to have size (action_size * 2).
2. Scale and translation networks used in the Real NVP Bijector
both have one hidden layer of (action_size) (activation_fn) units.
3. Parity ("odd" or "even") determines which part of the input
is being copied and which is being transformed.
"""
super().__init__()
assert parity in ["odd", "even"]
layer_fn = MODULES.get_if_str(layer_fn)
self.parity = parity
if self.parity == "odd":
self.copy_size = action_size // 2
else:
self.copy_size = action_size - action_size // 2
self.scale_prenet = SequentialNet(
hiddens=[action_size * 2 + self.copy_size, action_size],
layer_fn={"module": layer_fn, "bias": bias},
activation_fn=activation_fn,
norm_fn=None,
)
self.scale_net = SequentialNet(
hiddens=[action_size, action_size - self.copy_size],
layer_fn={"module": layer_fn, "bias": True},
activation_fn=None,
norm_fn=None,
)
self.translation_prenet = SequentialNet(
hiddens=[action_size * 2 + self.copy_size, action_size],
layer_fn={"module": layer_fn, "bias": bias},
activation_fn=activation_fn,
norm_fn=None,
)
self.translation_net = SequentialNet(
hiddens=[action_size, action_size - self.copy_size],
layer_fn={"module": layer_fn, "bias": True},
activation_fn=None,
norm_fn=None,
)
inner_init = create_optimal_inner_init(nonlinearity=activation_fn)
self.scale_prenet.apply(inner_init)
self.scale_net.apply(outer_init)
self.translation_prenet.apply(inner_init)
self.translation_net.apply(outer_init)
[docs] def forward(self, action, state_embedding, action_logprob):
if self.parity == "odd":
action_copy = action[:, :self.copy_size]
action_transform = action[:, self.copy_size:]
else:
action_copy = action[:, -self.copy_size:]
action_transform = action[:, :-self.copy_size]
x = torch.cat((state_embedding, action_copy), dim=1)
t = self.translation_prenet(x)
t = self.translation_net(t)
s = self.scale_prenet(x)
s = self.scale_net(s)
out_transform = t + action_transform * torch.exp(s)
if self.parity == "odd":
action = torch.cat((action_copy, out_transform), dim=1)
else:
action = torch.cat((out_transform, action_copy), dim=1)
log_det_jacobian = s.sum(dim=1)
action_logprob = action_logprob - log_det_jacobian
return action, action_logprob
__all__ = ["SquashingLayer", "CouplingLayer"]