Shortcuts

Source code for catalyst.metrics._accumulative

from typing import Any, Dict, Iterable, Optional
from collections import defaultdict

import torch

from catalyst.metrics._metric import ICallbackLoaderMetric


[docs]class AccumulativeMetric(ICallbackLoaderMetric): """This metric accumulates all the input data along loader Args: keys: list of keys to accumulate data from batch compute_on_call: if True, allows compute metric's value on call prefix: metric prefix suffix: metric suffix """
[docs] def __init__( self, keys: Iterable[str] = None, compute_on_call: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, ) -> None: """Init AccumulativeMetric""" super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) self.keys = keys or () self.storage = None self.num_samples = None self.collected_batches = None self.collected_samples = None
[docs] def reset(self, num_batches: int, num_samples: int) -> None: """ Reset metrics fields Args: num_batches: expected number of batches num_samples: expected number of samples to accumulate """ self.num_samples = num_samples self.collected_batches = 0 self.collected_samples = 0 self.storage = None
def _allocate_memory(self, shape_type_dict: Dict[str, Any]) -> None: """ Allocate memory for data accumulation Args: shape_type_dict: dict that contains information about shape of each tensor and it's dtype """ self.storage = defaultdict(torch.Tensor) for key in shape_type_dict: self.storage[key] = torch.empty( size=shape_type_dict[key]["shape"], dtype=shape_type_dict[key]["dtype"] )
[docs] def update(self, **kwargs) -> None: """ Update accumulated data with new batch Args: **kwargs: tensors that should be accumulates """ if self.collected_batches == 0: shape_type_dict = {} for field_name in self.keys: shape_type_dict[field_name] = {} shape_type_dict[field_name]["shape"] = ( self.num_samples, *(kwargs[field_name].shape[1:]), ) shape_type_dict[field_name]["dtype"] = kwargs[field_name].dtype self._allocate_memory(shape_type_dict=shape_type_dict) bs = 0 for field_name in self.keys: bs = kwargs[field_name].shape[0] self.storage[field_name][ self.collected_samples : self.collected_samples + bs, ... ] = (kwargs[field_name].detach().cpu()) self.collected_samples += bs self.collected_batches += 1
[docs] def compute(self) -> Dict[str, torch.Tensor]: """ Return accumulated data Returns: dict of accumulated data """ return self.storage
[docs] def compute_key_value(self) -> Dict[str, torch.Tensor]: """ Return accumulated data Returns: dict of accumulated data """ return self.compute()
__all__ = ["AccumulativeMetric"]