Source code for catalyst.core.engine
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from abc import ABC, abstractmethod
from contextlib import contextmanager
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from catalyst.typing import Criterion, Device, Model, Optimizer, Scheduler
@contextmanager
def nullcontext(enter_result: Any = None):
"""Context handler."""
yield enter_result
[docs]class IEngine(ABC):
"""
An abstraction that syncs experiment run with
different hardware-specific configurations.
- CPU
- GPU
- DataParallel (deepspeed, fairscale, nvidia, torch)
- AMP (deepspeed, fairscale, nvidia, torch)
- DDP (deepspeed, fairscale, nvidia, torch)
- XLA
Abstraction, please check out implementations for more details:
- :py:mod:`catalyst.engines.amp.AMPEngine`
- :py:mod:`catalyst.engines.apex.APEXEngine`
- :py:mod:`catalyst.engines.torch.DeviceEngine`
"""
@property
@abstractmethod
def device(self) -> Device:
"""Pytorch device."""
pass
# @property
# @abstractmethod
# def local_rank(self) -> int:
# """Process local rank for distributed training."""
# pass
@property
@abstractmethod
def rank(self) -> int:
"""Process rank for distributed training."""
pass
@property
@abstractmethod
def world_size(self) -> int:
"""Process world size for distributed training."""
pass
# @property
# @abstractmethod
# def num_nodes(self) -> int:
# pass
#
# @property
# @abstractmethod
# def num_proc_per_node(self) -> int:
# pass
#
# @property
# @abstractmethod
# def node_rank(self) -> int:
# pass
@property
@abstractmethod
def backend(self) -> Optional[str]:
"""String identifier for distributed backend."""
pass
@property
def is_ddp(self) -> bool:
"""Boolean flag for distributed run."""
return self.backend is not None
@property
def is_master_process(self) -> bool:
"""Checks if a process is master process.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `True`.
Returns:
`True` if current process is a master process in other cases return `False`.
"""
# -1 for non-distributed setup
# 0 for distributed setup
return self.rank <= 0
@property
def is_worker_process(self) -> bool:
"""Checks if a process is worker process.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `False`.
Returns:
`True` if current process is a worker process in other cases return `False`.
"""
return self.rank > 0
[docs] def barrier(self) -> None:
"""
Synchronizes all processes.
This collective blocks processes until the all runs enter the function.
"""
pass
[docs] def spawn(self, fn: Callable, *args: Any, **kwargs: Any) -> None:
"""Spawns abstraction for``nprocs`` creation with specified ``fn`` and ``args``/``kwargs``.
Args:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
*args: Arguments passed to spawn method.
**kwargs: Keyword-arguments passed to spawn method.
Returns:
wrapped function (if needed).
"""
return fn(*args, **kwargs)
[docs] def setup_process(self, rank: int = -1, world_size: int = 1):
"""Initialize DDP variables and processes.
Args:
rank: process rank. Default is `-1`.
world_size: number of devices in netwok to expect for train.
Default is `1`.
"""
pass
[docs] def cleanup_process(self):
"""Clean DDP variables and processes."""
pass
# TODO: make context manager
[docs] def ddp_sync_run(self, function: Callable):
"""Function wrapper for synchronous run in the distributed mode."""
if self.rank > 0:
self.barrier()
function()
if self.rank == 0:
self.barrier()
if self.rank > -1:
self.barrier()
[docs] @abstractmethod
def sync_device(
self, tensor_or_module: Union[Dict, List, Tuple, np.ndarray, torch.Tensor, nn.Module]
) -> Union[Dict, List, Tuple, torch.Tensor, nn.Module]:
"""Moves ``tensor_or_module`` to Engine's device.
Args:
tensor_or_module: tensor to mode
"""
pass
[docs] @abstractmethod
def sync_metrics(self, metrics: Dict) -> Dict:
"""Syncs ``metrics`` over ``world_size`` in the distributed mode."""
return metrics
[docs] @abstractmethod
def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor:
"""Syncs ``tensor`` over ``world_size`` in the distributed mode."""
pass
[docs] @abstractmethod
def init_components(
self,
model_fn: Callable = None,
criterion_fn: Callable = None,
optimizer_fn: Callable = None,
scheduler_fn: Callable = None,
):
"""Inits the runs components."""
pass
# due to FairScale setup, we need to manually delete the model in the end
# that's why we need the runner.model here
[docs] @abstractmethod
def deinit_components(self, runner=None):
"""Deinits the runs components. In distributed mode should destroy process group."""
pass
[docs] @abstractmethod
def zero_grad(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``model.zero_grad()`` step.
Should be overloaded in cases when required to set arguments
for ``model.zero_grad()`` like `set_to_none=True` or
you need to use custom scheme which replaces/improves
`.zero_grad()` method.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass
[docs] @abstractmethod
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step.
Should be overloaded in cases when required loss scaling.
Examples - APEX and AMP.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass
[docs] @abstractmethod
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step.
Should be overloaded in cases when required gradient scaling.
Example - AMP.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass
[docs] @abstractmethod
def pack_checkpoint(
self,
model: Model = None,
criterion: Criterion = None,
optimizer: Optimizer = None,
scheduler: Scheduler = None,
**kwargs,
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
"""
pass
[docs] @abstractmethod
def unpack_checkpoint(
self,
checkpoint: Dict,
model: Model = None,
criterion: Criterion = None,
optimizer: Optimizer = None,
scheduler: Scheduler = None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
pass
[docs] @abstractmethod
def save_checkpoint(self, checkpoint: Dict, path: str) -> None:
"""Saves checkpoint to a file.
Args:
checkpoint: data to save.
path: filepath where checkpoint should be stored.
"""
pass
[docs] @abstractmethod
def load_checkpoint(self, path: str) -> Dict:
"""Load checkpoint from path.
Args:
path: checkpoint file to load
"""
pass
[docs] def autocast(self, *args, **kwargs):
"""AMP scaling context.
Default autocast context does not scale anything.
Args:
*args: some args
**kwargs: some kwargs
Returns:
context
"""
return nullcontext()
[docs] def autocast_loader(self, loader: DataLoader):
"""Loader wrapper for the distributed mode."""
return loader
__all__ = ["IEngine"]