Shortcuts

Source code for catalyst.contrib.models.cv.segmentation.psp

from typing import Dict

from catalyst.contrib.models.cv.segmentation.core import (
    ResnetUnetSpec,
    UnetSpec,
)
from catalyst.contrib.models.cv.segmentation.decoder import PSPDecoder
from catalyst.contrib.models.cv.segmentation.encoder import (
    ResnetEncoder,
    UnetEncoder,
)
from catalyst.contrib.models.cv.segmentation.head import UnetHead


[docs]class PSPnet(UnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: UnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = PSPDecoder( in_channels=encoder.out_channels, in_strides=encoder.out_strides, **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, upsample_scale=decoder.out_strides[-1], interpolation_mode="bilinear", align_corners=True, **head_params ) return encoder, bridge, decoder, head
[docs]class ResnetPSPnet(ResnetUnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = PSPDecoder( in_channels=encoder.out_channels, in_strides=encoder.out_strides, **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, upsample_scale=decoder.out_strides[-1], interpolation_mode="bilinear", align_corners=True, **head_params ) return encoder, bridge, decoder, head
__all__ = ["PSPnet", "ResnetPSPnet"]