Shortcuts

Source code for catalyst.metrics._classification

from typing import Any, Dict, List, Optional, Tuple, Union
from collections import defaultdict

import numpy as np

import torch

from catalyst.metrics._metric import ICallbackBatchMetric
from catalyst.metrics.functional._classification import (
    get_aggregated_metrics,
    get_binary_metrics,
)
from catalyst.metrics.functional._misc import (
    get_binary_statistics,
    get_multiclass_statistics,
    get_multilabel_statistics,
)
from catalyst.settings import SETTINGS
from catalyst.utils import get_device
from catalyst.utils.distributed import all_gather, get_backend

if SETTINGS.xla_required:
    import torch_xla.core.xla_model as xm


class BinaryStatisticsMetric(ICallbackBatchMetric):
    """
    This metric accumulates true positive, false positive, true negative,
    false negative, support statistics from binary data.

    Args:
        compute_on_call: if True, computes and returns metric value during metric call
        prefix: metric prefix
        suffix: metric suffix

    Raises:
        ValueError: if mode is incorrect

    Examples:
    .. code-block:: python
        import torch
        from torch.utils.data import DataLoader, TensorDataset
        from catalyst import dl
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples,) * num_classes).to(torch.int64)
        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}
        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
        # model training
        runner = dl.SupervisedRunner(
            input_key="features",
            output_key="logits",
            target_key="targets",
            loss_key="loss"
        )
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir="./logdir",
            num_epochs=3,
            valid_loader="valid",
            valid_metric="accuracy03",
            minimize_valid_metric=False,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.PrecisionRecallF1SupportCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.AUCCallback(input_key="logits", target_key="targets"),
            ],
        )
    .. note::
        Please follow the `minimal examples`_ sections for more use cases.
        .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples  # noqa: E501, W505

    """

    def __init__(
        self,
        compute_on_call: bool = True,
        prefix: Optional[str] = None,
        suffix: Optional[str] = None,
    ):
        """Init params"""
        super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
        self.statistics = None
        self.num_classes = 2
        self._ddp_backend = None
        self.reset()

    # multiprocessing could not handle lamdas, so..
    def _mp_hack(self):
        return np.zeros(shape=(self.num_classes,))

    def reset(self) -> None:
        """Reset all the statistics."""
        self.statistics = defaultdict(self._mp_hack)
        self._ddp_backend = get_backend()

    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Union[Tuple[int, int, int, int, int], Tuple[Any, Any, Any, Any, Any]]:
        """
        Compute statistics from outputs and targets,
        update accumulated statistics with new values.

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            Tuple of int or array: true negative, false positive, false
                negative, true positive and support statistics

        """
        tn, fp, fn, tp, support = get_binary_statistics(
            outputs=outputs.cpu().detach(), targets=targets.cpu().detach()
        )

        tn = tn.numpy()
        fp = fp.numpy()
        fn = fn.numpy()
        tp = tp.numpy()
        support = support.numpy()

        self.statistics["tn"] += tn
        self.statistics["fp"] += fp
        self.statistics["fn"] += fn
        self.statistics["tp"] += tp
        self.statistics["support"] += support

        return tn, fp, fn, tp, support

    def update_key_value(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Dict[str, float]:
        """
        Update statistics and return statistics intermediate result

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            dict of statistics for current input

        """
        tn, fp, fn, tp, support = self.update(outputs=outputs, targets=targets)
        return {"fn": fn, "fp": fp, "support": support, "tn": tn, "tp": tp}

    def compute(self) -> Dict[str, Union[int, np.array]]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        """
        return self.statistics

    def compute_key_value(self) -> Dict[str, float]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        Examples:
            >>> {"tp": 3, "fp": 4, "tn": 5, "fn": 1, "support": 13}

        """
        result = self.compute()
        return {k: result[k] for k in sorted(result.keys())}


class MulticlassStatisticsMetric(ICallbackBatchMetric):
    """
    This metric accumulates true positive, false positive, true negative,
    false negative, support statistics from multiclass data.

    Args:
        compute_on_call: if True, computes and returns metric value during metric call
        prefix: metric prefix
        suffix: metric suffix
        num_classes: number of classes

    Raises:
        ValueError: if mode is incorrect

    Examples:
    .. code-block:: python
        import torch
        from torch.utils.data import DataLoader, TensorDataset
        from catalyst import dl
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples,) * num_classes).to(torch.int64)
        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}
        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
        # model training
        runner = dl.SupervisedRunner(
            input_key="features",
            output_key="logits",
            target_key="targets",
            loss_key="loss"
        )
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir="./logdir",
            num_epochs=3,
            valid_loader="valid",
            valid_metric="accuracy03",
            minimize_valid_metric=False,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.PrecisionRecallF1SupportCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.AUCCallback(input_key="logits", target_key="targets"),
            ],
        )
    .. note::
        Please follow the `minimal examples`_ sections for more use cases.
        .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples  # noqa: E501, W505

    """

    def __init__(
        self,
        compute_on_call: bool = True,
        prefix: Optional[str] = None,
        suffix: Optional[str] = None,
        num_classes: Optional[int] = None,
    ):
        """Init params"""
        super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
        self.statistics = None
        self.num_classes = num_classes
        self._ddp_backend = None
        self.reset()

    # multiprocessing could not handle lamdas, so..
    def _mp_hack(self):
        return np.zeros(shape=(self.num_classes,))

    def reset(self) -> None:
        """Reset all the statistics."""
        self.statistics = defaultdict(self._mp_hack)
        self._ddp_backend = get_backend()

    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Union[Tuple[int, int, int, int, int, int], Tuple[Any, Any, Any, Any, Any, int]]:
        """
        Compute statistics from outputs and targets,
        update accumulated statistics with new values.

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            Tuple of int or array: true negative, false positive, false
                negative, true positive, support statistics and num_classes

        """
        tn, fp, fn, tp, support, num_classes = get_multiclass_statistics(
            outputs=outputs.cpu().detach(),
            targets=targets.cpu().detach(),
            num_classes=self.num_classes,
        )

        tn = tn.numpy()
        fp = fp.numpy()
        fn = fn.numpy()
        tp = tp.numpy()
        support = support.numpy()

        if self.num_classes is None:
            self.num_classes = num_classes

        self.statistics["tn"] += tn
        self.statistics["fp"] += fp
        self.statistics["fn"] += fn
        self.statistics["tp"] += tp
        self.statistics["support"] += support

        return tn, fp, fn, tp, support, self.num_classes

    def update_key_value(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Dict[str, float]:
        """
        Update statistics and return statistics intermediate result

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            dict of statistics for current input

        """
        tn, fp, fn, tp, support, _ = self.update(outputs=outputs, targets=targets)
        return {"fn": fn, "fp": fp, "support": support, "tn": tn, "tp": tp}

    def compute(self) -> Dict[str, Union[int, np.array]]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        """
        return self.statistics

    def compute_key_value(self) -> Dict[str, float]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        Examples:
            >>> {"tp": np.array([1, 2, 1]), "fp": np.array([2, 1, 0]), ...}

        """
        result = self.compute()
        return {k: result[k] for k in sorted(result.keys())}


class MultilabelStatisticsMetric(ICallbackBatchMetric):
    """
    This metric accumulates true positive, false positive, true negative,
    false negative, support statistics from multilabel data.

    Args:
        compute_on_call: if True, computes and returns metric value during metric call
        prefix: metric prefix
        suffix: metric suffix
        num_classes: number of classes

    Raises:
        ValueError: if mode is incorrect

    Examples:
    .. code-block:: python
        import torch
        from torch.utils.data import DataLoader, TensorDataset
        from catalyst import dl
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples,) * num_classes).to(torch.int64)
        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}
        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
        # model training
        runner = dl.SupervisedRunner(
            input_key="features",
            output_key="logits",
            target_key="targets",
            loss_key="loss"
        )
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir="./logdir",
            num_epochs=3,
            valid_loader="valid",
            valid_metric="accuracy03",
            minimize_valid_metric=False,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.PrecisionRecallF1SupportCallback(
                    input_key="logits", target_key="targets", num_classes=num_classes
                ),
                dl.AUCCallback(input_key="logits", target_key="targets"),
            ],
        )
    .. note::
        Please follow the `minimal examples`_ sections for more use cases.
        .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples  # noqa: E501, W505

    """

    def __init__(
        self,
        compute_on_call: bool = True,
        prefix: Optional[str] = None,
        suffix: Optional[str] = None,
        num_classes: Optional[int] = None,
    ):
        """Init params"""
        super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
        self.statistics = None
        self.num_classes = num_classes
        self._ddp_backend = None
        self.reset()

    # multiprocessing could not handle lamdas, so..
    def _mp_hack(self):
        return np.zeros(shape=(self.num_classes,))

    def reset(self) -> None:
        """Reset all the statistics."""
        self.statistics = defaultdict(self._mp_hack)
        self._ddp_backend = get_backend()

    def update(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Union[Tuple[int, int, int, int, int, int], Tuple[Any, Any, Any, Any, Any, int]]:
        """
        Compute statistics from outputs and targets,
        update accumulated statistics with new values.

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            Tuple of int or array: true negative, false positive, false
                negative, true positive, support statistics and num_classes

        """
        tn, fp, fn, tp, support, num_classes = get_multilabel_statistics(
            outputs=outputs.cpu().detach(), targets=targets.cpu().detach()
        )

        tn = tn.numpy()
        fp = fp.numpy()
        fn = fn.numpy()
        tp = tp.numpy()
        support = support.numpy()
        if self.num_classes is None:
            self.num_classes = num_classes

        self.statistics["tn"] += tn
        self.statistics["fp"] += fp
        self.statistics["fn"] += fn
        self.statistics["tp"] += tp
        self.statistics["support"] += support

        return tn, fp, fn, tp, support, self.num_classes

    def update_key_value(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> Dict[str, float]:
        """
        Update statistics and return statistics intermediate result

        Args:
            outputs: prediction values
            targets: true answers

        Returns:
            dict of statistics for current input

        """
        tn, fp, fn, tp, support, _ = self.update(outputs=outputs, targets=targets)
        return {"fn": fn, "fp": fp, "support": support, "tn": tn, "tp": tp}

    def compute(self) -> Dict[str, Union[int, np.array]]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        """
        return self.statistics

    def compute_key_value(self) -> Dict[str, float]:
        """
        Return accumulated statistics

        Returns:
            dict of statistics

        Examples:
            >>> {"tp": np.array([1, 2, 1]), "fp": np.array([2, 1, 0]), ...}

        """
        result = self.compute()
        return {k: result[k] for k in sorted(result.keys())}


[docs]class BinaryPrecisionRecallF1Metric(BinaryStatisticsMetric): """Precision, recall, f1_score and support metrics for binary classification. Args: zero_division: value to set in case of zero division during metrics (precision, recall) computation; should be one of 0 or 1 compute_on_call: if True, allows compute metric's value on call prefix: metric prefix suffix: metric suffix """ def __init__( self, zero_division: int = 0, compute_on_call: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, ): """Init BinaryPrecisionRecallF1SupportMetric instance""" super().__init__( compute_on_call=compute_on_call, prefix=prefix, suffix=suffix, ) self.statistics = None self.zero_division = zero_division self.reset() @staticmethod def _convert_metrics_to_kv( precision_value: float, recall_value: float, f1_value: float ) -> Dict[str, float]: """ Convert list of metrics to key-value Args: precision_value: precision value recall_value: recall value f1_value: f1 value Returns: dict of metrics """ kv_metrics = { "precision": precision_value, "recall": recall_value, "f1": f1_value, } return kv_metrics def reset(self) -> None: """Reset all the statistics and metrics fields.""" self.statistics = defaultdict(int) def update( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Tuple[float, float, float]: """ Update statistics and return metrics intermediate results Args: outputs: predicted labels targets: target labels Returns: tuple of intermediate metrics: precision, recall, f1 score """ tn, fp, fn, tp, support = super().update(outputs=outputs, targets=targets) precision_value, recall_value, f1_value = get_binary_metrics( tp=tp, fp=fp, fn=fn, zero_division=self.zero_division ) return precision_value, recall_value, f1_value def update_key_value( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Dict[str, float]: """ Update statistics and return metrics intermediate results Args: outputs: predicted labels targets: target labels Returns: dict of intermediate metrics """ precision_value, recall_value, f1_value = self.update( outputs=outputs, targets=targets ) kv_metrics = self._convert_metrics_to_kv( precision_value=precision_value, recall_value=recall_value, f1_value=f1_value ) return kv_metrics def compute(self) -> Tuple[float, float, float]: """ Compute metrics with accumulated statistics Returns: tuple of metrics: precision, recall, f1 score """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": self.statistics = { k: xm.mesh_reduce(k, v, np.sum) for k, v in self.statistics.items() } elif self._ddp_backend == "ddp": for key in self.statistics: value: List[int] = all_gather(self.statistics[key]) value: int = sum(value) self.statistics[key] = value precision_value, recall_value, f1_value = get_binary_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], zero_division=self.zero_division, ) return precision_value, recall_value, f1_value def compute_key_value(self) -> Dict[str, float]: """ Compute metrics with all accumulated statistics Returns: dict of metrics """ precision_value, recall_value, f1_value = self.compute() kv_metrics = self._convert_metrics_to_kv( precision_value=precision_value, recall_value=recall_value, f1_value=f1_value ) return kv_metrics
[docs]class MulticlassPrecisionRecallF1SupportMetric(MulticlassStatisticsMetric): """ Metric that can collect statistics and count precision, recall, f1_score and support with it. Args: zero_division: value to set in case of zero division during metrics (precision, recall) computation; should be one of 0 or 1 compute_on_call: if True, allows compute metric's value on call compute_per_class_metrics: boolean flag to compute per-class metrics (default: SETTINGS.compute_per_class_metrics or False). prefix: metrics prefix suffix: metrics suffix num_classes: number of classes """ def __init__( self, zero_division: int = 0, compute_on_call: bool = True, compute_per_class_metrics: bool = SETTINGS.compute_per_class_metrics, prefix: str = None, suffix: str = None, num_classes: Optional[int] = None, ) -> None: """Init PrecisionRecallF1SupportMetric instance""" super().__init__( compute_on_call=compute_on_call, prefix=prefix, suffix=suffix, num_classes=num_classes, ) self.compute_per_class_metrics = compute_per_class_metrics self.zero_division = zero_division self.num_classes = num_classes self.reset() def _convert_metrics_to_kv( self, per_class, micro, macro, weighted ) -> Dict[str, float]: """ Convert metrics aggregation to key-value format Args: per_class: per-class metrics, array of shape (4, self.num_classes) of precision, recall, f1 and support metrics micro: micro averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics macro: macro averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics weighted: weighted averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics Returns: dict of key-value metrics """ kv_metrics = {} for aggregation_name, aggregated_metrics in zip( ("_micro", "_macro", "_weighted"), (micro, macro, weighted) ): metrics = { f"{metric_name}/{aggregation_name}": metric_value for metric_name, metric_value in zip( ("precision", "recall", "f1"), aggregated_metrics[:-1] ) } kv_metrics.update(metrics) # @TODO: rewrite this block - should be without `num_classes` if self.compute_per_class_metrics: per_class_metrics = { f"{metric_name}/class_{i:02d}": metric_value[i] for metric_name, metric_value in zip( ("precision", "recall", "f1", "support"), per_class ) for i in range(self.num_classes) } kv_metrics.update(per_class_metrics) return kv_metrics def update( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Tuple[Any, Any, Any, Any]: """ Update statistics and return intermediate metrics results Args: outputs: prediction values targets: true answers Returns: tuple of metrics intermediate results with per-class, micro, macro and weighted averaging """ tn, fp, fn, tp, support, num_classes = super().update( outputs=outputs, targets=targets ) per_class, micro, macro, weighted = get_aggregated_metrics( tp=tp, fp=fp, fn=fn, support=support, zero_division=self.zero_division ) if self.num_classes is None: self.num_classes = num_classes return per_class, micro, macro, weighted def update_key_value( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Dict[str, float]: """ Update statistics and return intermediate metrics results Args: outputs: prediction values targets: true answers Returns: dict of metrics intermediate results """ per_class, micro, macro, weighted = self.update(outputs=outputs, targets=targets) metrics = self._convert_metrics_to_kv( per_class=per_class, micro=micro, macro=macro, weighted=weighted ) return metrics def compute(self) -> Any: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: list of aggregated metrics: per-class, micro, macro and weighted averaging of precision, recall, f1 score and support metrics """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": device = get_device() for key in self.statistics: key_statistics = torch.tensor([self.statistics[key]], device=device) key_statistics = xm.all_gather(key_statistics).sum(dim=0).cpu().numpy() self.statistics[key] = key_statistics elif self._ddp_backend == "ddp": for key in self.statistics: value: List[np.ndarray] = all_gather(self.statistics[key]) value: np.ndarray = np.sum(np.vstack(value), axis=0) self.statistics[key] = value per_class, micro, macro, weighted = get_aggregated_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], support=self.statistics["support"], zero_division=self.zero_division, ) if self.compute_per_class_metrics: return per_class, micro, macro, weighted else: return [], micro, macro, weighted def compute_key_value(self) -> Dict[str, float]: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: dict of metrics """ per_class, micro, macro, weighted = self.compute() metrics = self._convert_metrics_to_kv( per_class=per_class, micro=micro, macro=macro, weighted=weighted ) return metrics
[docs]class MultilabelPrecisionRecallF1SupportMetric(MultilabelStatisticsMetric): """ Metric that can collect statistics and count precision, recall, f1_score and support with it. Args: zero_division: value to set in case of zero division during metrics (precision, recall) computation; should be one of 0 or 1 compute_on_call: if True, allows compute metric's value on call compute_per_class_metrics: boolean flag to compute per-class metrics (default: SETTINGS.compute_per_class_metrics or False). prefix: metrics prefix suffix: metrics suffix num_classes: number of classes """ def __init__( self, zero_division: int = 0, compute_on_call: bool = True, compute_per_class_metrics: bool = SETTINGS.compute_per_class_metrics, prefix: str = None, suffix: str = None, num_classes: Optional[int] = None, ) -> None: """Init PrecisionRecallF1SupportMetric instance""" super().__init__( compute_on_call=compute_on_call, prefix=prefix, suffix=suffix, num_classes=num_classes, ) self.compute_per_class_metrics = compute_per_class_metrics self.zero_division = zero_division self.num_classes = num_classes self.reset() def _convert_metrics_to_kv( self, per_class, micro, macro, weighted ) -> Dict[str, float]: """ Convert metrics aggregation to key-value format Args: per_class: per-class metrics, array of shape (4, self.num_classes) of precision, recall, f1 and support metrics micro: micro averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics macro: macro averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics weighted: weighted averaged metrics, array of shape (self.num_classes) of precision, recall, f1 and support metrics Returns: dict of key-value metrics """ kv_metrics = {} for aggregation_name, aggregated_metrics in zip( ("_micro", "_macro", "_weighted"), (micro, macro, weighted) ): metrics = { f"{metric_name}/{aggregation_name}": metric_value for metric_name, metric_value in zip( ("precision", "recall", "f1"), aggregated_metrics[:-1] ) } kv_metrics.update(metrics) # @TODO: rewrite this block - should be without `num_classes` if self.compute_per_class_metrics: per_class_metrics = { f"{metric_name}/class_{i:02d}": metric_value[i] for metric_name, metric_value in zip( ("precision", "recall", "f1", "support"), per_class ) for i in range(self.num_classes) } kv_metrics.update(per_class_metrics) return kv_metrics def update( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Tuple[Any, Any, Any, Any]: """ Update statistics and return intermediate metrics results Args: outputs: prediction values targets: true answers Returns: tuple of metrics intermediate results with per-class, micro, macro and weighted averaging """ tn, fp, fn, tp, support, num_classes = super().update( outputs=outputs, targets=targets ) per_class, micro, macro, weighted = get_aggregated_metrics( tp=tp, fp=fp, fn=fn, support=support, zero_division=self.zero_division ) if self.num_classes is None: self.num_classes = num_classes return per_class, micro, macro, weighted def update_key_value( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Dict[str, float]: """ Update statistics and return intermediate metrics results Args: outputs: prediction values targets: true answers Returns: dict of metrics intermediate results """ per_class, micro, macro, weighted = self.update(outputs=outputs, targets=targets) metrics = self._convert_metrics_to_kv( per_class=per_class, micro=micro, macro=macro, weighted=weighted ) return metrics def compute(self) -> Any: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: list of aggregated metrics: per-class, micro, macro and weighted averaging of precision, recall, f1 score and support metrics """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": device = get_device() for key in self.statistics: key_statistics = torch.tensor([self.statistics[key]], device=device) key_statistics = xm.all_gather(key_statistics).sum(dim=0).cpu().numpy() self.statistics[key] = key_statistics elif self._ddp_backend == "ddp": for key in self.statistics: value: List[np.ndarray] = all_gather(self.statistics[key]) value: np.ndarray = np.sum(np.vstack(value), axis=0) self.statistics[key] = value per_class, micro, macro, weighted = get_aggregated_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], support=self.statistics["support"], zero_division=self.zero_division, ) if self.compute_per_class_metrics: return per_class, micro, macro, weighted else: return [], micro, macro, weighted def compute_key_value(self) -> Dict[str, float]: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: dict of metrics """ per_class, micro, macro, weighted = self.compute() metrics = self._convert_metrics_to_kv( per_class=per_class, micro=micro, macro=macro, weighted=weighted ) return metrics
__all__ = [ "BinaryPrecisionRecallF1Metric", "MulticlassPrecisionRecallF1SupportMetric", "MultilabelPrecisionRecallF1SupportMetric", ]