Source code for catalyst.contrib.models.cv.segmentation.fpn
from typing import Dict # isort:skip
from .blocks import EncoderDownsampleBlock
from .bridge import UnetBridge
from .core import ResnetUnetSpec, UnetSpec
from .decoder import FPNDecoder
from .encoder import ResnetEncoder, UnetEncoder
from .head import FPNHead
[docs]class FPNUnet(UnetSpec):
def _get_components(
self,
encoder: UnetEncoder,
num_classes: int,
bridge_params: Dict,
decoder_params: Dict,
head_params: Dict,
):
bridge = UnetBridge(
in_channels=encoder.out_channels,
in_strides=encoder.out_strides,
out_channels=encoder.out_channels[-1] * 2,
block_fn=EncoderDownsampleBlock,
**bridge_params
)
decoder = FPNDecoder(
in_channels=bridge.out_channels,
in_strides=bridge.out_strides,
**decoder_params
)
head = FPNHead(
in_channels=decoder.out_channels,
in_strides=decoder.out_strides,
out_channels=num_classes,
upsample_scale=decoder.out_strides[-1],
interpolation_mode="bilinear",
align_corners=True,
**head_params
)
return encoder, bridge, decoder, head
[docs]class ResnetFPNUnet(ResnetUnetSpec):
def _get_components(
self,
encoder: ResnetEncoder,
num_classes: int,
bridge_params: Dict,
decoder_params: Dict,
head_params: Dict,
):
bridge = None
decoder = FPNDecoder(
in_channels=encoder.out_channels,
in_strides=encoder.out_strides,
**decoder_params
)
head = FPNHead(
in_channels=decoder.out_channels,
in_strides=decoder.out_strides,
out_channels=num_classes,
upsample_scale=decoder.out_strides[-1],
interpolation_mode="bilinear",
align_corners=True,
**head_params
)
return encoder, bridge, decoder, head