Shortcuts

Source code for catalyst.contrib.models.cv.segmentation.linknet

from typing import Dict
from functools import partial

import numpy as np

from .blocks import DecoderSumBlock, EncoderDownsampleBlock
from .bridge import UnetBridge
from .core import ResnetUnetSpec, UnetSpec
from .decoder import UNetDecoder
from .encoder import ResnetEncoder, UnetEncoder
from .head import UnetHead


[docs]class Linknet(UnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: UnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = UnetBridge( in_channels=encoder.out_channels, in_strides=encoder.out_strides, out_channels=encoder.out_channels[-1] * 2, block_fn=EncoderDownsampleBlock, **bridge_params ) decoder = UNetDecoder( in_channels=bridge.out_channels, in_strides=bridge.out_strides, block_fn=DecoderSumBlock, **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params ) return encoder, bridge, decoder, head
[docs]class ResnetLinknet(ResnetUnetSpec): """@TODO: Docs. Contribution is welcome.""" def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = UNetDecoder( in_channels=encoder.out_channels, in_strides=encoder.out_strides, block_fn=partial( DecoderSumBlock, aggregate_first=False, upsample_scale=None ), **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params ) return encoder, bridge, decoder, head