Shortcuts

Source code for catalyst.metrics.iou

from functools import partial

import torch


[docs]def iou( outputs: torch.Tensor, targets: torch.Tensor, class_dim: int = 1, threshold: float = None, eps: float = 1e-7, ) -> torch.Tensor: """Computes the iou/jaccard score. Args: outputs: [N; K; ...] tensor that for each of the N examples indicates the probability of the example belonging to each of the K classes, according to the model. targets: binary [N; K; ...] tensort that encodes which of the K classes are associated with the N-th input class_dim: indicates class dimention (K) for ``outputs`` and ``targets`` tensors (default = 1) threshold: threshold for outputs binarization eps: epsilon to avoid zero division Returns: IoU (Jaccard) score Examples: >>> size = 4 >>> half_size = size // 2 >>> shape = (1, 1, size, size) >>> empty = torch.zeros(shape) >>> full = torch.ones(shape) >>> left = torch.ones(shape) >>> left[:, :, :, half_size:] = 0 >>> right = torch.ones(shape) >>> right[:, :, :, :half_size] = 0 >>> top_left = torch.zeros(shape) >>> top_left[:, :, :half_size, :half_size] = 1 >>> pred = torch.cat([empty, left, empty, full, left, top_left], dim=1) >>> targets = torch.cat([full, right, empty, full, left, left], dim=1) >>> iou( >>> outputs=pred, >>> targets=targets, >>> class_dim=1, >>> threshold=0.5, >>> ) tensor([0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.5]) """ if threshold is not None: outputs = (outputs > threshold).float() num_dims = len(outputs.shape) assert num_dims > 2, "shape mismatch, please check the docs for more info" assert ( outputs.shape == targets.shape ), "shape mismatch, please check the docs for more info" dims = list(range(num_dims)) # support negative index if class_dim < 0: class_dim = num_dims + class_dim dims.pop(class_dim) sum_fn = partial(torch.sum, dim=dims) intersection = sum_fn(targets * outputs) union = sum_fn(targets) + sum_fn(outputs) # this looks a bit awkward but `eps * (union == 0)` term # makes sure that if I and U are both 0, than IoU == 1 # and if U != 0 and I == 0 the eps term in numerator is zeroed out # i.e. (0 + eps) / (U - 0 + eps) doesn't happen iou_score = (intersection + eps * (union == 0).float()) / ( union - intersection + eps ) return iou_score
jaccard = iou __all__ = ["iou", "jaccard"]