Source code for catalyst.core.callbacks.wrappers
from typing import List # isort:skip
from catalyst.core import _State, Callback
[docs]class PhaseWrapperCallback(Callback):
"""
CallbackWrapper which disables/enables handlers
dependant on current phase and event type
May be useful i.e. to disable/enable optimizers & losses
"""
LEVEL_STAGE = "stage"
LEVEL_EPOCH = "epoch"
LEVEL_LOADER = "loader"
LEVEL_BATCH = "batch"
TIME_START = "start"
TIME_END = "end"
def __init__(
self,
base_callback: Callback,
active_phases: List[str] = None,
inactive_phases: List[str] = None
):
super().__init__(base_callback.order)
assert (active_phases is None) ^ (inactive_phases is None), \
"Exactly one of active/inactive phases must be specified"
self.callback = base_callback
self.active_phases = active_phases or []
self.inactive_phases = inactive_phases or []
assert len(self.active_phases) + len(self.inactive_phases) > 0, \
"Wrapper has no sense if callback is always active/inactive"
[docs] def is_active_on_phase(self, phase, level, time):
return self._is_active_on_phase(phase=phase)
def _is_active_on_phase(self, phase):
if phase is None:
# if phase is None every callback is active
return True
if phase in self.active_phases:
return True
if self.inactive_phases and phase not in self.inactive_phases:
return True
return False
[docs] def on_stage_start(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_STAGE,
time=self.TIME_START
):
self.callback.on_stage_start(state)
[docs] def on_stage_end(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_STAGE,
time=self.TIME_END
):
self.callback.on_stage_end(state)
[docs] def on_epoch_start(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_EPOCH,
time=self.TIME_START
):
self.callback.on_epoch_start(state)
[docs] def on_epoch_end(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_EPOCH,
time=self.TIME_END
):
self.callback.on_epoch_end(state)
[docs] def on_loader_start(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_LOADER,
time=self.TIME_START
):
self.callback.on_loader_start(state)
[docs] def on_loader_end(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_LOADER,
time=self.TIME_END
):
self.callback.on_loader_end(state)
[docs] def on_batch_start(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_BATCH,
time=self.TIME_START
):
self.callback.on_batch_start(state)
[docs] def on_batch_end(self, state: _State):
if self.is_active_on_phase(
phase=state.phase,
level=self.LEVEL_BATCH,
time=self.TIME_END
):
self.callback.on_batch_end(state)
[docs] def on_exception(self, state: _State):
self.callback.on_exception(state)
[docs]class PhaseBatchWrapperCallback(PhaseWrapperCallback):
[docs] def is_active_on_phase(self, phase, level, time):
if level != self.LEVEL_BATCH:
return True
return self._is_active_on_phase(phase)
__all__ = ["PhaseWrapperCallback", "PhaseBatchWrapperCallback"]