Source code for catalyst.callbacks.metrics.functional_metric
from typing import Callable, Dict, Iterable, Union
from catalyst.callbacks.metric import FunctionalBatchMetricCallback
from catalyst.metrics._functional_metric import FunctionalBatchMetric
[docs]class FunctionalMetricCallback(FunctionalBatchMetricCallback):
    """
    Args:
        input_key: input key to use for metric calculation, specifies our `y_pred`
        target_key: output key to use for metric calculation, specifies our `y_true`
        metric_fn: metric function, that get outputs, targets and return score as torch.Tensor
        metric_key: key to store computed metric in ``runner.batch_metrics`` dictionary
        compute_on_call: Computes and returns metric value during metric call.
            Used for per-batch logging. default: True
        log_on_batch: boolean flag to log computed metrics every batch
        prefix: metric prefix
        suffix: metric suffix
    """
[docs]    def __init__(
        self,
        input_key: Union[str, Iterable[str], Dict[str, str]],
        target_key: Union[str, Iterable[str], Dict[str, str]],
        metric_fn: Callable,
        metric_key: str,
        compute_on_call: bool = True,
        log_on_batch: bool = True,
        prefix: str = None,
        suffix: str = None,
    ):
        """Init."""
        super().__init__(
            metric=FunctionalBatchMetric(
                metric_fn=metric_fn,
                metric_key=metric_key,
                compute_on_call=compute_on_call,
                prefix=prefix,
                suffix=suffix,
            ),
            input_key=input_key,
            target_key=target_key,
            log_on_batch=log_on_batch,
        )  
__all__ = ["FunctionalMetricCallback"]