Shortcuts

Source code for catalyst.contrib.nn.modules.lama

# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import torch
from torch import nn

from catalyst.utils.initialization import outer_init


[docs]class TemporalLastPooling(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """Forward call.""" x_out = x[:, -1:, :] return x_out
[docs]class TemporalAvgPooling(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """Forward call.""" if mask is None: x_out = x.mean(1, keepdim=True) else: x_masked = torch.sum(x * mask.float(), dim=1, keepdim=True) mask_sum = torch.sum(mask.float(), dim=1, keepdim=True) x_out = x_masked / mask_sum return x_out
[docs]class TemporalMaxPooling(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """Forward call.""" if mask is not None: x_mask = (~mask.bool()).float() * (-x.max()).float() x = torch.sum(x + x_mask, dim=1, keepdim=True) x_out = x.max(1, keepdim=True)[0] return x_out
[docs]class TemporalAttentionPooling(nn.Module): """@TODO: Docs. Contribution is welcome.""" name2activation = { "softmax": nn.Softmax(dim=1), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), }
[docs] def __init__(self, in_features, activation=None, kernel_size=1, **params): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.in_features = in_features activation = activation or "softmax" self.attention_pooling = nn.Sequential( nn.Conv1d( in_channels=in_features, out_channels=1, kernel_size=kernel_size, **params ), TemporalAttentionPooling.name2activation[activation], ) self.attention_pooling.apply(outer_init)
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """ Forward call. Args: x: tensor of size (batch_size, history_len, feature_size) mask: mask to use Returns: pooling result """ batch_size, history_len, feature_size = x.shape x = x.view(batch_size, history_len, -1) x_a = x.transpose(1, 2) x_attn = (self.attention_pooling(x_a) * x_a).transpose(1, 2) x_attn = x_attn.sum(1, keepdim=True) return x_attn
[docs]class TemporalConcatPooling(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, in_features, history_len=1): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.in_features = in_features self.out_features = in_features * history_len
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """ Concat pooling forward. Args: x: tensor of size (batch_size, history_len, feature_size) mask: mask to use Returns: concated result """ x = x.view(x.shape[0], -1) return x
[docs]class TemporalDropLastWrapper(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, net): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.net = net
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor = None): """@TODO: Docs. Contribution is welcome.""" x = x[:, :-1, :] x_out = self.net(x) return x_out
def _get_pooling(key, in_features, **params): """@TODO: Docs. Contribution is welcome.""" key_prefix = key.split("_", 1)[0] if key_prefix == "last": return TemporalLastPooling() elif key_prefix == "avg": layer = TemporalAvgPooling() elif key_prefix == "max": layer = TemporalMaxPooling() elif key_prefix in ["softmax", "tanh", "sigmoid"]: layer = TemporalAttentionPooling( in_features=in_features, activation=key_prefix, **params ) else: raise NotImplementedError() if "droplast" in key: layer = TemporalDropLastWrapper(layer) return layer
[docs]class LamaPooling(nn.Module): """@TODO: Docs. Contribution is welcome.""" available_groups = [ "last", "avg", "avg_droplast", "max", "max_droplast", "sigmoid", "sigmoid_droplast", "softmax", "softmax_droplast", "tanh", "tanh_droplast", ]
[docs] def __init__(self, in_features, groups=None): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.in_features = in_features self.groups = groups or [ "last", "avg_droplast", "max_droplast", "softmax_droplast", ] self.out_features = in_features * len(self.groups) groups = {} for key in self.groups: if isinstance(key, str): groups[key] = _get_pooling(key, self.in_features) elif isinstance(key, dict): key_key = key.pop("key") groups[key_key] = _get_pooling(key_key, in_features, **key) else: raise NotImplementedError() self.groups = nn.ModuleDict(groups)
[docs] def forward( self, x: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: """ Forward method of the LAMA. Args: x: tensor of size (batch_size, history_len, feature_size) mask: mask to use for attention compute Returns: torch.Tensor: LAMA pooling result """ batch_size, history_len, feature_size = x.shape features_list = [] for pooling_fn in self.groups.values(): features = pooling_fn(x, mask) features_list.append(features) x = torch.cat(features_list, dim=1) x = x.view(batch_size, -1) return x
__all__ = [ "TemporalLastPooling", "TemporalAvgPooling", "TemporalMaxPooling", "TemporalDropLastWrapper", "TemporalAttentionPooling", "TemporalConcatPooling", "LamaPooling", ]