Source code for catalyst.contrib.nn.modules.rms_norm
# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import torch
from torch import nn
[docs]class RMSNorm(nn.Module):
"""An implementation of RMS Normalization.
@TODO: Docs (link to paper). Contribution is welcome.
"""
[docs] def __init__(
self, dimension: int, epsilon: float = 1e-8, is_bias: bool = False
):
"""
Args:
dimension: the dimension of the layer output to normalize
epsilon: an epsilon to prevent dividing by zero
in case the layer has zero variance. (default = 1e-8)
is_bias: a boolean value whether to include bias term
while normalization
"""
super().__init__()
self.dimension = dimension
self.epsilon = epsilon
self.is_bias = is_bias
self.scale = nn.Parameter(torch.ones(self.dimension))
if self.is_bias:
self.bias = nn.Parameter(torch.zeros(self.dimension))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""@TODO: Docs. Contribution is welcome."""
x_std = torch.sqrt(torch.mean(x ** 2, -1, keepdim=True))
x_norm = x / (x_std + self.epsilon)
if self.is_bias:
return self.scale * x_norm + self.bias
return self.scale * x_norm
__all__ = ["RMSNorm"]