Source code for catalyst.contrib.nn.modules.common

import torch.nn as nn
import torch.nn.functional as F


[docs]class Flatten(nn.Module): def __init__(self): super().__init__()
[docs] def forward(self, x): return x.view(x.shape[0], -1)
[docs]class Lambda(nn.Module): def __init__(self, lambda_fn): super().__init__() self.lambda_fn = lambda_fn
[docs] def forward(self, x): return self.lambda_fn(x)
[docs]class Normalize(nn.Module): def __init__(self, **normalize_kwargs): super().__init__() self.normalize_kwargs = normalize_kwargs
[docs] def forward(self, x): return F.normalize(x, **self.normalize_kwargs)