Shortcuts

Source code for catalyst.contrib.nn.modules.se

# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class cSE(nn.Module): # noqa: N801 """ The channel-wise SE (Squeeze and Excitation) block from the `Squeeze-and-Excitation Networks`__ paper. Adapted from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65939 and https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 Shape: - Input: (batch, channels, height, width) - Output: (batch, channels, height, width) (same shape as input) __ https://arxiv.org/abs/1709.01507 """
[docs] def __init__(self, in_channels: int, r: int = 16): """ Args: in_channels (int): The number of channels in the feature map of the input. r (int): The reduction ratio of the intermediate channels. Default: 16. """ super().__init__() self.linear1 = nn.Linear(in_channels, in_channels // r) self.linear2 = nn.Linear(in_channels // r, in_channels)
[docs] def forward(self, x: torch.Tensor): """Forward call.""" input_x = x x = x.view(*(x.shape[:-2]), -1).mean(-1) x = F.relu(self.linear1(x), inplace=True) x = self.linear2(x) x = x.unsqueeze(-1).unsqueeze(-1) x = torch.sigmoid(x) x = torch.mul(input_x, x) return x
[docs]class sSE(nn.Module): # noqa: N801 """ The sSE (Channel Squeeze and Spatial Excitation) block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks`__ paper. Adapted from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 Shape: - Input: (batch, channels, height, width) - Output: (batch, channels, height, width) (same shape as input) __ https://arxiv.org/abs/1803.02579 """
[docs] def __init__(self, in_channels: int): """ Args: in_channels (int): The number of channels in the feature map of the input. """ super().__init__() self.conv = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1)
[docs] def forward(self, x: torch.Tensor): """Forward call.""" input_x = x x = self.conv(x) x = torch.sigmoid(x) x = torch.mul(input_x, x) return x
[docs]class scSE(nn.Module): # noqa: N801 """ The scSE (Concurrent Spatial and Channel Squeeze and Channel Excitation) block from the `Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks`__ paper. Adapted from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/66178 Shape: - Input: (batch, channels, height, width) - Output: (batch, channels, height, width) (same shape as input) __ https://arxiv.org/abs/1803.02579 """
[docs] def __init__(self, in_channels: int, r: int = 16): """ Args: in_channels (int): The number of channels in the feature map of the input. r (int): The reduction ratio of the intermediate channels. Default: 16. """ super().__init__() self.cse_block = cSE(in_channels, r) self.sse_block = sSE(in_channels)
[docs] def forward(self, x: torch.Tensor): """Forward call.""" cse = self.cse_block(x) sse = self.sse_block(x) x = torch.add(cse, sse) return x
__all__ = ["sSE", "scSE", "cSE"]