Source code for catalyst.dl.callbacks.metrics.iou
from typing import Dict
import numpy as np
from catalyst.dl.core import (
Callback, CallbackOrder, MetricCallback, RunnerState
)
from catalyst.dl.utils import criterion
from catalyst.utils.confusion_matrix import (
calculate_confusion_matrix_from_tensors, calculate_tp_fp_fn
)
[docs]class IouCallback(MetricCallback):
"""
IoU (Jaccard) metric callback.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "iou",
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
):
"""
Args:
input_key (str): input key to use for iou calculation
specifies our ``y_true``.
output_key (str): output key to use for iou calculation;
specifies our ``y_pred``
prefix (str): key to store in logs
eps (float): epsilon to avoid zero division
threshold (float): threshold for outputs binarization
activation (str): An torch.nn activation applied to the outputs.
Must be one of ['none', 'Sigmoid', 'Softmax2d']
"""
super().__init__(
prefix=prefix,
metric_fn=criterion.iou,
input_key=input_key,
output_key=output_key,
eps=eps,
threshold=threshold,
activation=activation
)
JaccardCallback = IouCallback
def calculate_jaccard(
true_positives: np.array,
false_positives: np.array,
false_negatives: np.array
) -> np.array:
"""Calculate list of Jaccard indices.
Args:
true_positives:
false_positives:
false_negatives:
Returns:
"""
epsilon = 1e-7
jaccard = (true_positives + epsilon) / (
true_positives + false_positives + false_negatives + epsilon
)
if not np.all(jaccard <= 1):
raise ValueError("Jaccard index should be less or equal to 1")
if not np.all(jaccard > 0):
raise ValueError("Jaccard index should be more than 1")
return jaccard
[docs]class MulticlassIOUMetricCallback(Callback):
def __init__(
self,
prefix: str = "jaccard",
input_key: str = "targets",
output_key: str = "logits",
class_names=None,
class_prefix="",
**metric_params
):
super().__init__(CallbackOrder.Metric)
self.prefix = prefix
self.input_key = input_key
self.output_key = output_key
self.metric_params = metric_params
self.confusion_matrix = None
self.class_names = class_names # dictionary {class_id: class_name}
self.class_prefix = class_prefix
def _reset_stats(self):
self.confusion_matrix = None
[docs] def on_batch_end(self, state: RunnerState):
outputs = state.output[self.output_key]
targets = state.input[self.input_key]
confusion_matrix = calculate_confusion_matrix_from_tensors(
outputs, targets
)
if self.confusion_matrix is None:
self.confusion_matrix = confusion_matrix
else:
self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: RunnerState):
tp_fp_fn_dict = calculate_tp_fp_fn(self.confusion_matrix)
batch_metrics: Dict = calculate_jaccard(**tp_fp_fn_dict)
for metric_id, jaccard_value in batch_metrics.items():
if metric_id not in self.class_names:
continue
metric_name = self.class_names[metric_id]
state.metrics.epoch_values[state.loader_name][
f"{self.class_prefix}_{metric_name}"
] = jaccard_value
state.metrics.epoch_values[state.loader_name]["mean"] = np.mean(
[x for x in batch_metrics.values()]
)
self._reset_stats()
MulticlassJaccardMetricCallback = MulticlassIOUMetricCallback
__all__ = [
"IouCallback", "JaccardCallback",
"MulticlassIOUMetricCallback", "MulticlassJaccardMetricCallback"
]