Shortcuts

Source code for catalyst.contrib.nn.modules.common

# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import torch
from torch import nn
from torch.nn import functional as F


[docs]class Flatten(nn.Module): """Flattens the input. Does not affect the batch size. @TODO: Docs (add `Example`). Contribution is welcome. """
[docs] def __init__(self): """@TODO: Docs. Contribution is welcome.""" super().__init__()
[docs] def forward(self, x): """Forward call.""" return x.view(x.shape[0], -1)
[docs]class Lambda(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, lambda_fn): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.lambda_fn = lambda_fn
[docs] def forward(self, x): """Forward call.""" return self.lambda_fn(x)
[docs]class Normalize(nn.Module): """Performs :math:`L_p` normalization of inputs over specified dimension. @TODO: Docs (add `Example`). Contribution is welcome. """
[docs] def __init__(self, **normalize_kwargs): """ Args: **normalize_kwargs: see ``torch.nn.functional.normalize`` params """ super().__init__() self.normalize_kwargs = normalize_kwargs
[docs] def forward(self, x): """Forward call.""" return F.normalize(x, **self.normalize_kwargs)
[docs]class GaussianNoise(nn.Module): """ A gaussian noise module. Shape: - Input: (batch, \*) - Output: (batch, \*) (same shape as input) """
[docs] def __init__(self, stddev: float = 0.1): """ Args: stddev (float): The standard deviation of the normal distribution. Default: 0.1. """ super().__init__() self.stddev = stddev
[docs] def forward(self, x: torch.Tensor): """Forward call.""" noise = torch.empty_like(x) noise.normal_(0, self.stddev)
__all__ = ["Flatten", "Lambda", "Normalize", "GaussianNoise"]