Shortcuts

Source code for catalyst.dl.callbacks.inference

from collections import defaultdict
import os

import numpy as np

from catalyst.dl import Callback, CallbackOrder, State


# @TODO: refactor
[docs]class InferCallback(Callback): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, out_dir=None, out_prefix=None): """ Args: @TODO: Docs. Contribution is welcome """ super().__init__(CallbackOrder.Internal) self.out_dir = out_dir self.out_prefix = out_prefix self.predictions = defaultdict(lambda: []) self._keys_from_state = ["out_dir", "out_prefix"]
[docs] def on_stage_start(self, state: State): """Stage start hook. Args: state (State): current state """ for key in self._keys_from_state: value = getattr(state, key, None) if value is not None: setattr(self, key, value) # assert self.out_prefix is not None if self.out_dir is not None: self.out_prefix = str(self.out_dir) + "/" + str(self.out_prefix) if self.out_prefix is not None: os.makedirs(os.path.dirname(self.out_prefix), exist_ok=True)
[docs] def on_loader_start(self, state: State): """Loader start hook. Args: state (State): current state """ self.predictions = defaultdict(lambda: [])
[docs] def on_batch_end(self, state: State): """Batch end hook. Args: state (State): current state """ dct = state.output dct = {key: value.detach().cpu().numpy() for key, value in dct.items()} for key, value in dct.items(): self.predictions[key].append(value)
[docs] def on_loader_end(self, state: State): """Loader end hook. Args: state (State): current state """ self.predictions = { key: np.concatenate(value, axis=0) for key, value in self.predictions.items() } if self.out_prefix is not None: for key, value in self.predictions.items(): suffix = ".".join([state.loader_name, key]) np.save(f"{self.out_prefix}/{suffix}.npy", value)
__all__ = ["InferCallback"]