Source code for catalyst.dl.core.metric_manager

from typing import Any, Dict  # isort:skip
from collections import defaultdict
from numbers import Number
from time import time

from catalyst.dl.meters import AverageValueMeter


[docs]class TimerManager: def __init__(self): self._starts = {} self.elapsed = {}
[docs] def start(self, name: str) -> None: """Starts timer ``name`` Args: name (str): name of a timer """ self._starts[name] = time()
[docs] def stop(self, name: str) -> None: """Stops timer ``name`` Args: name (str): name of a timer """ assert name in self._starts, f"Timer '{name}' wasn't started" self.elapsed[name] = time() - self._starts[name] del self._starts[name]
[docs] def reset(self) -> None: """Reset all previous timers""" self.elapsed = {} self._starts = {}
[docs]class MetricManager: @staticmethod def _to_single_value(value: Any) -> float: if hasattr(value, "item"): value = value.item() assert isinstance(value, Number), \ f"{type(value)} is not a python number" value = float(value) return value def __init__( self, valid_loader: str = "valid", main_metric: str = "loss", minimize: bool = True, batch_consistant_metrics: bool = True, ): self._valid_loader = valid_loader self._main_metric = main_metric self._minimize = minimize self._batch_consistant_metrics = batch_consistant_metrics self._meters: Dict[str, AverageValueMeter] = None self._batch_values: Dict[str, float] = None self.epoch_values: Dict[str, Dict[str:float]] = None self.valid_values: Dict[str, float] = None self.best_main_metric_value: float = \ float("+inf") if self._minimize else float("-inf") self.is_best: bool = None self._current_loader_name: str = None @property def batch_values(self) -> Dict[str, float]: self.add_batch_value() return self._batch_values @property def main_metric_value(self) -> float: assert self._valid_loader in self.epoch_values, \ f"{self._valid_loader} is not available yet" assert self._main_metric in self.epoch_values[self._valid_loader], \ f"{self._main_metric} is not available yet" return self.epoch_values[self._valid_loader][self._main_metric]
[docs] def begin_epoch(self): self.epoch_values = defaultdict(lambda: {})
[docs] def end_epoch_train(self): assert self._valid_loader in self.epoch_values, \ f"{self._valid_loader} is not available by the epoch end" assert self._main_metric in self.epoch_values[self._valid_loader], \ f"{self._main_metric} value is not available by the epoch end" self.valid_values = self.epoch_values[self._valid_loader] if self._minimize: self.is_best = self.main_metric_value < self.best_main_metric_value else: self.is_best = self.main_metric_value > self.best_main_metric_value if self.is_best: self.best_main_metric_value = self.main_metric_value
[docs] def begin_loader(self, name: str): self._current_loader_name = name self._meters = defaultdict(AverageValueMeter)
[docs] def end_loader(self): for name, meter in self._meters.items(): self.epoch_values[self._current_loader_name][name] = meter.mean self._current_loader_name = None
[docs] def begin_batch(self): self._batch_values = {}
[docs] def end_batch(self): if len(self._meters) != 0 and self._batch_consistant_metrics: assert self._meters.keys() == self._batch_values.keys(), \ "Metric set is not consistent among batches" for name, value in self._batch_values.items(): self._meters[name].add(value)
[docs] def add_batch_value( self, name: str = None, value: Any = None, metrics_dict: Dict[str, Any] = None ): metrics_dict = metrics_dict or {} if name: metrics_dict[name] = value for name, value in metrics_dict.items(): self._batch_values[name] = self._to_single_value(value)
__all__ = ["TimerManager", "MetricManager"]