from typing import Callable, Mapping, Sequence, Union
from collections import OrderedDict
from catalyst.core.callback import Callback, WrapperCallback
from catalyst.core.runner import IRunner
LOADERS = Union[str, Sequence[str], Mapping[str, Union[int, Sequence[int]]]]
FILTER_FN = Callable[[str, int, str], bool]
def _filter_fn_from_epochs(
epochs: Union[int, float, Sequence[int]], reverse_condition: bool
) -> FILTER_FN:
"""Build ``filter_fn`` from epochs for ``ControlFlowCallback``
Args:
epochs (int/Sequence[int]): epochs description
reverse_condition (bool): indicator to use reversed
condition in filter function
Raises:
ValueError: if passed object with unexpected type
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(epochs, (int, float)):
epochs = int(epochs)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch % epochs != 0
else:
filter_fn = lambda stage, epoch, loader: epoch % epochs == 0
elif isinstance(epochs, (list, tuple)):
epochs = sorted(set(epochs))
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in epochs
else:
filter_fn = lambda stage, epoch, loader: epoch in epochs
else:
raise ValueError(
"'epochs' should be int/float/Sequence[int]! "
f"(got {type(epochs)})"
)
return filter_fn
def _filter_fn_from_loaders(
loaders: LOADERS, reverse_condition: bool
) -> FILTER_FN:
"""Build ``filter_fn`` from loaders for ``ControlFlowCallback``.
Args:
loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
loaders description
reverse_condition (bool): indicator to use reversed
condition in filter function
Raises:
ValueError: if can't build filter_fn from mappings
ValueError: if passed object with unexpected type
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(loaders, str):
loaders = [loaders]
# sequence of loaders
if isinstance(loaders, (list, tuple)):
loaders = sorted(set(loaders)) # ignore duplicates
if reverse_condition:
filter_fn = lambda stage, epoch, loader: loader not in loaders
else:
filter_fn = lambda stage, epoch, loader: loader in loaders
# loader: ignore epoch or epochs
elif isinstance(loaders, (dict, OrderedDict)):
ignore_list = {}
for loader, epochs in loaders.items():
if isinstance(epochs, (int, float)):
ignore_list[loader] = [int(epochs)]
else:
try:
ignore_list[loader] = []
for num in sorted(set(epochs)):
to_add = int(num)
ignore_list[loader].append(to_add)
except (ValueError, TypeError):
raise ValueError(
"'ignore_list' should be a dict where "
"keys is a int/float/List[int]/Tuple[int]!"
)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in (
ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader)
)
else:
filter_fn = lambda stage, epoch, loader: epoch in (
ignore_list.get(loader) or {}
)
else:
raise ValueError(
"'loaders' type should be one of - str, "
"Sequence[str], Mapping[str, int] or "
"Mapping[str, Sequence[int]]! "
f"(got {type(loaders)})"
)
return filter_fn
def _filter_fn_from_arg(filter_fn: Union[str, FILTER_FN]) -> FILTER_FN:
"""Check if filter function from argumets
can be used with ``ControlFlowCallback``.
Args:
filter_fn (str or Callable): filter function to check
Raises:
ValueError: if ``filter_fn`` is a string and can not be
interpreted as python code then an error will be raised
ValueError: if passed not callable object then will be
raised an error
ValueError: will be raised error if filter function do not
have three arguments
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(filter_fn, str):
# lambda function from string
try:
filter_fn = eval(filter_fn) # noqa: WPS400
except (ValueError, SyntaxError):
raise ValueError(
"'filter_fn' should be a valid "
"python lambda function with "
"three arguments - 'stage', 'epoch' and 'loader'!"
)
if not callable(filter_fn):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - "
"'stage', 'epoch' and 'loader'!"
)
return filter_fn
[docs]class ControlFlowCallback(WrapperCallback):
"""Customize callback execution on different
stages, loaders and epochs.
For example, if you don't want to compute loss on a validation
you can ignore ``CriterionCallback``, for notebook API
need to wrap callback:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.dl import (
SupervisedRunner, AccuracyCallback,
CriterionCallback, ControlFlowCallback,
)
num_samples, num_features = 10_000, 10
n_classes = 10
X = torch.rand(num_samples, num_features)
y = torch.randint(0, n_classes, [num_samples])
loader = DataLoader(TensorDataset(X, y), batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
model = torch.nn.Linear(num_features, n_classes)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
num_epochs=5,
verbose=False,
main_metric="accuracy03",
minimize_metric=False,
callbacks=[
AccuracyCallback(
accuracy_args=[1, 3, 5]
),
ControlFlowCallback(
base_callback=CriterionCallback(),
ignore_loaders="valid" # or loaders="train"
)
]
)
In config API need to use ``_wrapper`` argument:
.. code-block:: yaml
callbacks_params:
...
loss:
_wrapper:
callback: ControlFlowCallback
ignore_loaders: valid
callback: CriterionCallback
...
"""
[docs] def __init__(
self,
base_callback: Callback,
epochs: Union[int, Sequence[int]] = None,
ignore_epochs: Union[int, Sequence[int]] = None,
loaders: LOADERS = None,
ignore_loaders: LOADERS = None,
filter_fn: Union[str, FILTER_FN] = None,
use_global_epochs: bool = False,
):
"""
Args:
base_callback (Callback): callback to wrap
epochs (int/Sequence[int]): epochs numbers where
need to execute callback, on other epochs
callback will be disabled.
If passed int/float then callback will be enabled
with period specified as epochs value
(epochs expression ``epoch_number % epochs != 0``)
and disabled on other epochs.
If passed list of epochs then will be executed callback
on specified epochs.
Default value is ``None``.
ignore_epochs: (int/Sequence[int]): epochs numbers where
need to disable callback, on other epochs
callback will be enabled.
If passed int/float then callback will be disabled
with period specified as epochs value
(epochs expression ``epoch_number % epochs == 0``)
and enabled on other epochs.
If passed list of epochs then will be disabled callback
on specified epochs.
Default value is ``None``.
loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
loader names where should be enabled callback, on
other loaders callback will be disabled.
If passed string object then will be disabled callback for
loader with specified name.
If passed list/tuple of strings then will be disabled callback
for loaders with specified names.
If passed dictionary where key is a string and values
int or list of integers then callback will be
disabled on epochs (dictionary value) for specified
loader (dictionary key).
Default value is ``None``.
ignore_loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
loader names where should be disabled callback, on
other loaders callback will be enabled.
If passed string object then will be disabled callback for
loader with specified name.
If passed list/tuple of strings then will be disabled callback
for loaders with specified names.
If passed dictionary where key is a string and values
int or list of integers then callback will be
disabled on epochs (dictionary value) for specified
loader (dictionary key).
Default value is ``None``.
filter_fn (str or Callable[[str, int, str], bool]):
function to use instead of ``loaders`` or ``epochs`` arguments.
If the object passed to a ``filter_fn`` is a string
then it will be interpreted as python code. Expected
lambda function with three arguments stage name (str),
epoch number (int), loader name (str) and this function
should return ``True`` if callback should be disabled
on some condition.
If passed callable object then it should accept
three arguments - stage name (str), epoch number (int),
loader name (str) and should return ``True`` if callback
should be disabled on some condition othervise should
return ``False``.
Default value is ``None``.
use_global_epochs (bool): if ``True`` then
will be used global epochs instead of epochs in
a stage, the default value is ``False``
"""
required_args = (
epochs,
ignore_epochs,
loaders,
ignore_loaders,
filter_fn,
)
if all(arg is None for arg in required_args):
raise ValueError(
"Expected one of arguments - "
"'epochs', 'ignore_epochs', "
"'loaders', 'ignore_loaders' "
"or 'filter_fn'!"
)
super().__init__(base_callback, True)
self.use_global_epochs = use_global_epochs
# loader parameters
self.filter_fn = None
if epochs is not None:
self.filter_fn = _filter_fn_from_epochs(epochs, True)
elif ignore_epochs is not None:
self.filter_fn = _filter_fn_from_epochs(ignore_epochs, False)
elif loaders is not None:
self.filter_fn = _filter_fn_from_loaders(loaders, True)
elif ignore_loaders is not None:
self.filter_fn = _filter_fn_from_loaders(ignore_loaders, False)
elif filter_fn is not None:
self.filter_fn = _filter_fn_from_arg(filter_fn)
[docs] def on_loader_start(self, runner: IRunner) -> None:
"""
Check if current epoch should be skipped.
Args:
runner (IRunner): current runner
"""
stage = runner.stage_name
loader = runner.loader_name
epoch = runner.global_epoch if self.use_global_epochs else runner.epoch
if self.filter_fn is not None:
self._is_enabled = not self.filter_fn(stage, epoch, loader)
if self._is_enabled:
self.callback.on_loader_start(runner)
[docs] def on_loader_end(self, runner: IRunner) -> None:
"""
Reset status of callback
Args:
runner (IRunner): current runner
"""
if self._is_enabled:
self.callback.on_loader_end(runner)
self._is_enabled = True
__all__ = ["ControlFlowCallback"]