Shortcuts

Source code for catalyst.engines.torch

from typing import Any, Dict, List, Mapping, Tuple, Union
import copy
import os

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

from catalyst.core.engine import IEngine
from catalyst.typing import (
    Device,
    Model,
    Optimizer,
    RunnerCriterion,
    RunnerModel,
    RunnerOptimizer,
    RunnerScheduler,
)
from catalyst.utils.distributed import ddp_reduce
from catalyst.utils.torch import (
    any2device,
    load_checkpoint,
    pack_checkpoint,
    save_checkpoint,
    unpack_checkpoint,
)


[docs]class DeviceEngine(IEngine): """Single training device engine. Args: device: use device, default is `"cpu"`. Examples: .. code-block:: python from catalyst import dl runner = dl.SupervisedRunner() runner.train( engine=dl.DeviceEngine("cuda:1"), ... ) .. code-block:: python from catalyst import dl class MyRunner(dl.IRunner): # ... def get_engine(self): return dl.DeviceEngine("cuda:1") # ... .. code-block:: yaml args: logs: ... model: _target_: ... ... engine: _target_: DeviceEngine device: cuda:1 stages: ... """ def __init__(self, device: str = None): """Init.""" device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._device = device def __repr__(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(device='{self._device}')" @property def device(self) -> Device: """Pytorch device.""" return self._device @property def rank(self) -> int: """Process rank for distributed training.""" return -1 @property def world_size(self) -> int: """Process world size for distributed training.""" return 1 def sync_device( self, tensor_or_module: Union[Dict, List, Tuple, np.ndarray, torch.Tensor, nn.Module] ) -> Union[Dict, List, Tuple, torch.Tensor, nn.Module]: """Moves ``tensor_or_module`` to Engine's deivce.""" return any2device(tensor_or_module, device=self.device) def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor: """Syncs ``tensor`` over ``world_size`` in distributed mode.""" return tensor def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" # model model = model_fn() model = self.sync_device(model) # criterion criterion = criterion_fn() criterion = self.sync_device(criterion) # optimizer optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) # scheduler scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler def deinit_components(self, runner=None): """Deinits the runs components.""" pass def zero_grad(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None: """Abstraction over ``model.zero_grad()`` step.""" model.zero_grad() def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None: """Abstraction over ``loss.backward()`` step.""" loss.backward() def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None: """Abstraction over ``optimizer.step()`` step.""" optimizer.step() def pack_checkpoint( self, model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, **kwargs, ) -> Dict: """ Packs ``model``, ``criterion``, ``optimizer``, ``scheduler`` and some extra info ``**kwargs`` to torch-based checkpoint. Args: model: torch model criterion: torch criterion optimizer: torch optimizer scheduler: torch scheduler **kwargs: some extra info to pack Returns: torch-based checkpoint with ``model_state_dict``, ``criterion_state_dict``, ``optimizer_state_dict``, ``scheduler_state_dict`` keys. """ return pack_checkpoint( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, **kwargs ) def unpack_checkpoint( self, checkpoint: Dict, model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, **kwargs, ) -> None: """Load checkpoint from file and unpack the content to a model (if not None), criterion (if not None), optimizer (if not None), scheduler (if not None). Args: checkpoint: checkpoint to load model: model where should be updated state criterion: criterion where should be updated state optimizer: optimizer where should be updated state scheduler: scheduler where should be updated state kwargs: extra arguments """ unpack_checkpoint( checkpoint=checkpoint, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, ) def save_checkpoint(self, checkpoint: Mapping[str, Any], path: str): """Saves checkpoint to a file. Args: checkpoint: data to save. path: filepath where checkpoint should be stored. """ save_checkpoint(checkpoint=checkpoint, path=path) def load_checkpoint(self, path: str): """Load checkpoint from path. Args: path: checkpoint file to load Returns: loaded checkpoint """ return load_checkpoint(path=path)
[docs]class DataParallelEngine(DeviceEngine): """MultiGPU training device engine. Examples: .. code-block:: python from catalyst import dl runner = dl.SupervisedRunner() runner.train( engine=dl.DataParallelEngine(), ... ) .. code-block:: python from catalyst import dl class MyRunner(dl.IRunner): # ... def get_engine(self): return dl.DataParallelEngine() # ... .. code-block:: yaml args: logs: ... model: _target_: ... ... engine: _target_: DataParallelEngine stages: ... """ def __init__(self): """Init""" super().__init__(f"cuda:{torch.cuda.current_device()}") self.device_count = torch.cuda.device_count() def __repr__(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(device_count={self.device_count})" def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" model = model_fn() model = self.sync_device(model) if isinstance(model, nn.Module): model = nn.DataParallel(model) elif isinstance(model, dict): model = {k: nn.DataParallel(v) for k, v in model.items()} else: raise ValueError("Model should be ``nn.Module`` or ``dict``") # criterion criterion = criterion_fn() criterion = self.sync_device(criterion) # optimizer optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) # scheduler scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) return model, criterion, optimizer, scheduler
[docs]class DistributedDataParallelEngine(DeviceEngine): """Distributed MultiGPU training device engine. Args: address: address to use for backend. port: port to use for backend. ddp_kwargs: parameters for `torch.nn.parallel.DistributedDataParallel`. More info here: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel process_group_kwargs: parameters for `torch.distributed.init_process_group`. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group Examples: .. code-block:: python from catalyst import dl runner = dl.SupervisedRunner() runner.train( engine=dl.DistributedDataParallelEngine(), ... ) .. code-block:: python from catalyst import dl class MyRunner(dl.IRunner): # ... def get_engine(self): return dl.DistributedDataParallelEngine( address="0.0.0.0", port=23234, ddp_kwargs={"find_unused_parameters": False}, process_group_kwargs={"backend": "nccl"}, ) # ... .. code-block:: yaml args: logs: ... model: _target_: ... ... engine: _target_: DistributedDataParallelEngine address: 0.0.0.0 port: 23234 ddp_kwargs: find_unused_parameters: false process_group_kwargs: backend: nccl stages: ... """ def __init__( self, address: str = None, port: Union[str, int] = None, ddp_kwargs: Dict[str, Any] = None, process_group_kwargs: Dict[str, Any] = None, ): """Init.""" super().__init__() self.address = address or "localhost" self.port = port or 12345 self._rank = 0 self._device = None if ddp_kwargs is None: ddp_kwargs = {} self.ddp_kwargs = copy.deepcopy(ddp_kwargs) if process_group_kwargs is None: process_group_kwargs = {} self.process_group_kwargs = copy.deepcopy(process_group_kwargs) # add missing arguments if "backend" not in self.process_group_kwargs: self.process_group_kwargs["backend"] = "nccl" if "world_size" not in self.process_group_kwargs: self.process_group_kwargs["world_size"] = torch.cuda.device_count() self._world_size = ( self.process_group_kwargs.get("world_size", None) or torch.cuda.device_count() ) def __repr__(self): # noqa: D105 return ( f"{self.__class__.__name__}(address={self.address}, " f"port={self.port}, " f"ddp_kwargs={self.ddp_kwargs}, " f"process_group_kwargs={self.process_group_kwargs})" ) @property def rank(self) -> int: """Process rank for distributed training.""" return self._rank @property def world_size(self) -> int: """Process world size for distributed training.""" return self._world_size def setup_process(self, rank: int = -1, world_size: int = 1): """Initialize DDP variables and processes. Args: rank: process rank. Default is `-1`. world_size: number of devices in netwok to expect for train. Default is `1`. """ self._rank = rank self._world_size = world_size torch.cuda.set_device(int(self._rank)) self._device = f"cuda:{int(self._rank)}" self.process_group_kwargs["rank"] = rank self.process_group_kwargs["world_size"] = world_size os.environ["MASTER_ADDR"] = str(self.address) os.environ["MASTER_PORT"] = str(self.port) dist.init_process_group(**self.process_group_kwargs) def cleanup_process(self): """Clean DDP variables and processes.""" dist.barrier() dist.destroy_process_group() def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor: """Syncs ``tensor`` over ``world_size`` in distributed mode. Args: tensor: tensor to sync across the processes. mode: tensor synchronization type, should be one of 'sum' or 'mean'. Default is 'mean'. Returns: torch.Tensor with synchronized values. """ return ddp_reduce(tensor, mode, self.world_size) def init_components( self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None, ): """Inits the runs components.""" if "device_ids" not in self.ddp_kwargs: self.ddp_kwargs["device_ids"] = [self._device] # model model = model_fn() model = self.sync_device(model) if isinstance(model, nn.Module): model = DistributedDataParallel(model, **self.ddp_kwargs) elif isinstance(model, dict): model = {k: DistributedDataParallel(v, **self.ddp_kwargs) for k, v in model.items()} else: raise ValueError("Model should be ``nn.Module`` or ``dict``") # criterion criterion = criterion_fn() criterion = self.sync_device(criterion) # optimizer optimizer = optimizer_fn() optimizer = self.sync_device(optimizer) # scheduler scheduler = scheduler_fn() scheduler = self.sync_device(scheduler) dist.barrier() return model, criterion, optimizer, scheduler
__all__ = ["DeviceEngine", "DataParallelEngine", "DistributedDataParallelEngine"]