Shortcuts

Source code for catalyst.dl.callbacks.phase

from typing import List
from collections import OrderedDict

from catalyst.core import Callback, CallbackNode, CallbackOrder, State


class Phase:
    """
    Class for storing information about certain phase, including
    - phase name
    - number of steps (batches) in phase before next phase is chosen
    - how many steps (batches) are done already
    """

    def __init__(self, name: str = None, steps: int = None):
        """
        Args:
            @TODO: Docs. Contribution is welcome
        """
        self.steps = int(steps) if steps is not None else None
        self.curr_step = 0
        self.name = name


class PhaseManager:
    """Class for storing & managing all phases in experiment configuration.

    Stores separately current phases in train & validation modes.

    By calling ``.step(...)`` method current phase is updated by step-size
    and if current phase is finished, the next phase becomes current.
    """

    def __init__(self, train_phases: List[Phase], valid_phases: List[Phase]):
        """
        Args:
            @TODO: Docs. Contribution is welcome
        """
        self.train_phases = train_phases
        self.valid_phases = valid_phases

        self.train_index = 0
        self.valid_index = 0

    def step(self, state: State, step_size: int = 1):
        """@TODO: Docs. Contribution is welcome."""
        if state.is_train_loader:
            if len(self.train_phases) > 1:
                phase = self.train_phases[self.train_index]
                phase.curr_step += step_size
                if phase.curr_step >= phase.steps:
                    phase.curr_step = 0
                    self.train_index = (self.train_index + 1) % len(
                        self.train_phases
                    )
        else:
            if len(self.valid_phases) > 1:
                phase = self.valid_phases[self.valid_index]
                phase.curr_step += step_size
                if phase.curr_step >= phase.steps:
                    phase.curr_step = 0
                    self.valid_index = (self.valid_index + 1) % len(
                        self.valid_phases
                    )

    def get_phase_name(self, state: State):
        """@TODO: Docs. Contribution is welcome."""
        if state.is_train_loader:
            return self.train_phases[self.train_index].name
        return self.valid_phases[self.valid_index].name


[docs]class PhaseManagerCallback(Callback): """Callback to update ``state.phase``.""" VALIDATION_MODE_ALL = "all" # (in validation) use all callbacks VALIDATION_MODE_SAME = "same" # (in validation) same phases as in training allowed_valid_modes = [VALIDATION_MODE_SAME, VALIDATION_MODE_ALL]
[docs] def __init__( self, train_phases: "OrderedDict[str, int]" = None, valid_phases: "OrderedDict[str, int]" = None, valid_mode: str = None, ): """ Args: @TODO: Docs. Contribution is welcome """ super().__init__(order=CallbackOrder.Internal, node=CallbackNode.All) self.phase_manager = self._get_phase_manager( train_phases=train_phases, valid_phases=valid_phases, valid_mode=valid_mode, )
def _get_phase_manager( self, train_phases: "OrderedDict[str, int]" = None, valid_phases: "OrderedDict[str, int]" = None, valid_mode: str = None, ): assert (valid_phases is None) ^ ( valid_mode is None ), "Exactly one of them must be specified" if train_phases is None: train_phases = [Phase(name=None, steps=None)] else: train_phases = [ Phase(name=name, steps=steps) for name, steps in train_phases.items() ] if valid_phases is None: if valid_mode == self.VALIDATION_MODE_ALL: valid_phases = [Phase(name=None, steps=None)] elif valid_mode == self.VALIDATION_MODE_SAME: valid_phases = [ Phase(name=p.name, steps=p.steps) for p in train_phases ] else: raise ValueError( f"Unsupported validation_mode, should be one of " f"{self.allowed_valid_modes}" ) return PhaseManager( train_phases=train_phases, valid_phases=valid_phases )
[docs] def on_batch_start(self, state: State): """Batch start hook. Args: state (State): current state """ state.phase = self.phase_manager.get_phase_name(state)
[docs] def on_batch_end(self, state: State): """Batch end hook. Args: state (State): current state """ self.phase_manager.step(state)
__all__ = ["PhaseManagerCallback"]