Source code for catalyst.dl.callbacks.phase
from typing import List
from collections import OrderedDict
from catalyst.core import Callback, CallbackNode, CallbackOrder, State
class Phase:
"""
Class for storing information about certain phase, including
- phase name
- number of steps (batches) in phase before next phase is chosen
- how many steps (batches) are done already
"""
def __init__(self, name: str = None, steps: int = None):
"""
Args:
@TODO: Docs. Contribution is welcome
"""
self.steps = int(steps) if steps is not None else None
self.curr_step = 0
self.name = name
class PhaseManager:
"""Class for storing & managing all phases in experiment configuration.
Stores separately current phases in train & validation modes.
By calling ``.step(...)`` method current phase is updated by step-size
and if current phase is finished, the next phase becomes current.
"""
def __init__(self, train_phases: List[Phase], valid_phases: List[Phase]):
"""
Args:
@TODO: Docs. Contribution is welcome
"""
self.train_phases = train_phases
self.valid_phases = valid_phases
self.train_index = 0
self.valid_index = 0
def step(self, state: State, step_size: int = 1):
"""@TODO: Docs. Contribution is welcome."""
if state.is_train_loader:
if len(self.train_phases) > 1:
phase = self.train_phases[self.train_index]
phase.curr_step += step_size
if phase.curr_step >= phase.steps:
phase.curr_step = 0
self.train_index = (self.train_index + 1) % len(
self.train_phases
)
else:
if len(self.valid_phases) > 1:
phase = self.valid_phases[self.valid_index]
phase.curr_step += step_size
if phase.curr_step >= phase.steps:
phase.curr_step = 0
self.valid_index = (self.valid_index + 1) % len(
self.valid_phases
)
def get_phase_name(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.is_train_loader:
return self.train_phases[self.train_index].name
return self.valid_phases[self.valid_index].name
[docs]class PhaseManagerCallback(Callback):
"""Callback to update ``state.phase``."""
VALIDATION_MODE_ALL = "all" # (in validation) use all callbacks
VALIDATION_MODE_SAME = "same" # (in validation) same phases as in training
allowed_valid_modes = [VALIDATION_MODE_SAME, VALIDATION_MODE_ALL]
[docs] def __init__(
self,
train_phases: "OrderedDict[str, int]" = None,
valid_phases: "OrderedDict[str, int]" = None,
valid_mode: str = None,
):
"""
Args:
@TODO: Docs. Contribution is welcome
"""
super().__init__(order=CallbackOrder.Internal, node=CallbackNode.All)
self.phase_manager = self._get_phase_manager(
train_phases=train_phases,
valid_phases=valid_phases,
valid_mode=valid_mode,
)
def _get_phase_manager(
self,
train_phases: "OrderedDict[str, int]" = None,
valid_phases: "OrderedDict[str, int]" = None,
valid_mode: str = None,
):
assert (valid_phases is None) ^ (
valid_mode is None
), "Exactly one of them must be specified"
if train_phases is None:
train_phases = [Phase(name=None, steps=None)]
else:
train_phases = [
Phase(name=name, steps=steps)
for name, steps in train_phases.items()
]
if valid_phases is None:
if valid_mode == self.VALIDATION_MODE_ALL:
valid_phases = [Phase(name=None, steps=None)]
elif valid_mode == self.VALIDATION_MODE_SAME:
valid_phases = [
Phase(name=p.name, steps=p.steps) for p in train_phases
]
else:
raise ValueError(
f"Unsupported validation_mode, should be one of "
f"{self.allowed_valid_modes}"
)
return PhaseManager(
train_phases=train_phases, valid_phases=valid_phases
)
[docs] def on_batch_start(self, state: State):
"""Batch start hook.
Args:
state (State): current state
"""
state.phase = self.phase_manager.get_phase_name(state)
[docs] def on_batch_end(self, state: State):
"""Batch end hook.
Args:
state (State): current state
"""
self.phase_manager.step(state)
__all__ = ["PhaseManagerCallback"]