Shortcuts

Source code for catalyst.core.engine

# taken from https://github.com/Scitator/animus/blob/main/animus/torch/accelerate.py

from typing import Callable, Dict

import numpy as np

from accelerate import Accelerator
from accelerate.state import DistributedType
import torch

from catalyst import SETTINGS
from catalyst.utils.distributed import mean_reduce

if SETTINGS.xla_required:
    import torch_xla.core.xla_model as xm


[docs]class Engine(Accelerator): """ An abstraction that syncs experiment run with different hardware-specific configurations. - CPU - GPU - DataParallel (deepspeed, torch) - AMP (deepspeed, torch) - DDP (deepspeed, torch) - XLA Please check out implementations for more details: - :py:mod:`catalyst.engines.torch.CPUEngine` - :py:mod:`catalyst.engines.torch.GPUEngine` - :py:mod:`catalyst.engines.torch.DataParallelEngine` - :py:mod:`catalyst.engines.torch.DistributedDataParallelEngine` - :py:mod:`catalyst.engines.torch.DistributedXLAEngine` """ @property def is_ddp(self): """Boolean flag for distributed type.""" return self.distributed_type != DistributedType.NO
[docs] def spawn(self, fn: Callable, *args, **kwargs): """Spawns processes with specified ``fn`` and ``args``/``kwargs``. Args: fn (function): Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args)``, where ``i`` is the process index and ``args`` is the passed through tuple of arguments. *args: Arguments passed to spawn method. **kwargs: Keyword-arguments passed to spawn method. Returns: wrapped function (if needed). """ return fn(*args, **kwargs)
[docs] def setup(self, local_rank: int, world_size: int): """Initialize DDP variables and processes if required. Args: local_rank: process rank. Default is `-1`. world_size: number of devices in netwok to expect for train. Default is `1`. """ pass
[docs] def cleanup(self): """Cleans DDP variables and processes.""" pass
[docs] def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: """Syncs ``metrics`` over ``world_size`` in the distributed mode.""" if self.state.distributed_type in [ DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, ]: metrics = { k: mean_reduce( torch.tensor(v, device=self.device), world_size=self.state.num_processes, ) for k, v in metrics.items() } elif self.state.distributed_type == DistributedType.TPU: metrics = { k: xm.mesh_reduce( k, v.item() if isinstance(v, torch.Tensor) else v, np.mean ) for k, v in metrics.items() } return metrics
__all__ = ["Engine"]