Source code for catalyst.dl.callbacks.metrics.auc

from typing import List

from catalyst.core.callbacks import LoaderMetricCallback
from catalyst.utils import metrics
from catalyst.utils.metrics.functional import wrap_class_metric2dict

[docs]class AUCCallback(LoaderMetricCallback): """Calculates the AUC per class for each loader. .. note:: Currently, supports binary and multi-label cases. """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "auc", multiplier: float = 1.0, class_args: List[str] = None, **kwargs, ): """ Args: input_key (str): input key to use for auc calculation specifies our ``y_true``. output_key (str): output key to use for auc calculation; specifies our ``y_pred``. prefix (str): metric's name. multiplier (float): scale factor for the metric. class_args (List[str]): class names to display in the logs. If None, defaults to indices for each class, starting from 0 """ super().__init__( prefix=prefix, metric_fn=wrap_class_metric2dict( metrics.auc, class_args=class_args ), input_key=input_key, output_key=output_key, multiplier=multiplier, **kwargs, )
__all__ = ["AUCCallback"]