Shortcuts

Source code for catalyst.core.runner

from typing import Any, Dict, Iterable, Mapping, Optional, Tuple
from abc import ABC, abstractmethod
from collections import defaultdict, OrderedDict
from functools import lru_cache
import logging

import torch
import torch.distributed
import torch.multiprocessing
from torch.utils.data import DataLoader, Dataset, DistributedSampler

from catalyst.core.callback import Callback, ICallback
from catalyst.core.engine import IEngine
from catalyst.core.logger import ILogger
from catalyst.core.misc import filter_callbacks_by_node, sort_callbacks_by_order, validate_loaders
from catalyst.core.trial import ITrial
from catalyst.typing import (
    Criterion,
    Device,
    Model,
    Optimizer,
    RunnerCriterion,
    RunnerModel,
    RunnerOptimizer,
    RunnerScheduler,
    Sampler,
    Scheduler,
)
from catalyst.utils.distributed import ddp_sync_run
from catalyst.utils.misc import maybe_recursive_call, set_global_seed

LOGGER = logging.getLogger(__name__)


BATCH_METRICS = Dict[str, float]
LOADER_METRICS = Dict[str, BATCH_METRICS]
EPOCH_METRICS = Dict[str, LOADER_METRICS]


@lru_cache(maxsize=42)
def _has_str_intersections(origin_string: str, strings: Tuple):
    return any(x in origin_string for x in strings)


def _get_batch_size(loader: DataLoader):
    batch_size = loader.batch_size
    if batch_size is not None:
        return batch_size

    batch_size = loader.batch_sampler.batch_size
    if batch_size is not None:
        return batch_size
    raise NotImplementedError(
        "No `batch_size` found,"
        "please specity it throught `loader.batch_size`, or `loader.batch_sampler.batch_size`"
    )


[docs]class RunnerException(Exception): """Exception class for all runner errors.""" pass
[docs]class IRunner(ICallback, ILogger, ABC): """ An abstraction that contains all the logic of how to run the experiment, stages, epochs, loaders and batches. IRunner supports the logic for deep learning pipeline configuration with pure python code. Please check the examples for intuition. Args: model: Torch model object engine: IEngine instance Abstraction, please check out implementations for more details: - :py:mod:`catalyst.runners.runner.Runner` - :py:mod:`catalyst.runners.config.ConfigRunner` - :py:mod:`catalyst.runners.hydra.HydraRunner` .. note:: To learn more about Catalyst Core concepts, please check out - :py:mod:`catalyst.core.runner.IRunner` - :py:mod:`catalyst.core.engine.IEngine` - :py:mod:`catalyst.core.callback.Callback` .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples Examples: .. code-block:: python import os from torch import nn, optim from torch.utils.data import DataLoader from catalyst import dl, utils from catalyst.contrib.datasets import MNIST from catalyst.data import ToTensor class CustomRunner(dl.IRunner): def __init__(self, logdir, device): super().__init__() self._logdir = logdir self._device = device def get_engine(self): return dl.DeviceEngine(self._device) def get_loggers(self): return { "console": dl.ConsoleLogger(), "csv": dl.CSVLogger(logdir=self._logdir), "tensorboard": dl.TensorboardLogger(logdir=self._logdir), } @property def stages(self): return ["train_freezed", "train_unfreezed"] def get_stage_len(self, stage: str) -> int: return 3 def get_loaders(self, stage: str): loaders = { "train": DataLoader( MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32 ), "valid": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32 ), } return loaders def get_model(self, stage: str): model = ( self.model if self.model is not None else nn.Sequential( nn.Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) ) if stage == "train_freezed": # freeze layer utils.set_requires_grad(model[1], False) else: utils.set_requires_grad(model, True) return model def get_criterion(self, stage: str): return nn.CrossEntropyLoss() def get_optimizer(self, stage: str, model): if stage == "train_freezed": return optim.Adam(model.parameters(), lr=1e-3) else: return optim.SGD(model.parameters(), lr=1e-1) def get_scheduler(self, stage: str, optimizer): return None def get_callbacks(self, stage: str): return { "criterion": dl.CriterionCallback( metric_key="loss", input_key="logits", target_key="targets" ), "optimizer": dl.OptimizerCallback(metric_key="loss"), "accuracy": dl.AccuracyCallback( input_key="logits", target_key="targets", topk_args=(1, 3, 5) ), "classification": dl.PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=10 ), "checkpoint": dl.CheckpointCallback( self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3, ), } def handle_batch(self, batch): x, y = batch logits = self.model(x) self.batch = { "features": x, "targets": y, "logits": logits, } runner = CustomRunner("./logs", "cpu") runner.run() """ def __init__( self, model: RunnerModel = None, engine: IEngine = None, ): """Init.""" # the core self.model: RunnerModel = model self.engine: IEngine = engine self.trial: ITrial = None # the data self.loaders: Dict[str, DataLoader] = None # the components self.criterion: RunnerCriterion = None self.optimizer: RunnerOptimizer = None self.scheduler: RunnerScheduler = None # the callbacks self.callbacks: Dict[str, Callback] = {} # the loggers self.loggers: Dict[str, ILogger] = {} # the dataflow - model input/output and other batch tensors self.batch: Dict[str, torch.Tensor] = None # metrics flow - batch, loader and epoch metrics self.batch_metrics: BATCH_METRICS = defaultdict(None) self.loader_metrics: LOADER_METRICS = defaultdict(None) self.epoch_metrics: EPOCH_METRICS = defaultdict(None) # experiment info self.run_key: str = None self.global_epoch_step: int = 0 self.global_batch_step: int = 0 self.global_sample_step: int = 0 # stage info self.stage_key: str = "infer" self.is_infer_stage: bool = self.stage_key.startswith("infer") self.stage_epoch_len: int = 0 self.stage_epoch_step: int = 0 self.stage_batch_step: int = 0 self.stage_sample_step: int = 0 # loader info self.loader: DataLoader = None self.loader_key: str = None self.is_train_loader: bool = False self.is_valid_loader: bool = False self.is_infer_loader: bool = True self.loader_batch_size: int = 0 self.loader_batch_len: int = 0 self.loader_sample_len: int = 0 self.loader_batch_step: int = 0 self.loader_sample_step: int = 0 # batch info self.batch_size: int = 0 # extra self.exception: Exception = None self.need_early_stop: bool = False self._stage_rank: int = -1 self._stage_world_size: int = -1 # @TODO: remove hotfix @property def device(self) -> Device: """Returns the runner's device instance.""" return self.engine.device @property def seed(self) -> int: """Experiment's seed for reproducibility.""" return 42 @property def name(self) -> str: """Returns run name for monitoring tools.""" return "IRunner" @property def hparams(self) -> OrderedDict: """ Returns hyper-parameters for current run. Example:: >>> runner.hparams OrderedDict([('optimizer', 'Adam'), ('lr', 0.02), ('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False), ('train_batch_size', 32)]) Returns: dictionary with hyperparameters """ return {} @property def _log_defaults(self) -> Dict: return { # experiment info "run_key": self.run_key, "global_sample_step": self.global_sample_step, "global_batch_step": self.global_batch_step, "global_epoch_step": self.global_epoch_step, # stage info "stage_key": self.stage_key, "stage_epoch_len": self.stage_epoch_len, "stage_epoch_step": self.stage_epoch_step, "stage_batch_step": self.stage_batch_step, "stage_sample_step": self.stage_sample_step, # loader info "loader_key": self.loader_key, "loader_batch_len": self.loader_batch_len, "loader_sample_len": self.loader_sample_len, "loader_batch_step": self.loader_batch_step, "loader_sample_step": self.loader_sample_step, } @property @abstractmethod def stages(self) -> Iterable[str]: """Run's stage names. Example:: >>> runner.stages ["pretraining", "finetuning"] """ pass
[docs] def get_stage_len(self, stage: str) -> int: """Returns number of epochs for the selected stage. Args: stage: current stage Returns: number of epochs in stage Example:: >>> runner.get_stage_len("pretraining") 3 """ return 1
[docs] def get_trial(self) -> Optional[ITrial]: """Returns the trial for the run.""" return None # noqa: WPS324
[docs] @abstractmethod def get_engine(self) -> IEngine: """Returns the engine for the run.""" return None # noqa: WPS324
[docs] def get_loggers(self) -> Dict[str, ILogger]: """Returns the loggers for the run.""" return {}
[docs] def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]": """Returns the datasets for a given stage and epoch. # noqa: DAR401 .. note:: For Deep Learning cases you have the same dataset during whole stage. For Reinforcement Learning it's common to change the dataset (experiment) every training epoch. Args: stage: stage name of interest, like "pretrain" / "train" / "finetune" / etc Returns: # noqa: DAR202 OrderedDict[str, Dataset]: Ordered dictionary with datasets for current stage and epoch. .. note:: We need ordered dictionary to guarantee the correct dataflow and order of our training datasets. For example, to run train loader before validation one :) Example:: >>> runner.get_datasets(stage="training") OrderedDict({ "train": CsvDataset(in_csv=in_csv_train, ...), "valid": CsvDataset(in_csv=in_csv_valid, ...), }) """ raise NotImplementedError
def get_samplers(self, stage: str = None) -> "OrderedDict[str, Sampler]": """Returns samplers for a given stage. # noqa: DAR401 Args: stage: stage name of interest, like "pretrain" / "train" / "finetune" / etc Returns: # noqa: DAR201, DAR202 OrderedDict[str, Sampler]: Ordered dictionary with samplers for current stage and epoch. """ raise NotImplementedError # def get_transforms(self, stage: str = None): # """Returns the data transforms for a given stage and dataset. # # Args: # stage: stage name of interest, # like "pretrain" / "train" / "finetune" / etc # dataset: dataset name of interest, # like "train" / "valid" / "infer" # # .. note:: # For datasets/loaders naming please follow # :py:mod:`catalyst.core.runner` documentation. # # Returns: # noqa: DAR202 # Data transformations to use for specified dataset. # # """ # raise NotImplementedError
[docs] @abstractmethod # noqa: WPS463 def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage. # noqa: DAR401 .. note:: Wrapper for :py:mod:`catalyst.core.experiment.IExperiment.get_datasets`. For most of your experiments you need to rewrite `get_datasets` method only. Args: stage: stage name of interest, like "pretrain" / "train" / "finetune" / etc Returns: # noqa: DAR201, DAR202 OrderedDict[str, DataLoader]: Ordered dictionary with loaders for current stage and epoch. """ pass
[docs] @abstractmethod # noqa: WPS463 def get_model(self, stage: str) -> Model: """Returns the model for a given stage and epoch. Example:: # suppose we have typical MNIST model, like # nn.Sequential(nn.Linear(28*28, 128), nn.Linear(128, 10)) >>> runner.get_model(stage="train") Sequential( : Linear(in_features=784, out_features=128, bias=True) : Linear(in_features=128, out_features=10, bias=True) ) Args: stage: stage name of interest like "pretrain" / "train" / "finetune" / etc Returns: # noqa: DAR201, DAR202 Model: model for a given stage. """ pass
[docs] def get_criterion(self, stage: str) -> Optional[Criterion]: """Returns the criterion for a given stage and epoch. Example:: # for typical classification task >>> runner.get_criterion(stage="train") nn.CrossEntropyLoss() Args: stage: stage name of interest like "pretrain" / "train" / "finetune" / etc Returns: # noqa: DAR201, DAR202 Criterion: criterion for a given stage. """ return None # noqa: WPS324
[docs] def get_optimizer(self, stage: str, model: Model) -> Optional[Optimizer]: """Returns the optimizer for a given stage and model. Example:: >>> runner.get_optimizer(model=model, stage="train") torch.optim.Adam(model.parameters()) Args: stage: stage name of interest like "pretrain" / "train" / "finetune" / etc model: model to optimize with stage optimizer Returns: # noqa: DAR201, DAR202 Optimizer: optimizer for a given stage and model. """ return None # noqa: WPS324
[docs] def get_scheduler(self, stage: str, optimizer: Optimizer) -> Optional[Scheduler]: """Returns the scheduler for a given stage and optimizer. Example:: >>> runner.get_scheduler(stage="training", optimizer=optimizer) torch.optim.lr_scheduler.StepLR(optimizer) Args: stage: stage name of interest like "pretrain" / "train" / "finetune" / etc optimizer: optimizer to schedule with stage scheduler Returns: # noqa: DAR201, DAR202 Scheduler: scheduler for a given stage and optimizer. """ return None # noqa: WPS324
def _get_model(self) -> Model: self.model = self.get_model(stage=self.stage_key) return self.model def _get_criterion(self) -> Criterion: self.criterion = self.get_criterion(stage=self.stage_key) return self.criterion def _get_optimizer(self, model: Model = None) -> Optimizer: if model is not None: self.model = model # assert self.model is not None, "You need to setup model first" self.optimizer = self.get_optimizer(stage=self.stage_key, model=self.model) return self.optimizer def _get_scheduler(self, optimizer: Optimizer = None) -> Scheduler: if optimizer is not None: self.optimizer = optimizer # assert self.optimizer is not None, "You need to setup optimizer first" self.scheduler = self.get_scheduler(stage=self.stage_key, optimizer=self.optimizer) return self.scheduler
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, ICallback]": """Returns callbacks for a given stage. Args: stage: stage name of interest like "pretrain" / "train" / "finetune" / etc Returns: OrderedDict[str, Callback]: Ordered dictionary # noqa: DAR202 with callbacks for current stage. """ return {}
[docs] def log_hparams(self, *args, **kwargs) -> None: """Logs hyperparameters to available loggers.""" for logger in self.loggers.values(): logger.log_hparams( *args, **kwargs, # experiment info run_key=self.run_key, stage_key=self.stage_key, )
[docs] def log_metrics(self, *args, **kwargs) -> None: """Logs batch, loader and epoch metrics to available loggers.""" for logger in self.loggers.values(): logger.log_metrics(*args, **kwargs, **self._log_defaults)
[docs] def log_image(self, *args, **kwargs) -> None: """Logs image to available loggers.""" for logger in self.loggers.values(): logger.log_image(*args, **kwargs, **self._log_defaults)
def log_artifact(self, *args, **kwargs) -> None: """Logs artifact (file like audio, video, csv, etc.) to available loggers.""" for logger in self.loggers.values(): logger.log_artifact(*args, **kwargs, **self._log_defaults) def flush_log(self) -> None: """Flushes the loggers.""" for logger in self.loggers.values(): logger.flush_log() def close_log(self, *args, **kwargs) -> None: """Closes the loggers.""" for logger in self.loggers.values(): logger.close_log(*args, **kwargs) def _setup_loaders(self) -> None: set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) loaders = self.get_loaders(stage=self.stage_key) loaders = validate_loaders(loaders) self.loaders = loaders def _setup_components(self) -> None: set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) self.model, self.criterion, self.optimizer, self.scheduler = self.engine.init_components( model_fn=self._get_model, criterion_fn=self._get_criterion, optimizer_fn=self._get_optimizer, scheduler_fn=self._get_scheduler, ) def _setup_callbacks(self): set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) callbacks = self.get_callbacks(self.stage_key) callbacks = filter_callbacks_by_node(callbacks) callbacks = sort_callbacks_by_order(callbacks) self.callbacks = callbacks def on_experiment_start(self, runner: "IRunner"): """Event handler.""" self.run_key = self.name self.global_epoch_step: int = 0 self.global_batch_step: int = 0 self.global_sample_step: int = 0 self.exception: Exception = None self.need_early_stop: bool = False self.trial = self.get_trial() self.engine = self.get_engine() self.loggers = self.get_loggers() self.log_hparams(hparams=self.hparams, scope="experiment") def on_stage_start(self, runner: "IRunner"): """Event handler.""" assert self.stage_key is not None self.is_infer_stage: bool = self.stage_key.startswith("infer") self.stage_epoch_len = self.get_stage_len(stage=self.stage_key) self.stage_epoch_step: int = 0 self.stage_batch_step: int = 0 self.stage_sample_step: int = 0 if self.engine.is_ddp: self.engine.setup_process(rank=self._stage_rank, world_size=self._stage_world_size) if not self.engine.is_master_process: del self.loggers self.loggers = {} ddp_sync_run(self._setup_loaders) self._setup_components() self._setup_callbacks() self.log_hparams(hparams=self.hparams, scope="stage") def on_epoch_start(self, runner: "IRunner"): """Event handler.""" self.global_epoch_step += 1 self.stage_epoch_step += 1 self.epoch_metrics: Dict = defaultdict(None) # storage for pure epoch-based metrics, like lr/momentum self.epoch_metrics["_epoch_"] = {} assert self.loaders is not None for loader_key, loader in self.loaders.items(): if len(loader) == 0: raise RunnerException(f"DataLoader with name {loader_key} is empty.") set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) def on_loader_start(self, runner: "IRunner"): """Event handler.""" assert self.loader is not None self.is_train_loader: bool = self.loader_key.startswith("train") self.is_valid_loader: bool = self.loader_key.startswith("valid") self.is_infer_loader: bool = self.loader_key.startswith("infer") assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader self.loader_batch_size: int = _get_batch_size(self.loader) self.loader_batch_len: int = len(self.loader) self.loader_sample_len: int = len(self.loader.dataset) self.loader_batch_step: int = 0 self.loader_sample_step: int = 0 self.loader_metrics: Dict = defaultdict(None) if self.loader_batch_len == 0: raise NotImplementedError(f"DataLoader with name {self.loader_key} is empty.") set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler): self.loader.sampler.set_epoch(self.stage_epoch_step) def on_batch_start(self, runner: "IRunner"): """Event handler.""" self.batch = self.engine.sync_device(tensor_or_module=self.batch) if isinstance(self.batch, dict): self.batch_size = len(next(iter(self.batch.values()))) else: self.batch_size = len(self.batch[0]) # we have an batch per each worker... self.global_batch_step += self.engine.world_size self.stage_batch_step += self.engine.world_size self.loader_batch_step += self.engine.world_size self.global_sample_step += self.batch_size * self.engine.world_size self.stage_sample_step += self.batch_size * self.engine.world_size self.loader_sample_step += self.batch_size * self.engine.world_size self.batch_metrics: Dict = defaultdict(None) def on_batch_end(self, runner: "IRunner"): """Event handler.""" # as far as we could `backward` anything from `batch_metrics` on the nodes during training, # they could not be synced before, so we have to sync them in the end of the batch # @TODO: could be done better if self.engine.is_ddp: self.batch_metrics = { k: runner.engine.sync_tensor(torch.tensor(v, device=runner.device), "mean") for k, v in self.batch_metrics.items() } self.log_metrics(metrics=self.batch_metrics, scope="batch") def on_loader_end(self, runner: "IRunner"): """Event handler.""" self.log_metrics(metrics=self.loader_metrics, scope="loader") self.epoch_metrics[self.loader_key] = { key: float(value) for key, value in self.loader_metrics.items() } def on_epoch_end(self, runner: "IRunner"): """Event handler.""" self.log_metrics(metrics=self.epoch_metrics, scope="epoch") self.flush_log() def on_stage_end(self, runner: "IRunner"): """Event handler.""" del self.callbacks self.callbacks = {} del self.loaders self.loaders = {} self.engine.deinit_components(runner=self) self.close_log(scope="stage") # due to multiprocessing setup we have to close current loggers # to prevent EOF-like errors if self.engine.is_ddp: self.flush_log() self.close_log() self.engine.cleanup_process() def on_experiment_end(self, runner: "IRunner"): """Event handler.""" self.flush_log() self.close_log(scope="experiment") def on_exception(self, runner: "IRunner"): """Event handler.""" raise self.exception def _run_event(self, event: str) -> None: if _has_str_intersections(event, ("_start",)): getattr(self, event)(self) for callback in self.callbacks.values(): getattr(callback, event)(self) if _has_str_intersections(event, ("_end", "_exception")): getattr(self, event)(self)
[docs] @abstractmethod def handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer stage during Experiment run. Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. """ pass
def _run_batch(self) -> None: self._run_event("on_batch_start") self.handle_batch(batch=self.batch) self.batch = self.engine.sync_device(self.batch) self._run_event("on_batch_end") def _run_loader(self) -> None: # NOTE: wrapped forward because need to scale forward propagation # as it was noted in docs: # https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training self._run_event("on_loader_start") with torch.set_grad_enabled(self.is_train_loader): for self.loader_batch_step, self.batch in enumerate(self.loader): with self.engine.autocast(): self._run_batch() if self.need_early_stop: self.need_early_stop = False break self._run_event("on_loader_end") def _run_epoch(self) -> None: self._run_event("on_epoch_start") for self.loader_key, self.loader in self.loaders.items(): self._run_loader() self._run_event("on_epoch_end") def _run_stage(self, rank: int = -1, world_size: int = 1) -> None: self._stage_rank, self._stage_world_size = rank, world_size self._run_event("on_stage_start") while self.stage_epoch_step < self.stage_epoch_len: self._run_epoch() if self.need_early_stop: self.need_early_stop = False break self._run_event("on_stage_end") def _run_experiment(self) -> None: self._run_event("on_experiment_start") for self.stage_key in self.stages: if self.engine.is_ddp: # ddp-device branch world_size = self.engine.world_size torch.multiprocessing.spawn( self._run_stage, args=(world_size,), nprocs=world_size, join=True, ) else: # single-device branch (cpu, gpu, dp) self._run_stage() self._run_event("on_experiment_end")
[docs] def run(self) -> "IRunner": """Runs the experiment. Returns: self, `IRunner` instance after the experiment """ try: self._run_experiment() except (Exception, KeyboardInterrupt) as ex: self.exception = ex self._run_event("on_exception") return self
__all__ = ["IRunner", "RunnerException"]