Source code for catalyst.utils.initialization

from typing import Callable  # isort:skip

import numpy as np

import torch.nn as nn

ACTIVATIONS = {
    None: "sigmoid",
    nn.Sigmoid: "sigmoid",
    nn.Tanh: "tanh",
    nn.ReLU: "relu",
    nn.LeakyReLU: "leaky_relu",
    nn.ELU: "relu",
}


def _nonlinearity2name(nonlinearity):
    if isinstance(nonlinearity, nn.Module):
        nonlinearity = nonlinearity.__class__
    nonlinearity = ACTIVATIONS.get(nonlinearity, nonlinearity)
    nonlinearity = nonlinearity.lower()
    return nonlinearity


[docs]def create_optimal_inner_init( nonlinearity: nn.Module, **kwargs ) -> Callable[[nn.Module], None]: """ Create initializer for inner layers based on their activation function (nonlinearity). Args: nonlinearity: non-linear activation """ nonlinearity: str = _nonlinearity2name(nonlinearity) assert isinstance(nonlinearity, str) if nonlinearity in ["sigmoid", "tanh"]: weignt_init_fn = nn.init.xavier_uniform_ init_args = kwargs elif nonlinearity in ["relu", "leaky_relu"]: weignt_init_fn = nn.init.kaiming_normal_ init_args = {**{"nonlinearity": nonlinearity}, **kwargs} else: raise NotImplementedError def inner_init(layer): if isinstance(layer, (nn.Linear, nn.Conv1d, nn.Conv2d)): weignt_init_fn(layer.weight.data, **init_args) if layer.bias is not None: nn.init.zeros_(layer.bias.data) return inner_init
[docs]def outer_init(layer: nn.Module) -> None: """ Initialization for output layers of policy and value networks typically used in deep reinforcement learning literature. """ if isinstance(layer, (nn.Linear, nn.Conv1d, nn.Conv2d)): v = 3e-3 nn.init.uniform_(layer.weight.data, -v, v) if layer.bias is not None: nn.init.uniform_(layer.bias.data, -v, v)
[docs]def constant_init(module, val, bias=0): """ Initialize the module with constant value """ nn.init.constant_(module.weight, val) if hasattr(module, "bias") and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def uniform_init(module, a=0, b=1, bias=0): """ Initialize the module with uniform distribution """ nn.init.uniform_(module.weight, a, b) if hasattr(module, "bias") and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def normal_init(module, mean=0, std=1, bias=0): """ Initialize the module with normal distribution """ nn.init.normal_(module.weight, mean, std) if hasattr(module, "bias") and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def xavier_init(module, gain=1, bias=0, distribution="normal"): """ Initialize the module with xavier initialization """ assert distribution in ["uniform", "normal"] if distribution == "uniform": nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, "bias") and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def kaiming_init( module, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal" ): """ Initialize the module with he initialization """ assert distribution in ["uniform", "normal"] if distribution == "uniform": nn.init.kaiming_uniform_( module.weight, mode=mode, nonlinearity=nonlinearity ) else: nn.init.kaiming_normal_( module.weight, mode=mode, nonlinearity=nonlinearity ) if hasattr(module, "bias") and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def bias_init_with_prob(prior_prob): """ Initialize conv/fc bias value according to giving probablity """ bias_init = float(-np.log((1 - prior_prob) / prior_prob)) return bias_init