Shortcuts

Source code for catalyst.callbacks.misc

from abc import ABC, abstractmethod

from tqdm.auto import tqdm

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.tools.metric_handler import MetricHandler
from catalyst.tools.time_manager import TimeManager
from catalyst.utils.misc import is_exception

EPS = 1e-8


class IBatchMetricHandlerCallback(ABC, Callback):
    """@TODO: docs"""

    def __init__(self, metric_key: str, minimize: bool = True, min_delta: float = 1e-6):
        """@TODO: docs"""
        super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
        self.is_better = MetricHandler(minimize=minimize, min_delta=min_delta)
        self.metric_key = metric_key
        self.best_score = None

    @abstractmethod
    def handle_score_is_better(self, runner: "IRunner"):
        """Event handler."""
        pass

    @abstractmethod
    def handle_score_is_not_better(self, runner: "IRunner"):
        """Event handler."""
        pass

    def on_loader_start(self, runner: "IRunner") -> None:
        """Event handler."""
        self.best_score = None

    def on_batch_end(self, runner: "IRunner") -> None:
        """Event handler."""
        score = runner.batch_metrics[self.metric_key]
        if self.best_score is None or self.is_better(score, self.best_score):
            self.best_score = score
            self.handle_score_is_better(runner=runner)
        else:
            self.handle_score_is_not_better(runner=runner)


class IEpochMetricHandlerCallback(ABC, Callback):
    """@TODO: docs"""

    def __init__(
        self, loader_key: str, metric_key: str, minimize: bool = True, min_delta: float = 1e-6,
    ):
        """@TODO: docs"""
        super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
        self.is_better = MetricHandler(minimize=minimize, min_delta=min_delta)
        self.loader_key = loader_key
        self.metric_key = metric_key
        self.best_score = None

    @abstractmethod
    def handle_score_is_better(self, runner: "IRunner"):
        """Event handler."""
        pass

    @abstractmethod
    def handle_score_is_not_better(self, runner: "IRunner"):
        """Event handler."""
        pass

    def on_stage_start(self, runner: "IRunner") -> None:
        """Event handler."""
        self.best_score = None

    def on_epoch_end(self, runner: "IRunner") -> None:
        """Event handler."""
        score = runner.epoch_metrics[self.loader_key][self.metric_key]
        if self.best_score is None or self.is_better(score, self.best_score):
            self.best_score = score
            self.handle_score_is_better(runner=runner)
        else:
            self.handle_score_is_not_better(runner=runner)


[docs]class EarlyStoppingCallback(IEpochMetricHandlerCallback): """Stage early stop based on metric. Args: patience: number of epochs with no improvement after which training will be stopped. loader_key: loader key for early stopping (based on metric score over the dataset) metric_key: metric key for early stopping (based on metric score over the dataset) minimize: if ``True`` then expected that metric should decrease and early stopping will be performed only when metric stops decreasing. If ``False`` then expected that metric should increase. Default value ``True``. min_delta: minimum change in the monitored metric to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement, default value is ``1e-6``. """
[docs] def __init__( self, patience: int, loader_key: str, metric_key: str, minimize: bool, min_delta: float = 1e-6, ): """Init.""" super().__init__( loader_key=loader_key, metric_key=metric_key, minimize=minimize, min_delta=min_delta, ) self.patience = patience self.num_no_improvement_epochs = 0
[docs] def handle_score_is_better(self, runner: "IRunner"): """Event handler.""" self.num_no_improvement_epochs = 0
[docs] def handle_score_is_not_better(self, runner: "IRunner"): """Event handler.""" self.num_no_improvement_epochs += 1 if self.num_no_improvement_epochs >= self.patience: # print(f"Early stop at {runner.epoch} epoch") runner.need_early_stop = True
[docs]class TimerCallback(Callback): """Logs pipeline execution time."""
[docs] def __init__(self): """Initialisation for TimerCallback.""" super().__init__(order=CallbackOrder.metric + 1, node=CallbackNode.all) self.timer = TimeManager()
def on_loader_start(self, runner: "IRunner") -> None: """Loader start hook. Args: runner: current runner """ self.timer.reset() self.timer.start("_timer/batch_time") self.timer.start("_timer/data_time") def on_loader_end(self, runner: "IRunner") -> None: """Loader end hook. Args: runner: current runner """ self.timer.reset() def on_batch_start(self, runner: "IRunner") -> None: """Batch start hook. Args: runner: current runner """ self.timer.stop("_timer/data_time") self.timer.start("_timer/model_time") def on_batch_end(self, runner: "IRunner") -> None: """Batch end hook. Args: runner: current runner """ self.timer.stop("_timer/model_time") self.timer.stop("_timer/batch_time") # @TODO: just a trick self.timer.elapsed["_timer/_fps"] = runner.batch_size / ( self.timer.elapsed["_timer/batch_time"] + EPS ) for key, value in self.timer.elapsed.items(): runner.batch_metrics[key] = value self.timer.reset() self.timer.start("_timer/batch_time") self.timer.start("_timer/data_time")
[docs]class TqdmCallback(Callback): """Logs the params into tqdm console.""" def __init__(self): super().__init__(order=CallbackOrder.external, node=CallbackNode.master) self.tqdm: tqdm = None self.step = 0 def on_loader_start(self, runner: "IRunner"): """Init tqdm progress bar.""" self.step = 0 self.tqdm = tqdm( total=runner.loader_batch_len, desc=f"{runner.stage_epoch_step}/{runner.stage_epoch_len}" f" * Epoch ({runner.loader_key})", # leave=True, # ncols=0, # file=sys.stdout, ) def on_batch_end(self, runner: "IRunner"): """Update tqdm progress bar at the end of each batch.""" batch_metrics = {k: float(v) for k, v in runner.batch_metrics.items()} self.tqdm.set_postfix( **{ k: "{:3.3f}".format(v) if v > 1e-3 else "{:1.3e}".format(v) for k, v in sorted(batch_metrics.items()) } ) self.tqdm.update() def on_loader_end(self, runner: "IRunner"): """Cleanup and close tqdm progress bar.""" # self.tqdm.visible = False # self.tqdm.leave = True # self.tqdm.disable = True self.tqdm.clear() self.tqdm.close() self.tqdm = None self.step = 0
[docs] def on_exception(self, runner: "IRunner"): """Called if an Exception was raised.""" exception = runner.exception if not is_exception(exception): return if isinstance(exception, KeyboardInterrupt): if self.tqdm is not None: self.tqdm.write("Keyboard Interrupt") self.tqdm.clear() self.tqdm.close() self.tqdm = None
[docs]class CheckRunCallback(Callback): """Executes only a pipeline part from the run. Args: num_batch_steps: number of batches to iterate in epoch num_epoch_steps: number of epoch to perform in a stage Minimal working example (Notebook API): .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst import dl # data num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, 1) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) # model training runner = dl.SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=8, verbose=True, callbacks=[ dl.CheckRunCallback(num_batch_steps=3, num_epoch_steps=3) ] ) """
[docs] def __init__(self, num_batch_steps: int = 3, num_epoch_steps: int = 3): """Init.""" super().__init__(order=CallbackOrder.external, node=CallbackNode.all) self.num_batch_steps = num_batch_steps self.num_epoch_steps = num_epoch_steps
def on_epoch_end(self, runner: "IRunner"): """Check if iterated specified number of epochs. Args: runner: current runner """ if runner.stage_epoch_step >= self.num_epoch_steps: runner.need_early_stop = True def on_batch_end(self, runner: "IRunner"): """Check if iterated specified number of batches. Args: runner: current runner """ if runner.loader_batch_step >= self.num_batch_steps: runner.need_early_stop = True
__all__ = [ "TimerCallback", "TqdmCallback", "CheckRunCallback", "IBatchMetricHandlerCallback", "IEpochMetricHandlerCallback", "EarlyStoppingCallback", ]