Source code for catalyst.contrib.modules.noisy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class NoisyLinear(nn.Linear):
def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
super().__init__(in_features, out_features, bias=bias)
self.sigma_weight = nn.Parameter(
torch.Tensor(out_features, in_features).fill_(sigma_init)
)
self.register_buffer(
"epsilon_weight", torch.zeros(out_features, in_features)
)
if bias:
self.sigma_bias = nn.Parameter(
torch.Tensor(out_features).fill_(sigma_init)
)
self.register_buffer("epsilon_bias", torch.zeros(out_features))
self.reset_parameters()
[docs] def reset_parameters(self):
std = math.sqrt(3 / self.in_features)
nn.init.uniform(self.weight, -std, std)
nn.init.uniform(self.bias, -std, std)
[docs] def forward(self, input):
torch.randn(self.epsilon_weight.size(), out=self.epsilon_weight)
bias = self.bias
if bias is not None:
torch.randn(self.epsilon_bias.size(), out=self.epsilon_bias)
bias = bias + self.sigma_bias * self.epsilon_bias
return F.linear(
input, self.weight + self.sigma_weight * self.epsilon_weight, bias
)
[docs]class NoisyFactorizedLinear(nn.Linear):
"""
NoisyNet layer with factorized gaussian noise
N.B. nn.Linear already initializes weight and bias to
"""
def __init__(self, in_features, out_features, sigma_zero=0.4, bias=True):
super().__init__(in_features, out_features, bias=bias)
sigma_init = sigma_zero / math.sqrt(in_features)
self.sigma_weight = nn.Parameter(
torch.Tensor(out_features, in_features).fill_(sigma_init)
)
self.register_buffer("epsilon_input", torch.zeros(1, in_features))
self.register_buffer("epsilon_output", torch.zeros(out_features, 1))
if bias:
self.sigma_bias = nn.Parameter(
torch.Tensor(out_features).fill_(sigma_init)
)
[docs] def forward(self, input):
torch.randn(self.epsilon_input.size(), out=self.epsilon_input)
torch.randn(self.epsilon_output.size(), out=self.epsilon_output)
func = lambda x: torch.sign(x) * torch.sqrt(torch.abs(x)) # noqa: E731
eps_in = func(self.epsilon_input)
eps_out = func(self.epsilon_output)
bias = self.bias
if bias is not None:
bias = bias + self.sigma_bias * eps_out.t()
noise_v = torch.mul(eps_in, eps_out)
return F.linear(input, self.weight + self.sigma_weight * noise_v, bias)