Shortcuts

Source code for catalyst.utils.distributed

from typing import Any, List, Optional
import os
import pickle

import torch
from torch import nn
import torch.distributed as dist

from catalyst.settings import SETTINGS

if SETTINGS.xla_required:
    import torch_xla.core.xla_env_vars as xenv
    import torch_xla.core.xla_model as xm


def _is_torch_distributed_initialized() -> bool:
    """Checks if torch.distributed is available and initialized."""
    return dist.is_available() and dist.is_initialized()


def _is_xla_distributed_initialized() -> bool:
    return (
        SETTINGS.xla_required and os.environ.get(xenv.TORCH_DIST_ROOT, None) is not None
    )


def _is_ddp_wrapped(model: nn.Module) -> bool:
    """Checks whether model is wrapped with DataParallel/DistributedDataParallel."""
    parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel

    # Check whether Apex is installed and if it is,
    # add Apex's DistributedDataParallel to list of checked types
    if SETTINGS.apex_required:
        from apex.parallel import DistributedDataParallel as apex_DDP

        parallel_wrappers = parallel_wrappers + (apex_DDP,)

    if SETTINGS.fairscale_required:
        from fairscale.nn.data_parallel import (
            FullyShardedDataParallel,
            ShardedDataParallel,
        )

        parallel_wrappers = parallel_wrappers + (
            ShardedDataParallel,
            FullyShardedDataParallel,
        )

    if SETTINGS.deepspeed_required:
        from deepspeed import DeepSpeedEngine, PipelineEngine

        parallel_wrappers = parallel_wrappers + (DeepSpeedEngine, PipelineEngine)

    return isinstance(model, parallel_wrappers)


[docs]def get_nn_from_ddp_module(model: nn.Module) -> nn.Module: """ Return a real model from a torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, or apex.parallel.DistributedDataParallel. Args: model: A model, or DataParallel wrapper. Returns: A model """ if _is_ddp_wrapped(model): model = model.module return model
[docs]def get_backend() -> Optional[str]: """Returns the backend for distributed training.""" if _is_xla_distributed_initialized(): return "xla" elif _is_torch_distributed_initialized(): return "ddp" else: return None
[docs]def get_rank() -> int: """ Returns the rank of the current worker. Returns: int: ``rank`` if torch.distributed is initialized, otherwise ``-1`` """ if _is_xla_distributed_initialized(): return xm.get_ordinal() elif _is_torch_distributed_initialized(): return dist.get_rank() else: return -1
# def get_local_rank() -> int: # pass
[docs]def get_world_size() -> int: """Returns the world size for distributed training.""" if _is_xla_distributed_initialized(): return xm.xrt_world_size() elif _is_torch_distributed_initialized(): return dist.get_world_size() else: return 1
# def get_num_nodes() -> int: # pass # # # def get_num_proc_per_nodes() -> int: # pass # # # def get_node_rank() -> int: # pass
[docs]def sum_reduce(tensor: torch.Tensor) -> torch.Tensor: """Reduce tensor to all processes and compute total (sum) value. Args: tensor: tensor to reduce. Returns: reduced tensor """ cloned = tensor.clone() dist.all_reduce(cloned, dist.ReduceOp.SUM) return cloned
[docs]def mean_reduce(tensor: torch.Tensor, world_size: int) -> torch.Tensor: """Reduce tensor to all processes and compute mean value. Args: tensor: tensor to reduce. world_size: number of processes in DDP setup. Returns: reduced tensor """ # TODO: fix division operator for int/long tensors reduced = sum_reduce(tensor) / world_size return reduced
[docs]def all_gather(data: Any) -> List[Any]: """Run all_gather on arbitrary picklable data (not necessarily tensors). .. note:: if data on different devices then data in resulted list will be on the same devices. Source: http://github.com/facebookresearch/detr/blob/master/util/misc.py#L88-L128 Args: data: any picklable object Returns: list of data gathered from each process. """ if not dist.is_available() or not dist.is_initialized(): world_size = 1 else: world_size = dist.get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: padding = torch.empty( size=(max_size - local_size,), dtype=torch.uint8, device="cuda" ) tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list
[docs]def ddp_reduce(tensor: torch.Tensor, mode: str, world_size: int): """Syncs ``tensor`` over ``world_size`` in distributed mode. Args: tensor: tensor to sync across the processes. mode: tensor synchronization type, should be one of 'sum', 'mean' or 'all'. world_size: world size Returns: torch.Tensor with synchronized values. Raises: ValueError: if mode is out of ``sum``, ``mean``, ``all``. """ if mode not in {"sum", "mean", "all"}: raise ValueError(f"Unknown sync_type '{mode}'") if mode == "sum": return sum_reduce(tensor) elif mode == "mean": return mean_reduce(tensor, world_size) else: return all_gather(tensor)
__all__ = [ "get_backend", "get_rank", "get_world_size", "get_nn_from_ddp_module", "sum_reduce", "mean_reduce", "all_gather", "ddp_reduce", ]