Source code for catalyst.tools.meters.ppv_tpr_f1_meter
"""
In this module **precision**, **recall** and **F1 score**
calculations are defined in separate functions.
:py:class:`PrecisionRecallF1ScoreMeter` can keep track for all three of these.
"""
from collections import defaultdict
import torch
from . import meter
[docs]def f1score(precision_value, recall_value, eps=1e-5):
    """
    Calculating F1-score from precision and recall to reduce computation
    redundancy.
    Args:
        precision_value: precision (0-1)
        recall_value: recall (0-1)
    Returns:
        F1 score (0-1)
    """
    numerator = 2 * (precision_value * recall_value)
    denominator = precision_value + recall_value + eps
    return numerator / denominator 
[docs]def precision(tp, fp, eps: float = 1e-5) -> float:
    """
    Calculates precision (a.k.a. positive predictive value) for binary
    classification and segmentation.
    Args:
        tp (int): number of true positives
        fp (int): number of false positives
    Returns:
        precision value (0-1)
    """
    # originally precision is: ppv = tp / (tp + fp + eps)
    # but when both masks are empty this gives: tp=0 and fp=0 => ppv=0
    # so here precision is defined as ppv := 1 - fdr (false discovery rate)
    return 1 - fp / (tp + fp + eps) 
[docs]def recall(tp, fn, eps=1e-5) -> float:
    """
    Calculates recall (a.k.a. true positive rate) for binary classification and
    segmentation.
    Args:
        tp: number of true positives
        fn: number of false negatives
    Returns:
        recall value (0-1)
    """
    # originally reacall is: tpr := tp / (tp + fn + eps)
    # but when both masks are empty this gives: tp=0 and fn=0 => tpr=0
    # so here recall is defined as tpr := 1 - fnr (false negative rate)
    return 1 - fn / (fn + tp + eps) 
[docs]class PrecisionRecallF1ScoreMeter(meter.Meter):
    """
    Keeps track of global true positives, false positives, and false negatives
    for each epoch and calculates precision, recall, and F1-score based on
    those metrics. Currently, this meter works for binary cases only, please
    use multiple instances of this class for multi-label cases.
    """
[docs]    def __init__(self, threshold=0.5):
        """
        Constructor method for the `` PrecisionRecallF1ScoreMeter`` class.
        """
        super(PrecisionRecallF1ScoreMeter, self).__init__()
        self.threshold = threshold
        self.reset() 
[docs]    def reset(self) -> None:
        """
        Resets true positive, false positive and false negative counts to 0.
        """
        self.tp_fp_fn_counts = defaultdict(int) 
[docs]    def add(self, output: torch.Tensor, target: torch.Tensor) -> None:
        """
        Thresholds predictions and calculates the true positives,
        false positives, and false negatives in comparison to the target.
        Args:
            output (torch.Tensor): prediction after activation function
                shape should be (batch_size, ...), but works with any shape
            target (torch.Tensor): label (binary),
                shape should be the same as output's shape
        """
        output = (output > self.threshold).float()
        tp = torch.sum(target * output)
        fp = torch.sum(output) - tp
        fn = torch.sum(target) - tp
        self.tp_fp_fn_counts["tp"] += tp
        self.tp_fp_fn_counts["fp"] += fp
        self.tp_fp_fn_counts["fn"] += fn 
[docs]    def value(self):
        """
        Calculates precision/recall/f1 based on the current stored
        tp/fp/fn counts.
        Returns:
            tuple of floats: (precision, recall, f1)
        """
        precision_value = precision(
            self.tp_fp_fn_counts["tp"], self.tp_fp_fn_counts["fp"]
        )
        recall_value = recall(
            self.tp_fp_fn_counts["tp"], self.tp_fp_fn_counts["fn"]
        )
        f1_value = f1score(precision_value, recall_value)
        return (float(precision_value), float(recall_value), float(f1_value))