Shortcuts

Source code for catalyst.dl.callbacks.wrappers

from typing import List

from catalyst.core import Callback, State


[docs]class PhaseWrapperCallback(Callback): """ CallbackWrapper which disables/enables handlers dependant on current phase and event type May be useful i.e. to disable/enable optimizers & losses """ LEVEL_STAGE = "stage" LEVEL_EPOCH = "epoch" LEVEL_LOADER = "loader" LEVEL_BATCH = "batch" TIME_START = "start" TIME_END = "end"
[docs] def __init__( self, base_callback: Callback, active_phases: List[str] = None, inactive_phases: List[str] = None, ): """ Args: @TODO: Docs. Contribution is welcome """ super().__init__(base_callback.order) assert (active_phases is None) ^ ( inactive_phases is None ), "Exactly one of active/inactive phases must be specified" self.callback = base_callback self.active_phases = active_phases or [] self.inactive_phases = inactive_phases or [] assert ( len(self.active_phases) + len(self.inactive_phases) > 0 ), "Wrapper has no sense if callback is always active/inactive"
[docs] def is_active_on_phase(self, phase, level, time): """@TODO: Docs. Contribution is welcome.""" return self._is_active_on_phase(phase=phase)
def _is_active_on_phase(self, phase): if phase is None: # if phase is None every callback is active return True if phase in self.active_phases: return True if self.inactive_phases and phase not in self.inactive_phases: return True return False
[docs] def on_stage_start(self, state: State): """Stage start hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_STAGE, time=self.TIME_START ): self.callback.on_stage_start(state)
[docs] def on_stage_end(self, state: State): """Stage end hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_STAGE, time=self.TIME_END ): self.callback.on_stage_end(state)
[docs] def on_epoch_start(self, state: State): """Epoch start hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_EPOCH, time=self.TIME_START ): self.callback.on_epoch_start(state)
[docs] def on_epoch_end(self, state: State): """Epoch end hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_EPOCH, time=self.TIME_END ): self.callback.on_epoch_end(state)
[docs] def on_loader_start(self, state: State): """Loader start hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_LOADER, time=self.TIME_START ): self.callback.on_loader_start(state)
[docs] def on_loader_end(self, state: State): """Loader end hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_LOADER, time=self.TIME_END ): self.callback.on_loader_end(state)
[docs] def on_batch_start(self, state: State): """Batch start hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_BATCH, time=self.TIME_START ): self.callback.on_batch_start(state)
[docs] def on_batch_end(self, state: State): """Batch end hook. Args: state (State): current state """ if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_BATCH, time=self.TIME_END ): self.callback.on_batch_end(state)
[docs] def on_exception(self, state: State): """On exception event. Args: state (State): current state """ self.callback.on_exception(state)
[docs]class PhaseBatchWrapperCallback(PhaseWrapperCallback): """@TODO: Docs. Contribution is welcome."""
[docs] def is_active_on_phase(self, phase, level, time): """@TODO: Docs. Contribution is welcome.""" if level != self.LEVEL_BATCH: return True return self._is_active_on_phase(phase)
__all__ = ["PhaseWrapperCallback", "PhaseBatchWrapperCallback"]