Source code for catalyst.contrib.nn.modules.common
from torch import nn
from torch.nn import functional as F
[docs]class Flatten(nn.Module):
"""Flattens the input. Does not affect the batch size.
@TODO: Docs (add `Example`). Contribution is welcome.
"""
[docs] def __init__(self):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
[docs] def forward(self, x):
"""Forward call."""
return x.view(x.shape[0], -1)
[docs]class Lambda(nn.Module):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, lambda_fn):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.lambda_fn = lambda_fn
[docs] def forward(self, x):
"""Forward call."""
return self.lambda_fn(x)
[docs]class Normalize(nn.Module):
"""Performs :math:`L_p` normalization of inputs over specified dimension.
@TODO: Docs (add `Example`). Contribution is welcome.
"""
[docs] def __init__(self, **normalize_kwargs):
"""
Args:
**normalize_kwargs: see ``torch.nn.functional.normalize`` params
"""
super().__init__()
self.normalize_kwargs = normalize_kwargs
[docs] def forward(self, x):
"""Forward call."""
return F.normalize(x, **self.normalize_kwargs)