Source code for catalyst.contrib.nn.criterion.trevsky
from typing import List, Optional
from functools import partial
import torch
from torch import nn
from catalyst.metrics.functional import trevsky
[docs]class TrevskyLoss(nn.Module):
"""The trevsky loss.
TrevskyIndex = TP / (TP + alpha * FN + betta * FP)
TrevskyLoss = 1 - TrevskyIndex
"""
[docs] def __init__(
self,
alpha: float,
beta: Optional[float] = None,
class_dim: int = 1,
mode: str = "macro",
weights: List[float] = None,
eps: float = 1e-7,
):
"""
Args:
alpha: false negative coefficient, bigger alpha bigger penalty for
false negative. Must be in (0, 1)
beta: false positive coefficient, bigger alpha bigger penalty for
false positive. Must be in (0, 1), if None beta = (1 - alpha)
class_dim: indicates class dimention (K) for
``outputs`` and ``targets`` tensors (default = 1)
mode: class summation strategy. Must be one of ['micro', 'macro',
'weighted']. If mode='micro', classes are ignored, and metric
are calculated generally. If mode='macro', metric are
calculated separately and than are averaged over all classes.
If mode='weighted', metric are calculated separately and than
summed over all classes with weights.
weights: class weights(for mode="weighted")
eps: epsilon to avoid zero division
"""
super().__init__()
assert mode in ["micro", "macro", "weighted"]
self.loss_fn = partial(
trevsky,
eps=eps,
alpha=alpha,
beta=beta,
class_dim=class_dim,
threshold=None,
mode=mode,
weights=weights,
)
def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Calculates loss between ``logits`` and ``target`` tensors."""
trevsky_score = self.loss_fn(outputs, targets)
return 1 - trevsky_score
[docs]class FocalTrevskyLoss(nn.Module):
"""The focal trevsky loss.
TrevskyIndex = TP / (TP + alpha * FN + betta * FP)
FocalTrevskyLoss = (1 - TrevskyIndex)^gamma
Node: focal will use per image, so loss will pay more attention on complicated images
"""
[docs] def __init__(
self,
alpha: float,
beta: Optional[float] = None,
gamma: float = 4 / 3,
class_dim: int = 1,
mode: str = "macro",
weights: List[float] = None,
eps: float = 1e-7,
):
"""
Args:
alpha: false negative coefficient, bigger alpha bigger penalty for
false negative. Must be in (0, 1)
beta: false positive coefficient, bigger alpha bigger penalty for
false positive. Must be in (0, 1), if None beta = (1 - alpha)
gamma: focal coefficient. It determines how much the weight of
simple examples is reduced.
class_dim: indicates class dimention (K) for
``outputs`` and ``targets`` tensors (default = 1)
mode: class summation strategy. Must be one of ['micro', 'macro',
'weighted']. If mode='micro', classes are ignored, and metric
are calculated generally. If mode='macro', metric are
calculated separately and than are averaged over all classes.
If mode='weighted', metric are calculated separately and than
summed over all classes with weights.
weights: class weights(for mode="weighted")
eps: epsilon to avoid zero division
"""
super().__init__()
self.gamma = gamma
self.trevsky_loss = TrevskyLoss(
alpha=alpha, beta=beta, class_dim=class_dim, mode=mode, weights=weights, eps=eps,
)
def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Calculates loss between ``logits`` and ``target`` tensors."""
loss = 0
batch_size = len(outputs)
for output_sample, target_sample in zip(outputs, targets):
output_sample = torch.unsqueeze(output_sample, dim=0)
target_sample = torch.unsqueeze(target_sample, dim=0)
sample_loss = self.trevsky_loss(output_sample, target_sample)
loss += sample_loss ** self.gamma
loss = loss / batch_size # mean over batch
return loss
__all__ = ["TrevskyLoss", "FocalTrevskyLoss"]