Shortcuts

Source code for catalyst.callbacks.metrics.classification

from catalyst.callbacks.metric import BatchMetricCallback
from catalyst.metrics._classification import (
    MulticlassPrecisionRecallF1SupportMetric,
    MultilabelPrecisionRecallF1SupportMetric,
)


[docs]class PrecisionRecallF1SupportCallback(BatchMetricCallback): """Multiclass PrecisionRecallF1Support metric callback. Args: input_key: input key to use for metric calculation, specifies our `y_pred` target_key: output key to use for metric calculation, specifies our `y_true` num_classes: number of classes zero_division: value to set in case of zero division during metrics (precision, recall) computation; should be one of 0 or 1 log_on_batch: boolean flag to log computed metrics every batch prefix: metric prefix suffix: metric suffix """ def __init__( self, input_key: str, target_key: str, num_classes: int, zero_division: int = 0, log_on_batch: bool = True, prefix: str = None, suffix: str = None, ): """Init.""" super().__init__( metric=MulticlassPrecisionRecallF1SupportMetric( num_classes=num_classes, zero_division=zero_division, prefix=prefix, suffix=suffix ), input_key=input_key, target_key=target_key, log_on_batch=log_on_batch, )
[docs]class MultilabelPrecisionRecallF1SupportCallback(BatchMetricCallback): """Multilabel PrecisionRecallF1Support metric callback. Args: input_key: input key to use for metric calculation, specifies our `y_pred` target_key: output key to use for metric calculation, specifies our `y_true` num_classes: number of classes zero_division: value to set in case of zero division during metrics (precision, recall) computation; should be one of 0 or 1 log_on_batch: boolean flag to log computed metrics every batch prefix: metric prefix suffix: metric suffix """ def __init__( self, input_key: str, target_key: str, num_classes: int, zero_division: int = 0, log_on_batch: bool = True, prefix: str = None, suffix: str = None, ): """Init.""" super().__init__( metric=MultilabelPrecisionRecallF1SupportMetric( num_classes=num_classes, zero_division=zero_division, prefix=prefix, suffix=suffix ), input_key=input_key, target_key=target_key, log_on_batch=log_on_batch, )
__all__ = [ "PrecisionRecallF1SupportCallback", "MultilabelPrecisionRecallF1SupportCallback", ]