Source code for catalyst.metrics._auc
from typing import Dict, Tuple
import torch
from catalyst import SETTINGS
from catalyst.metrics._metric import ICallbackLoaderMetric
from catalyst.metrics.functional._auc import auc, binary_auc
from catalyst.metrics.functional._misc import process_multilabel_components
from catalyst.utils import get_device
from catalyst.utils.distributed import all_gather, get_backend
if SETTINGS.xla_required:
import torch_xla.core.xla_model as xm
[docs]class AUCMetric(ICallbackLoaderMetric):
"""AUC metric,
Args:
compute_on_call: if True, computes and returns metric value during metric call
prefix: metric prefix
suffix: metric suffix
.. warning::
This metric is under API improvement.
Examples:
.. code-block:: python
import torch
from catalyst import metrics
scores = torch.tensor([
[0.9, 0.1],
[0.1, 0.9],
])
targets = torch.tensor([
[1, 0],
[0, 1],
])
metric = metrics.AUCMetric()
# for efficient statistics storage
metric.reset(num_batches=1, num_samples=len(scores))
metric.update(scores, targets)
metric.compute()
# (
# tensor([1., 1.]) # per class
# 1.0, # micro
# 1.0, # macro
# 1.0 # weighted
# )
metric.compute_key_value()
# {
# 'auc': 1.0,
# 'auc/_micro': 1.0,
# 'auc/_macro': 1.0,
# 'auc/_weighted': 1.0
# 'auc/class_00': 1.0,
# 'auc/class_01': 1.0,
# }
metric.reset(num_batches=1, num_samples=len(scores))
metric(scores, targets)
# (
# tensor([1., 1.]) # per class
# 1.0, # micro
# 1.0, # macro
# 1.0 # weighted
# )
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# sample data
num_samples, num_features, num_classes = int(1e4), int(1e1), 4
X = torch.rand(num_samples, num_features)
y = (torch.rand(num_samples,) * num_classes).to(torch.int64)
# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_classes)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
# model training
runner = dl.SupervisedRunner(
input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
num_epochs=3,
valid_loader="valid",
valid_metric="accuracy03",
minimize_valid_metric=False,
verbose=True,
callbacks=[
dl.AccuracyCallback(
input_key="logits", target_key="targets", num_classes=num_classes
),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=num_classes
),
dl.AUCCallback(input_key="logits", target_key="targets"),
],
)
.. note::
Please follow the `minimal examples`_ sections for more use cases.
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples
"""
def __init__(self, compute_on_call: bool = True, prefix: str = None, suffix: str = None):
"""Init."""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name = f"{self.prefix}auc{self.suffix}"
self._ddp_backend = None
self.scores = []
self.targets = []
self.reset(0, 0)
def reset(self, num_batches, num_samples) -> None:
"""Resets all fields"""
self._ddp_backend = get_backend()
self.scores = []
self.targets = []
def update(self, scores: torch.Tensor, targets: torch.Tensor) -> None:
"""Updates metric value with statistics for new data.
Args:
scores: tensor with scores
targets: tensor with targets
"""
self.scores.append(scores.cpu().detach())
self.targets.append(targets.cpu().detach())
def compute(self) -> Tuple[torch.Tensor, float, float, float]:
"""Computes the AUC metric based on saved statistics."""
targets = torch.cat(self.targets)
scores = torch.cat(self.scores)
# ddp hotfix, could be done better
# but metric must handle DDP on it's own
if self._ddp_backend == "xla":
# if you have "RuntimeError: Aborted: Session XXX is not found" here
# please, ask Google for a more powerful TPU setup ;)
device = get_device()
scores = xm.all_gather(scores.to(device)).cpu().detach()
targets = xm.all_gather(targets.to(device)).cpu().detach()
elif self._ddp_backend == "ddp":
scores = torch.cat(all_gather(scores))
targets = torch.cat(all_gather(targets))
scores, targets, _ = process_multilabel_components(outputs=scores, targets=targets)
per_class = auc(scores=scores, targets=targets)
micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
macro = per_class.mean().item()
weights = targets.sum(axis=0) / len(targets)
weighted = (per_class * weights).sum().item()
return per_class, micro, macro, weighted
def compute_key_value(self) -> Dict[str, float]:
"""Computes the AUC metric based on saved statistics and returns key-value results."""
per_class_auc, micro_auc, macro_auc, weighted_auc = self.compute()
output = {
f"{self.metric_name}/class_{i:02d}": value.item()
for i, value in enumerate(per_class_auc)
}
output[f"{self.metric_name}/_micro"] = micro_auc
output[self.metric_name] = macro_auc
output[f"{self.metric_name}/_macro"] = macro_auc
output[f"{self.metric_name}/_weighted"] = weighted_auc
return output
__all__ = ["AUCMetric"]