Source code for catalyst.dl.callbacks.metrics.iou

from typing import Dict

import numpy as np

from catalyst.dl.core import (
    Callback, CallbackOrder, MetricCallback, RunnerState
from catalyst.dl.utils import criterion
from catalyst.utils.confusion_matrix import (
    calculate_confusion_matrix_from_tensors, calculate_tp_fp_fn

[docs]class IouCallback(MetricCallback): """ IoU (Jaccard) metric callback. """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "iou", eps: float = 1e-7, threshold: float = None, activation: str = "Sigmoid", ): """ Args: input_key (str): input key to use for iou calculation specifies our ``y_true``. output_key (str): output key to use for iou calculation; specifies our ``y_pred`` prefix (str): key to store in logs eps (float): epsilon to avoid zero division threshold (float): threshold for outputs binarization activation (str): An torch.nn activation applied to the outputs. Must be one of ['none', 'Sigmoid', 'Softmax2d'] """ super().__init__( prefix=prefix, metric_fn=criterion.iou, input_key=input_key, output_key=output_key, eps=eps, threshold=threshold, activation=activation )
JaccardCallback = IouCallback def calculate_jaccard( true_positives: np.array, false_positives: np.array, false_negatives: np.array ) -> np.array: """Calculate list of Jaccard indices. Args: true_positives: false_positives: false_negatives: Returns: """ epsilon = 1e-7 jaccard = (true_positives + epsilon) / ( true_positives + false_positives + false_negatives + epsilon ) if not np.all(jaccard <= 1): raise ValueError("Jaccard index should be less or equal to 1") if not np.all(jaccard > 0): raise ValueError("Jaccard index should be more than 1") return jaccard
[docs]class MulticlassIOUMetricCallback(Callback): def __init__( self, prefix: str = "jaccard", input_key: str = "targets", output_key: str = "logits", class_names=None, class_prefix="", **metric_params ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.input_key = input_key self.output_key = output_key self.metric_params = metric_params self.confusion_matrix = None self.class_names = class_names # dictionary {class_id: class_name} self.class_prefix = class_prefix def _reset_stats(self): self.confusion_matrix = None
[docs] def on_batch_end(self, state: RunnerState): outputs = state.output[self.output_key] targets = state.input[self.input_key] confusion_matrix = calculate_confusion_matrix_from_tensors( outputs, targets ) if self.confusion_matrix is None: self.confusion_matrix = confusion_matrix else: self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: RunnerState): tp_fp_fn_dict = calculate_tp_fp_fn(self.confusion_matrix) batch_metrics: Dict = calculate_jaccard(**tp_fp_fn_dict) for metric_id, jaccard_value in batch_metrics.items(): if metric_id not in self.class_names: continue metric_name = self.class_names[metric_id] state.metrics.epoch_values[state.loader_name][ f"{self.class_prefix}_{metric_name}" ] = jaccard_value state.metrics.epoch_values[state.loader_name]["mean"] = np.mean( [x for x in batch_metrics.values()] ) self._reset_stats()
MulticlassJaccardMetricCallback = MulticlassIOUMetricCallback __all__ = [ "IouCallback", "JaccardCallback", "MulticlassIOUMetricCallback", "MulticlassJaccardMetricCallback" ]