Source code for catalyst.core.callbacks.wrappers

from typing import List  # isort:skip

from catalyst.core import _State, Callback


[docs]class PhaseWrapperCallback(Callback): """ CallbackWrapper which disables/enables handlers dependant on current phase and event type May be useful i.e. to disable/enable optimizers & losses """ LEVEL_STAGE = "stage" LEVEL_EPOCH = "epoch" LEVEL_LOADER = "loader" LEVEL_BATCH = "batch" TIME_START = "start" TIME_END = "end" def __init__( self, base_callback: Callback, active_phases: List[str] = None, inactive_phases: List[str] = None ): super().__init__(base_callback.order) assert (active_phases is None) ^ (inactive_phases is None), \ "Exactly one of active/inactive phases must be specified" self.callback = base_callback self.active_phases = active_phases or [] self.inactive_phases = inactive_phases or [] assert len(self.active_phases) + len(self.inactive_phases) > 0, \ "Wrapper has no sense if callback is always active/inactive"
[docs] def is_active_on_phase(self, phase, level, time): return self._is_active_on_phase(phase=phase)
def _is_active_on_phase(self, phase): if phase is None: # if phase is None every callback is active return True if phase in self.active_phases: return True if self.inactive_phases and phase not in self.inactive_phases: return True return False
[docs] def on_stage_start(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_STAGE, time=self.TIME_START ): self.callback.on_stage_start(state)
[docs] def on_stage_end(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_STAGE, time=self.TIME_END ): self.callback.on_stage_end(state)
[docs] def on_epoch_start(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_EPOCH, time=self.TIME_START ): self.callback.on_epoch_start(state)
[docs] def on_epoch_end(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_EPOCH, time=self.TIME_END ): self.callback.on_epoch_end(state)
[docs] def on_loader_start(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_LOADER, time=self.TIME_START ): self.callback.on_loader_start(state)
[docs] def on_loader_end(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_LOADER, time=self.TIME_END ): self.callback.on_loader_end(state)
[docs] def on_batch_start(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_BATCH, time=self.TIME_START ): self.callback.on_batch_start(state)
[docs] def on_batch_end(self, state: _State): if self.is_active_on_phase( phase=state.phase, level=self.LEVEL_BATCH, time=self.TIME_END ): self.callback.on_batch_end(state)
[docs] def on_exception(self, state: _State): self.callback.on_exception(state)
[docs]class PhaseBatchWrapperCallback(PhaseWrapperCallback):
[docs] def is_active_on_phase(self, phase, level, time): if level != self.LEVEL_BATCH: return True return self._is_active_on_phase(phase)
__all__ = ["PhaseWrapperCallback", "PhaseBatchWrapperCallback"]