Source code for catalyst.core.callbacks.phase
from typing import List  # isort:skip
from collections import OrderedDict
from catalyst.core import Callback, CallbackOrder
class Phase:
    """
    Class for storing information about certain phase, including
    - phase name
    - number of steps (batches) in phase before next phase is chosen
    - how many steps (batches) are done already
    """
    def __init__(self, name: str = None, steps: int = None):
        self.steps = int(steps) if steps is not None else None
        self.curr_step = 0
        self.name = name
class PhaseManager:
    """
    Class for storing & managing all phases in experiment configuration
    Stores separately current phases in train & validation modes
    By calling `.step(...)` method current phase is updated by step-size
    and if current phase is finished, the next phase becomes current
    """
    def __init__(
        self,
        train_phases: List[Phase],
        valid_phases: List[Phase]
    ):
        self.train_phases = train_phases
        self.valid_phases = valid_phases
        self.train_index = 0
        self.valid_index = 0
    def step(self, state, step_size=1):
        if state.need_backward:
            if len(self.train_phases) > 1:
                phase = self.train_phases[self.train_index]
                phase.curr_step += step_size
                if phase.curr_step >= phase.steps:
                    phase.curr_step = 0
                    self.train_index = \
                        (self.train_index + 1) % len(self.train_phases)
        else:
            if len(self.valid_phases) > 1:
                phase = self.valid_phases[self.valid_index]
                phase.curr_step += step_size
                if phase.curr_step >= phase.steps:
                    phase.curr_step = 0
                    self.valid_index = \
                        (self.valid_index + 1) % len(self.valid_phases)
    def get_phase_name(self, state):
        if state.need_backward:
            return self.train_phases[self.train_index].name
        return self.valid_phases[self.valid_index].name
[docs]class PhaseManagerCallback(Callback):
    """
    PhaseManagerCallback updates state.phase
    """
    VALIDATION_MODE_ALL = "all"  # (in validation) use all callbacks
    VALIDATION_MODE_SAME = "same"  # (in validation) same phases as in training
    allowed_valid_modes = [VALIDATION_MODE_SAME, VALIDATION_MODE_ALL]
    def __init__(
        self,
        train_phases: "OrderedDict[str, int]" = None,
        valid_phases: "OrderedDict[str, int]" = None,
        valid_mode: str = None
    ):
        super().__init__(CallbackOrder.Internal)
        self.phase_manager = self._get_phase_manager(
            train_phases=train_phases,
            valid_phases=valid_phases,
            valid_mode=valid_mode
        )
    def _get_phase_manager(
        self,
        train_phases: "OrderedDict[str, int]" = None,
        valid_phases: "OrderedDict[str, int]" = None,
        valid_mode: str = None
    ):
        assert (valid_phases is None) ^ (valid_mode is None), \
            
"Exactly one of them must be specified"
        if train_phases is None:
            train_phases = [Phase(name=None, steps=None)]
        else:
            train_phases = [
                Phase(name=name, steps=steps)
                for name, steps in train_phases.items()
            ]
        if valid_phases is None:
            if valid_mode == self.VALIDATION_MODE_ALL:
                valid_phases = [Phase(name=None, steps=None)]
            elif valid_mode == self.VALIDATION_MODE_SAME:
                valid_phases = [
                    Phase(name=p.name, steps=p.steps)
                    for p in train_phases
                ]
            else:
                raise ValueError(
                    f"Unsupported validation_mode, should be one of "
                    f"{self.allowed_valid_modes}"
                )
        return PhaseManager(
            train_phases=train_phases,
            valid_phases=valid_phases
        )
[docs]    def on_batch_start(self, state):
        state.phase = self.phase_manager.get_phase_name(state) 
[docs]    def on_batch_end(self, state):
        self.phase_manager.step(state)  
__all__ = ["PhaseManagerCallback"]