Source code for catalyst.dl.callbacks.metrics.auc
from typing import List # isort:skip
from catalyst.dl.core import MeterMetricsCallback
from catalyst.dl.meters import AUCMeter
[docs]class AUCCallback(MeterMetricsCallback):
"""
Calculates the AUC per class for each loader.
Currently, supports binary and multi-label cases.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "auc",
class_names: List[str] = None,
num_classes: int = 2,
activation: str = "Sigmoid"
):
"""
Args:
input_key (str): input key to use for auc calculation
specifies our ``y_true``.
output_key (str): output key to use for auc calculation;
specifies our ``y_pred``
prefix (str): name to display for auc when printing
class_names (List[str]): class names to display in the logs.
If None, defaults to indices for each class, starting from 0.
num_classes (int): Number of classes; must be > 1
activation (str): An torch.nn activation applied to the outputs.
Must be one of ['none', 'Sigmoid', 'Softmax2d']
"""
num_classes = num_classes \
if class_names is None \
else len(class_names)
meters = [AUCMeter() for _ in range(num_classes)]
super().__init__(
metric_names=[prefix],
meter_list=meters,
input_key=input_key,
output_key=output_key,
class_names=class_names,
num_classes=num_classes,
activation=activation
)
__all__ = ["AUCCallback"]