Shortcuts

Source code for catalyst.core.callbacks.control_flow

from typing import Callable, Mapping, Sequence, Union
from collections import OrderedDict

from catalyst.core.callback import Callback, WrapperCallback
from catalyst.core.runner import IRunner

LOADERS = Union[str, Sequence[str], Mapping[str, Union[int, Sequence[int]]]]
FILTER_FN = Callable[[str, int, str], bool]


def _filter_fn_from_epochs(
    epochs: Union[int, float, Sequence[int]], reverse_condition: bool
) -> FILTER_FN:
    """Build ``filter_fn`` from epochs for ``ControlFlowCallback``

    Args:
        epochs (int/Sequence[int]): epochs description
        reverse_condition (bool): indicator to use reversed
            condition in filter function

    Raises:
        ValueError: if passed object with unexpected type

    Returns:
        filter function which accepts 3 arguments - stage (str),
        epoch (int), loader (str) and return ``True`` if
        need to disable callback
    """
    if isinstance(epochs, (int, float)):
        epochs = int(epochs)
        if reverse_condition:
            filter_fn = lambda stage, epoch, loader: epoch % epochs != 0
        else:
            filter_fn = lambda stage, epoch, loader: epoch % epochs == 0
    elif isinstance(epochs, (list, tuple)):
        epochs = sorted(set(epochs))
        if reverse_condition:
            filter_fn = lambda stage, epoch, loader: epoch not in epochs
        else:
            filter_fn = lambda stage, epoch, loader: epoch in epochs
    else:
        raise ValueError(
            "'epochs' should be int/float/Sequence[int]! "
            f"(got {type(epochs)})"
        )
    return filter_fn


def _filter_fn_from_loaders(
    loaders: LOADERS, reverse_condition: bool
) -> FILTER_FN:
    """Build ``filter_fn`` from loaders for ``ControlFlowCallback``.

    Args:
        loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
            loaders description
        reverse_condition (bool): indicator to use reversed
            condition in filter function

    Raises:
        ValueError: if can't build filter_fn from mappings
        ValueError: if passed object with unexpected type

    Returns:
        filter function which accepts 3 arguments - stage (str),
        epoch (int), loader (str) and return ``True`` if
        need to disable callback
    """
    if isinstance(loaders, str):
        loaders = [loaders]

    # sequence of loaders
    if isinstance(loaders, (list, tuple)):
        loaders = sorted(set(loaders))  # ignore duplicates
        if reverse_condition:
            filter_fn = lambda stage, epoch, loader: loader not in loaders
        else:
            filter_fn = lambda stage, epoch, loader: loader in loaders
    # loader: ignore epoch or epochs
    elif isinstance(loaders, (dict, OrderedDict)):
        ignore_list = {}
        for loader, epochs in loaders.items():
            if isinstance(epochs, (int, float)):
                ignore_list[loader] = [int(epochs)]
            else:
                try:
                    ignore_list[loader] = []
                    for num in sorted(set(epochs)):
                        to_add = int(num)
                        ignore_list[loader].append(to_add)
                except (ValueError, TypeError):
                    raise ValueError(
                        "'ignore_list' should be a dict where "
                        "keys is a int/float/List[int]/Tuple[int]!"
                    )
        if reverse_condition:
            filter_fn = lambda stage, epoch, loader: epoch not in (
                ignore_list.get(loader) or {}  # {loader: [epoch]}.get(loader)
            )
        else:
            filter_fn = lambda stage, epoch, loader: epoch in (
                ignore_list.get(loader) or {}
            )
    else:
        raise ValueError(
            "'loaders' type should be one of - str, "
            "Sequence[str], Mapping[str, int] or "
            "Mapping[str, Sequence[int]]! "
            f"(got {type(loaders)})"
        )
    return filter_fn


def _filter_fn_from_arg(filter_fn: Union[str, FILTER_FN]) -> FILTER_FN:
    """Check if filter function from argumets
    can be used with ``ControlFlowCallback``.

    Args:
        filter_fn (str or Callable): filter function to check

    Raises:
        ValueError: if ``filter_fn`` is a string and can not be
            interpreted as python code then an error will be raised
        ValueError: if passed not callable object then will be
            raised an error
        ValueError: will be raised error if filter function do not
            have three arguments

    Returns:
        filter function which accepts 3 arguments - stage (str),
        epoch (int), loader (str) and return ``True`` if
        need to disable callback
    """
    if isinstance(filter_fn, str):
        # lambda function from string
        try:
            filter_fn = eval(filter_fn)  # noqa: WPS400
        except (ValueError, SyntaxError):
            raise ValueError(
                "'filter_fn' should be a valid "
                "python lambda function with "
                "three arguments - 'stage', 'epoch' and 'loader'!"
            )
    if not callable(filter_fn):
        raise ValueError("'filter_fn' should be a callable!")
    if filter_fn.__code__.co_argcount != 3:
        raise ValueError(
            "Filter function should have three arguments - "
            "'stage', 'epoch' and 'loader'!"
        )
    return filter_fn


[docs]class ControlFlowCallback(WrapperCallback): """Customize callback execution on different stages, loaders and epochs. For example, if you don't want to compute loss on a validation you can ignore ``CriterionCallback``, for notebook API need to wrap callback: .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.dl import ( SupervisedRunner, AccuracyCallback, CriterionCallback, ControlFlowCallback, ) num_samples, num_features = 10_000, 10 n_classes = 10 X = torch.rand(num_samples, num_features) y = torch.randint(0, n_classes, [num_samples]) loader = DataLoader(TensorDataset(X, y), batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} model = torch.nn.Linear(num_features, n_classes) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) runner = SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=5, verbose=False, main_metric="accuracy03", minimize_metric=False, callbacks=[ AccuracyCallback( accuracy_args=[1, 3, 5] ), ControlFlowCallback( base_callback=CriterionCallback(), ignore_loaders="valid" # or loaders="train" ) ] ) In config API need to use ``_wrapper`` argument: .. code-block:: yaml callbacks_params: ... loss: _wrapper: callback: ControlFlowCallback ignore_loaders: valid callback: CriterionCallback ... """
[docs] def __init__( self, base_callback: Callback, epochs: Union[int, Sequence[int]] = None, ignore_epochs: Union[int, Sequence[int]] = None, loaders: LOADERS = None, ignore_loaders: LOADERS = None, filter_fn: Union[str, FILTER_FN] = None, use_global_epochs: bool = False, ): """ Args: base_callback (Callback): callback to wrap epochs (int/Sequence[int]): epochs numbers where need to execute callback, on other epochs callback will be disabled. If passed int/float then callback will be enabled with period specified as epochs value (epochs expression ``epoch_number % epochs != 0``) and disabled on other epochs. If passed list of epochs then will be executed callback on specified epochs. Default value is ``None``. ignore_epochs: (int/Sequence[int]): epochs numbers where need to disable callback, on other epochs callback will be enabled. If passed int/float then callback will be disabled with period specified as epochs value (epochs expression ``epoch_number % epochs == 0``) and enabled on other epochs. If passed list of epochs then will be disabled callback on specified epochs. Default value is ``None``. loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]): loader names where should be enabled callback, on other loaders callback will be disabled. If passed string object then will be disabled callback for loader with specified name. If passed list/tuple of strings then will be disabled callback for loaders with specified names. If passed dictionary where key is a string and values int or list of integers then callback will be disabled on epochs (dictionary value) for specified loader (dictionary key). Default value is ``None``. ignore_loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]): loader names where should be disabled callback, on other loaders callback will be enabled. If passed string object then will be disabled callback for loader with specified name. If passed list/tuple of strings then will be disabled callback for loaders with specified names. If passed dictionary where key is a string and values int or list of integers then callback will be disabled on epochs (dictionary value) for specified loader (dictionary key). Default value is ``None``. filter_fn (str or Callable[[str, int, str], bool]): function to use instead of ``loaders`` or ``epochs`` arguments. If the object passed to a ``filter_fn`` is a string then it will be interpreted as python code. Expected lambda function with three arguments stage name (str), epoch number (int), loader name (str) and this function should return ``True`` if callback should be disabled on some condition. If passed callable object then it should accept three arguments - stage name (str), epoch number (int), loader name (str) and should return ``True`` if callback should be disabled on some condition othervise should return ``False``. Default value is ``None``. use_global_epochs (bool): if ``True`` then will be used global epochs instead of epochs in a stage, the default value is ``False`` """ required_args = ( epochs, ignore_epochs, loaders, ignore_loaders, filter_fn, ) if all(arg is None for arg in required_args): raise ValueError( "Expected one of arguments - " "'epochs', 'ignore_epochs', " "'loaders', 'ignore_loaders' " "or 'filter_fn'!" ) super().__init__(base_callback, True) self.use_global_epochs = use_global_epochs # loader parameters self.filter_fn = None if epochs is not None: self.filter_fn = _filter_fn_from_epochs(epochs, True) elif ignore_epochs is not None: self.filter_fn = _filter_fn_from_epochs(ignore_epochs, False) elif loaders is not None: self.filter_fn = _filter_fn_from_loaders(loaders, True) elif ignore_loaders is not None: self.filter_fn = _filter_fn_from_loaders(ignore_loaders, False) elif filter_fn is not None: self.filter_fn = _filter_fn_from_arg(filter_fn)
[docs] def on_loader_start(self, runner: IRunner) -> None: """ Check if current epoch should be skipped. Args: runner (IRunner): current runner """ stage = runner.stage_name loader = runner.loader_name epoch = runner.global_epoch if self.use_global_epochs else runner.epoch if self.filter_fn is not None: self._is_enabled = not self.filter_fn(stage, epoch, loader) if self._is_enabled: self.callback.on_loader_start(runner)
[docs] def on_loader_end(self, runner: IRunner) -> None: """ Reset status of callback Args: runner (IRunner): current runner """ if self._is_enabled: self.callback.on_loader_end(runner) self._is_enabled = True
__all__ = ["ControlFlowCallback"]