Source code for catalyst.callbacks.metric_aggregation
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
import torch
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
def _sum_aggregation(x):
return torch.sum(torch.stack(x))
def _mean_aggregation(x):
return torch.mean(torch.stack(x))
[docs]class MetricAggregationCallback(Callback):
"""A callback to aggregate several metrics in one value.
Args:
prefix: new key for aggregated metric.
metrics (Union[str, List[str], Dict[str, float]]): If not None,
it aggregates only the values from the metric by these keys.
for ``weighted_sum`` aggregation it must be a Dict[str, float].
mode: function for aggregation.
Must be either ``sum``, ``mean`` or ``weighted_sum`` or user's
function to aggregate metrics. This function must get dict of
metrics and runner and return aggregated metric. It can be
useful for complicated fine tuning with different losses that
depends on epochs and loader or something also
scope: type of metric. Must be either ``batch`` or ``loader``
multiplier: scale factor for the aggregated metric.
Examples:
Loss is a weighted sum of cross entropy loss and binary cross entropy loss
>>> import torch
>>> from torch.utils.data import DataLoader, TensorDataset
>>> from catalyst import dl
>>>
>>> # sample data
>>> num_samples, num_features, num_classes = int(1e4), int(1e1), 4
>>> X = torch.rand(num_samples, num_features)
>>> y = (torch.rand(num_samples, ) * num_classes).to(torch.int64)
>>>
>>> # pytorch loaders
>>> dataset = TensorDataset(X, y)
>>> loader = DataLoader(dataset, batch_size=32, num_workers=1)
>>> loaders = {"train": loader, "valid": loader}
>>>
>>> # model, criterion, optimizer, scheduler
>>> model = torch.nn.Linear(num_features, num_classes)
>>> criterion = {"ce": torch.nn.CrossEntropyLoss(),
>>> "bce": torch.nn.BCEWithLogitsLoss()}
>>> optimizer = torch.optim.Adam(model.parameters())
>>> scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
>>>
>>> class CustomRunner(dl.Runner):
>>> def handle_batch(self, batch):
>>> x, y = batch
>>> logits = self.model(x)
>>> num_classes = logits.shape[-1]
>>> targets_onehot = torch.nn.functional.one_hot(y, num_classes=num_classes)
>>> self.batch = {
>>> "features": x,
>>> "logits": logits,
>>> "targets": y,
>>> "targets_onehot": targets_onehot.float()
>>> }
>>>
>>> # model training
>>> runner = CustomRunner()
>>> runner.train(
>>> model=model,
>>> criterion=criterion,
>>> optimizer=optimizer,
>>> scheduler=scheduler,
>>> loaders=loaders,
>>> logdir="./logdir",
>>> num_epochs=3,
>>> callbacks=[
>>> dl.AccuracyCallback(input_key="logits",
>>> target_key="targets",
>>> num_classes=num_classes),
>>> dl.CriterionCallback(input_key="logits",
>>> target_key="targets",
>>> metric_key="loss_ce",
>>> criterion_key="ce"),
>>> dl.CriterionCallback(input_key="logits",
>>> target_key="targets_onehot",
>>> metric_key="loss_bce",
>>> criterion_key="bce"),
>>> # loss aggregation
>>> dl.MetricAggregationCallback(prefix='loss',
>>> metrics={'loss_ce': 0.6, 'loss_bce': 0.4},
>>> mode='weighted_sum'),
>>> dl.OptimizerCallback(metric_key="loss")
>>> ]
>>> )
"""
def __init__(
self,
prefix: str,
metrics: Union[str, List[str], Dict[str, float]] = None,
mode: Union[str, Callable] = "mean",
scope: str = "batch",
multiplier: float = 1.0,
) -> None:
"""Init."""
super().__init__(order=CallbackOrder.metric_aggregation, node=CallbackNode.all)
if prefix is None or not isinstance(prefix, str):
raise ValueError("prefix must be str")
if mode in ("sum", "mean"):
if metrics is not None and not isinstance(metrics, list):
raise ValueError(
"For `sum` or `mean` mode the metrics must be "
"None or list or str (not dict)"
)
elif mode in ("weighted_sum", "weighted_mean"):
if metrics is None or not isinstance(metrics, dict):
raise ValueError(
"For `weighted_sum` or `weighted_mean` mode "
"the metrics must be specified "
"and must be a dict"
)
elif not callable(mode):
raise NotImplementedError(
"mode must be `sum`, `mean` " "or `weighted_sum` or `weighted_mean` or be Callable"
)
assert scope in ("batch", "loader")
if isinstance(metrics, str):
metrics = [metrics]
self.prefix = prefix
self.metrics = metrics
self.mode = mode
self.scope = scope
self.multiplier = multiplier
if mode in ("sum", "weighted_sum", "weighted_mean"):
self.aggregation_fn = _sum_aggregation
if mode == "weighted_mean":
weights_sum = sum(metrics.items())
self.metrics = {key: weight / weights_sum for key, weight in metrics.items()}
elif mode == "mean":
self.aggregation_fn = _mean_aggregation
elif callable(mode):
self.aggregation_fn = mode
def _preprocess(self, metrics: Any) -> List[float]:
if self.metrics is not None:
try:
if self.mode == "weighted_sum":
result = [metrics[key] * value for key, value in self.metrics.items()]
else:
result = [metrics[key] for key in self.metrics]
except KeyError:
raise KeyError(f"Could not found required key out of {metrics.keys()}")
else:
result = list(metrics.values())
result = [metric.float() for metric in result]
return result
def _process_metrics(self, metrics: Dict, runner: "IRunner") -> None:
if callable(self.mode):
metric_aggregated = self.aggregation_fn(metrics, runner) * self.multiplier
else:
metrics_processed = self._preprocess(metrics)
metric_aggregated = self.aggregation_fn(metrics_processed) * self.multiplier
metrics[self.prefix] = metric_aggregated
def on_batch_end(self, runner: "IRunner") -> None:
"""Computes the metric and add it to the batch metrics.
Args:
runner: current runner
"""
if self.scope == "batch":
self._process_metrics(runner.batch_metrics, runner)
def on_loader_end(self, runner: "IRunner") -> None:
"""Computes the metric and add it to the loader metrics.
Args:
runner: current runner
"""
self._process_metrics(runner.loader_metrics, runner)
__all__ = [
"MetricAggregationCallback",
]