Source code for catalyst.contrib.nn.criterion.iou
# flake8: noqa
# @TODO: code formatting issue for 20.07 release
from functools import partial
import torch
from torch import nn
from catalyst.metrics.functional import wrap_metric_fn_with_activation
from catalyst.metrics.iou import iou
[docs]class IoULoss(nn.Module):
"""The intersection over union (Jaccard) loss.
@TODO: Docs. Contribution is welcome.
"""
[docs] def __init__(
self,
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
):
"""
Args:
eps: epsilon to avoid zero division
threshold: threshold for outputs binarization
activation: An torch.nn activation applied to the outputs.
Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax'``
"""
super().__init__()
metric_fn = wrap_metric_fn_with_activation(
metric_fn=iou, activation=activation
)
self.loss_fn = partial(metric_fn, eps=eps, threshold=threshold)
[docs] def forward(self, outputs, targets):
"""@TODO: Docs. Contribution is welcome."""
per_class_iou = self.loss_fn(outputs, targets) # [bs; num_classes]
iou = torch.mean(per_class_iou)
return 1 - iou
[docs]class BCEIoULoss(nn.Module):
"""The Intersection over union (Jaccard) with BCE loss.
@TODO: Docs. Contribution is welcome.
"""
[docs] def __init__(
self,
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
reduction: str = "mean",
):
"""
Args:
eps: epsilon to avoid zero division
threshold: threshold for outputs binarization
activation: An torch.nn activation applied to the outputs.
Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax'``
reduction: Specifies the reduction to apply
to the output of BCE
"""
super().__init__()
self.bce_loss = nn.BCEWithLogitsLoss(reduction=reduction)
self.iou_loss = IoULoss(eps, threshold, activation)
[docs] def forward(self, outputs, targets):
"""@TODO: Docs. Contribution is welcome."""
iou = self.iou_loss.forward(outputs, targets)
bce = self.bce_loss(outputs, targets)
loss = iou + bce
return loss
__all__ = ["IoULoss", "BCEIoULoss"]