Source code for catalyst.contrib.modules.pooling

import torch
import torch.nn as nn
import torch.nn.functional as F

from catalyst.contrib.registry import MODULES


[docs]class GlobalAvgPool2d(nn.Module): def __init__(self): super().__init__()
[docs] def forward(self, x): h, w = x.shape[2:] return F.avg_pool2d(input=x, kernel_size=(h, w))
[docs] @staticmethod def out_features(in_features): return in_features
[docs]class GlobalMaxPool2d(nn.Module): def __init__(self): super().__init__()
[docs] def forward(self, x): h, w = x.shape[2:] return F.max_pool2d(input=x, kernel_size=(h, w))
[docs] @staticmethod def out_features(in_features): return in_features
[docs]class GlobalConcatPool2d(nn.Module): def __init__(self): super().__init__() self.avg = GlobalAvgPool2d() self.max = GlobalMaxPool2d()
[docs] def forward(self, x): return torch.cat([self.avg(x), self.max(x)], 1)
[docs] @staticmethod def out_features(in_features): return in_features * 2
[docs]class GlobalAttnPool2d(nn.Module): def __init__(self, in_features, activation_fn="Sigmoid"): super().__init__() activation_fn = MODULES.get_if_str(activation_fn) self.attn = nn.Sequential( nn.Conv2d( in_features, 1, kernel_size=1, stride=1, padding=0, bias=False ), activation_fn() )
[docs] def forward(self, x): x_a = self.attn(x) x = x * x_a x = torch.sum(x, dim=[-2, -1], keepdim=True) return x
[docs] @staticmethod def out_features(in_features): return in_features
[docs]class GlobalAvgAttnPool2d(nn.Module): def __init__(self, in_features, activation_fn="Sigmoid"): super().__init__() self.avg = GlobalAvgPool2d() self.attn = GlobalAttnPool2d(in_features, activation_fn)
[docs] def forward(self, x): return torch.cat([self.avg(x), self.attn(x)], 1)
[docs] @staticmethod def out_features(in_features): return in_features * 2
[docs]class GlobalMaxAttnPool2d(nn.Module): def __init__(self, in_features, activation_fn="Sigmoid"): super().__init__() self.max = GlobalMaxPool2d() self.attn = GlobalAttnPool2d(in_features, activation_fn)
[docs] def forward(self, x): return torch.cat([self.max(x), self.attn(x)], 1)
[docs] @staticmethod def out_features(in_features): return in_features * 2
[docs]class GlobalConcatAttnPool2d(nn.Module): def __init__(self, in_features, activation_fn="Sigmoid"): super().__init__() self.avg = GlobalAvgPool2d() self.max = GlobalMaxPool2d() self.attn = GlobalAttnPool2d(in_features, activation_fn)
[docs] def forward(self, x): return torch.cat([self.avg(x), self.max(x), self.attn(x)], 1)
[docs] @staticmethod def out_features(in_features): return in_features * 3