Shortcuts

Source code for catalyst.contrib.nn.modules.common

from torch import nn
from torch.nn import functional as F


[docs]class Flatten(nn.Module): """Flattens the input. Does not affect the batch size. @TODO: Docs (add `Example`). Contribution is welcome. """
[docs] def __init__(self): """@TODO: Docs. Contribution is welcome.""" super().__init__()
[docs] def forward(self, x): """Forward call.""" return x.view(x.shape[0], -1)
[docs]class Lambda(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, lambda_fn): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.lambda_fn = lambda_fn
[docs] def forward(self, x): """Forward call.""" return self.lambda_fn(x)
[docs]class Normalize(nn.Module): """Performs :math:`L_p` normalization of inputs over specified dimension. @TODO: Docs (add `Example`). Contribution is welcome. """
[docs] def __init__(self, **normalize_kwargs): """ Args: **normalize_kwargs: see ``torch.nn.functional.normalize`` params """ super().__init__() self.normalize_kwargs = normalize_kwargs
[docs] def forward(self, x): """Forward call.""" return F.normalize(x, **self.normalize_kwargs)