Source code for catalyst.dl.callbacks.metrics.auc

from typing import List  # isort:skip

from catalyst.dl.core import MeterMetricsCallback
from catalyst.dl.meters import AUCMeter


[docs]class AUCCallback(MeterMetricsCallback): """ Calculates the AUC per class for each loader. Currently, supports binary and multi-label cases. """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "auc", class_names: List[str] = None, num_classes: int = 2, activation: str = "Sigmoid" ): """ Args: input_key (str): input key to use for auc calculation specifies our ``y_true``. output_key (str): output key to use for auc calculation; specifies our ``y_pred`` prefix (str): name to display for auc when printing class_names (List[str]): class names to display in the logs. If None, defaults to indices for each class, starting from 0. num_classes (int): Number of classes; must be > 1 activation (str): An torch.nn activation applied to the outputs. Must be one of ['none', 'Sigmoid', 'Softmax2d'] """ num_classes = num_classes \ if class_names is None \ else len(class_names) meters = [AUCMeter() for _ in range(num_classes)] super().__init__( metric_names=[prefix], meter_list=meters, input_key=input_key, output_key=output_key, class_names=class_names, num_classes=num_classes, activation=activation )
__all__ = ["AUCCallback"]