Shortcuts

Source code for catalyst.engines.torch

# taken from https://github.com/Scitator/animus/blob/main/animus/torch/accelerate.py
from typing import Any, Callable, Dict, Optional, Union
import os

import numpy as np

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from catalyst import SETTINGS
from catalyst.core.engine import Engine
from catalyst.utils.distributed import mean_reduce

if SETTINGS.xla_required:
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp


[docs]class CPUEngine(Engine): """CPU-based engine.""" def __init__(self, *args, **kwargs) -> None: """Init.""" super().__init__(*args, cpu=True, **kwargs)
[docs]class GPUEngine(Engine): """Single-GPU-based engine.""" def __init__(self, *args, **kwargs) -> None: """Init.""" super().__init__(*args, cpu=False, **kwargs)
[docs]class DataParallelEngine(Engine): """Multi-GPU-based engine.""" def __init__(self, *args, **kwargs) -> None: """Init.""" super().__init__(*args, cpu=False, **kwargs) def prepare_model(self, model): """Overrides.""" model = torch.nn.DataParallel(model) model = super().prepare_model(model) return model
[docs]class DistributedDataParallelEngine(Engine): """Distributed multi-GPU-based engine. Args: *args: args for Accelerator.__init__ address: master node (rank 0)'s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, can simply be 127.0.0.1 port: master node (rank 0)'s free port that needs to be used for communication during distributed training world_size: the number of processes to use for distributed training. Should be less or equal to the number of GPUs workers_dist_rank: the rank of the first process to run on the node. It should be a number between `number of initialized processes` and `world_size - 1`, the other processes on the node wiil have ranks `# of initialized processes + 1`, `# of initialized processes + 2`, ..., `# of initialized processes + num_node_workers - 1` num_node_workers: the number of processes to launch on the node. For GPU training, this is recommended to be set to the number of GPUs on the current node so that each process can be bound to a single GPU process_group_kwargs: parameters for `torch.distributed.init_process_group`. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group # noqa: E501, W505 **kwargs: kwargs for Accelerator.__init__ """ def __init__( self, *args, address: str = "127.0.0.1", port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, **kwargs ): """Init.""" self._address = os.environ.get("MASTER_ADDR", address) self._port = os.environ.get("MASTER_PORT", port) self._num_local_workers = num_node_workers or torch.cuda.device_count() or 1 self._workers_global_rank = workers_dist_rank self._world_size = world_size or self._num_local_workers self._process_group_kwargs = process_group_kwargs or {} self._args = args self._kwargs = kwargs def spawn(self, fn: Callable, *args, **kwargs): """Spawns processes with specified ``fn`` and ``args``/``kwargs``. Args: fn (function): Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args)``, where ``i`` is the process index and ``args`` is the passed through tuple of arguments. *args: Arguments passed to spawn method. **kwargs: Keyword-arguments passed to spawn method. Returns: wrapped function (if needed). """ return mp.spawn( fn, args=(self._world_size,), nprocs=self._num_local_workers, join=True, ) def setup(self, local_rank: int, world_size: int): """Initialize DDP variables and processes if required. Args: local_rank: process rank. Default is `-1`. world_size: number of devices in netwok to expect for train. Default is `1`. """ process_group_kwargs = { "backend": "nccl", "world_size": world_size, **self._process_group_kwargs, } global_rank = self._workers_global_rank + local_rank os.environ["MASTER_ADDR"] = str(self._address) os.environ["MASTER_PORT"] = str(self._port) os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(global_rank) os.environ["LOCAL_RANK"] = str(local_rank) dist.init_process_group(**process_group_kwargs) super().__init__(self, *self._args, **self._kwargs) def cleanup(self): """Cleans DDP variables and processes.""" dist.destroy_process_group() def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: """Syncs ``metrics`` over ``world_size`` in the distributed mode.""" metrics = { k: mean_reduce( torch.tensor(v, device=self.device), world_size=self.state.num_processes, ) for k, v in metrics.items() } return metrics
[docs]class DistributedXLAEngine(Engine): """Distributed XLA-based engine.""" def __init__(self, *args, **kwargs): """Init.""" self._args = args self._kwargs = kwargs def spawn(self, fn: Callable, *args, **kwargs): """Spawns processes with specified ``fn`` and ``args``/``kwargs``. Args: fn (function): Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args)``, where ``i`` is the process index and ``args`` is the passed through tuple of arguments. *args: Arguments passed to spawn method. **kwargs: Keyword-arguments passed to spawn method. Returns: wrapped function (if needed). """ world_size: int = 8 return xmp.spawn(fn, args=(world_size,), nprocs=world_size, start_method="fork") def setup(self, local_rank: int, world_size: int): """Initialize DDP variables and processes if required. Args: local_rank: process rank. Default is `-1`. world_size: number of devices in netwok to expect for train. Default is `1`. """ super().__init__(self, *self._args, **self._kwargs) def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: """Syncs ``metrics`` over ``world_size`` in the distributed mode.""" metrics = { k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean) for k, v in metrics.items() } return metrics
__all__ = [ CPUEngine, GPUEngine, DataParallelEngine, DistributedDataParallelEngine, DistributedXLAEngine, ]