Source code for catalyst.dl.callbacks.metrics.accuracy
from typing import List
from catalyst.core import BatchMetricCallback
from catalyst.utils import metrics
from catalyst.utils.metrics.functional import (
get_default_topk_args,
wrap_topk_metric2dict,
)
[docs]class AccuracyCallback(BatchMetricCallback):
"""Accuracy metric callback.
Computes multi-class accuracy@topk for the specified values of `topk`.
.. note::
For multi-label accuracy please use
`catalyst.dl.callbacks.metrics.MultiLabelAccuracyCallback`
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "accuracy",
multiplier: float = 1.0,
topk_args: List[int] = None,
num_classes: int = None,
accuracy_args: List[int] = None,
**kwargs,
):
"""
Args:
input_key (str): input key to use for accuracy calculation;
specifies our `y_true`
output_key (str): output key to use for accuracy calculation;
specifies our `y_pred`
prefix (str): key for the metric's name
topk_args (List[int]): specifies which accuracy@K to log:
[1] - accuracy
[1, 3] - accuracy at 1 and 3
[1, 3, 5] - accuracy at 1, 3 and 5
num_classes (int): number of classes to calculate ``topk_args``
if ``accuracy_args`` is None
activation (str): An torch.nn activation applied to the outputs.
Must be one of ``"none"``, ``"Sigmoid"``, or ``"Softmax"``
"""
topk_args = (
topk_args or accuracy_args or get_default_topk_args(num_classes)
)
super().__init__(
prefix=prefix,
metric_fn=wrap_topk_metric2dict(
metrics.accuracy, topk_args=topk_args
),
input_key=input_key,
output_key=output_key,
multiplier=multiplier,
**kwargs,
)
[docs]class MultiLabelAccuracyCallback(BatchMetricCallback):
"""Accuracy metric callback.
Computes multi-class accuracy@topk for the specified values of `topk`.
.. note::
For multi-label accuracy please use
`catalyst.dl.callbacks.metrics.MultiLabelAccuracyCallback`
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "multi_label_accuracy",
threshold: float = None,
activation: str = "Sigmoid",
):
"""
Args:
input_key (str): input key to use for accuracy calculation;
specifies our `y_true`
output_key (str): output key to use for accuracy calculation;
specifies our `y_pred`
prefix (str): key for the metric's name
threshold (float): threshold for for model output
activation (str): An torch.nn activation applied to the outputs.
Must be one of ``"none"``, ``"Sigmoid"``, or ``"Softmax"``
"""
super().__init__(
prefix=prefix,
metric_fn=metrics.multi_label_accuracy,
input_key=input_key,
output_key=output_key,
threshold=threshold,
activation=activation,
)
__all__ = ["AccuracyCallback", "MultiLabelAccuracyCallback"]