Source code for catalyst.contrib.models.segmentation.linknet

from typing import Dict  # isort:skip
from functools import partial

import numpy as np

from .blocks import DecoderSumBlock, EncoderDownsampleBlock
from .bridge import UnetBridge
from .core import ResnetUnetSpec, UnetSpec
from .decoder import UNetDecoder
from .encoder import ResnetEncoder, UnetEncoder
from .head import UnetHead


[docs]class Linknet(UnetSpec): def _get_components( self, encoder: UnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = UnetBridge( in_channels=encoder.out_channels, in_strides=encoder.out_strides, out_channels=encoder.out_channels[-1] * 2, block_fn=EncoderDownsampleBlock, **bridge_params ) decoder = UNetDecoder( in_channels=bridge.out_channels, in_strides=bridge.out_strides, block_fn=DecoderSumBlock, **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params ) return encoder, bridge, decoder, head
[docs]class ResnetLinknet(ResnetUnetSpec): def _get_components( self, encoder: ResnetEncoder, num_classes: int, bridge_params: Dict, decoder_params: Dict, head_params: Dict, ): bridge = None decoder = UNetDecoder( in_channels=encoder.out_channels, in_strides=encoder.out_strides, block_fn=partial( DecoderSumBlock, aggregate_first=False, upsample_scale=None ), **decoder_params ) head = UnetHead( in_channels=decoder.out_channels, in_strides=decoder.out_strides, out_channels=num_classes, num_upsample_blocks=int(np.log2(decoder.out_strides[-1])), **head_params ) return encoder, bridge, decoder, head