Shortcuts

Source code for catalyst.contrib.models.cv.segmentation.fpn

from typing import Dict

from .blocks import EncoderDownsampleBlock
from .bridge import UnetBridge
from .core import ResnetUnetSpec, UnetSpec
from .decoder import FPNDecoder
from .encoder import ResnetEncoder, UnetEncoder
from .head import FPNHead


[docs]class FPNUnet(UnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: UnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = UnetBridge( in_channels=encoder.out_channels, in_strides=encoder.out_strides, out_channels=encoder.out_channels[-1] * 2, block_fn=EncoderDownsampleBlock, **bridge_params ) decoder = FPNDecoder( in_channels=bridge.out_channels, in_strides=bridge.out_strides, **decoder_params ) head = FPNHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, upsample_scale=decoder.out_strides[-1], interpolation_mode="bilinear", align_corners=True, **head_params ) return encoder, bridge, decoder, head
[docs]class ResnetFPNUnet(ResnetUnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = FPNDecoder( in_channels=encoder.out_channels, in_strides=encoder.out_strides, **decoder_params ) head = FPNHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, upsample_scale=decoder.out_strides[-1], interpolation_mode="bilinear", align_corners=True, **head_params ) return encoder, bridge, decoder, head