from typing import Any, Callable, Dict, List, Union # isort:skip
from abc import ABC, abstractmethod
from collections import defaultdict
import logging
import torch
from catalyst import utils
from catalyst.core import Callback, CallbackNode, CallbackOrder, State
from catalyst.utils import meters
logger = logging.getLogger(__name__)
class _MetricCallback(ABC, Callback):
def __init__(
self,
prefix: str,
input_key: Union[str, List[str], Dict[str, str]] = "targets",
output_key: Union[str, List[str], Dict[str, str]] = "logits",
multiplier: float = 1.0,
**metrics_kwargs,
):
super().__init__(order=CallbackOrder.Metric, node=CallbackNode.All)
self.prefix = prefix
# self.metric_fn = partial(metric_fn, **metric_params)
self.input_key = input_key
self.output_key = output_key
self.multiplier = multiplier
self.metrics_kwargs = metrics_kwargs
self._get_input = utils.get_dictkey_auto_fn(self.input_key)
self._get_output = utils.get_dictkey_auto_fn(self.output_key)
kv_types = (dict, tuple, list, type(None))
is_value_input = \
isinstance(self.input_key, str) and self.input_key != "__all__"
is_value_output = \
isinstance(self.output_key, str) and self.output_key != "__all__"
is_kv_input = \
isinstance(self.input_key, kv_types) or self.input_key == "__all__"
is_kv_output = (
isinstance(self.output_key, kv_types)
or self.output_key == "__all__"
)
# @TODO: fix to only KV usage
if hasattr(self, "_compute_metric"):
pass # overridden in descendants
elif is_value_input and is_value_output:
self._compute_metric = self._compute_metric_value
elif is_kv_input and is_kv_output:
self._compute_metric = self._compute_metric_key_value
else:
raise NotImplementedError()
@property
@abstractmethod
def metric_fn(self):
pass
def _compute_metric_value(self, state: State):
output = self._get_output(state.batch_out, self.output_key)
input = self._get_input(state.batch_in, self.input_key)
metric = self.metric_fn(output, input, **self.metrics_kwargs)
return metric
def _compute_metric_key_value(self, state: State):
output = self._get_output(state.batch_out, self.output_key)
input = self._get_input(state.batch_in, self.input_key)
metric = self.metric_fn(**output, **input, **self.metrics_kwargs)
return metric
def on_batch_end(self, state: State):
"""
Computes the metric and add it to batch metrics
"""
metric = self._compute_metric(state) * self.multiplier
state.batch_metrics[self.prefix] = metric
[docs]class MetricCallback(_MetricCallback):
"""
A callback that returns single metric on `state.on_batch_end`
"""
def __init__(
self,
prefix: str,
metric_fn: Callable,
input_key: Union[str, List[str], Dict[str, str]] = "targets",
output_key: Union[str, List[str], Dict[str, str]] = "logits",
multiplier: float = 1.0,
**metric_kwargs,
):
super().__init__(
prefix=prefix,
input_key=input_key,
output_key=output_key,
multiplier=multiplier,
**metric_kwargs,
)
self.metric = metric_fn
@property
def metric_fn(self):
return self.metric
[docs]class MultiMetricCallback(MetricCallback):
"""
A callback that returns multiple metrics on `state.on_batch_end`
"""
def __init__(
self,
prefix: str,
metric_fn: Callable,
list_args: List,
input_key: Union[str, List[str], Dict[str, str]] = "targets",
output_key: Union[str, List[str], Dict[str, str]] = "logits",
multiplier: float = 1.0,
**metrics_kwargs,
):
super().__init__(
prefix=prefix,
metric_fn=metric_fn,
input_key=input_key,
output_key=output_key,
multiplier=multiplier,
**metrics_kwargs,
)
self.list_args = list_args
[docs] def on_batch_end(self, state: State):
metrics_ = self._compute_metric(state)
for arg, metric in zip(self.list_args, metrics_):
if isinstance(arg, int):
key = f"{self.prefix}{arg:02}"
else:
key = f"{self.prefix}_{arg}"
state.batch_metrics[key] = metric * self.multiplier
[docs]class MetricAggregationCallback(Callback):
"""
A callback to aggregate several metrics in one value.
"""
[docs] def __init__(
self,
prefix: str,
metrics: Union[str, List[str], Dict[str, float]] = None,
mode: str = "mean",
multiplier: float = 1.0
) -> None:
"""
Args:
prefix (str): new key for aggregated metric.
metrics (Union[str, List[str], Dict[str, float]]): If not None,
it aggregates only the values from the metric by these keys.
for ``weighted_sum`` aggregation it must be a Dict[str, float].
mode (str): function for aggregation.
Must be either ``sum``, ``mean`` or ``weighted_sum``.
multiplier (float): scale factor for the aggregated metric.
"""
super().__init__(
order=CallbackOrder.MetricAggregation, node=CallbackNode.All
)
if prefix is None or not isinstance(prefix, str):
raise ValueError("prefix must be str")
if mode in ("sum", "mean"):
if metrics is not None and not isinstance(metrics, list):
raise ValueError(
"For `sum` or `mean` mode the loss_keys must be "
"None or list or str (not dict)"
)
elif mode in ("weighted_sum", "weighted_mean"):
if metrics is None or not isinstance(metrics, dict):
raise ValueError(
"For `weighted_sum` or `weighted_mean` mode "
"the loss_keys must be specified "
"and must be a dict"
)
else:
raise NotImplementedError(
"mode must be `sum`, `mean` "
"or `weighted_sum` or `weighted_mean`"
)
if isinstance(metrics, str):
metrics = [metrics]
self.prefix = prefix
self.metrics = metrics
self.mode = mode
self.multiplier = multiplier
if mode in ("sum", "weighted_sum", "weighted_mean"):
self.aggregation_fn = \
lambda x: torch.sum(torch.stack(x)) * multiplier
if mode == "weighted_mean":
weights_sum = sum(metrics.items())
self.metrics = {
key: weight / weights_sum
for key, weight in metrics.items()
}
elif mode == "mean":
self.aggregation_fn = \
lambda x: torch.mean(torch.stack(x)) * multiplier
def _preprocess(self, metrics: Any) -> List[float]:
if self.metrics is not None:
if self.mode == "weighted_sum":
result = [
metrics[key] * value for key, value in self.metrics.items()
]
else:
result = [metrics[key] for key in self.metrics]
else:
result = list(metrics.values())
return result
[docs] def on_batch_end(self, state: State) -> None:
"""
Computes the metric and add it to the metrics
"""
metrics = self._preprocess(state.batch_metrics)
metric = self.aggregation_fn(metrics)
state.batch_metrics[self.prefix] = metric
[docs]class MetricManagerCallback(Callback):
"""
Prepares metrics for logging, transferring values from PyTorch to numpy
"""
def __init__(self):
super().__init__(
order=CallbackOrder.Logging - 1,
node=CallbackNode.All,
)
self.meters: Dict[str, meters.AverageValueMeter] = None
@staticmethod
def _to_single_value(value: Any) -> float:
if hasattr(value, "item"):
value = value.item()
value = float(value)
return value
@staticmethod
def _process_metrics(metrics: Dict[str, Any]):
output = {}
for key, value in metrics.items():
value = utils.get_distributed_mean(value)
value = MetricManagerCallback._to_single_value(value)
output[key] = value
return output
[docs] def on_epoch_start(self, state: State):
state.epoch_metrics = defaultdict(None)
[docs] def on_loader_start(self, state: State):
state.loader_metrics = defaultdict(None)
self.meters = defaultdict(meters.AverageValueMeter)
[docs] def on_loader_end(self, state: State):
for key, value in self.meters.items():
value = value.mean
state.loader_metrics[key] = value
for key, value in state.loader_metrics.items():
state.epoch_metrics[f"{state.loader_name}_{key}"] = value
[docs] def on_batch_start(self, state: State):
state.batch_metrics = defaultdict(None)
[docs] def on_batch_end(self, state: State):
state.batch_metrics = self._process_metrics(state.batch_metrics)
for key, value in state.batch_metrics.items():
self.meters[key].add(value)
__all__ = [
"MetricCallback", "MultiMetricCallback", "MetricAggregationCallback",
"MetricManagerCallback"
]