Source code for catalyst.utils.ddp

from torch import nn


[docs]def is_wrapped_with_ddp(model: nn.Module) -> bool: """ Checks whether model is wrapped with DataParallel/DistributedDataParallel. """ parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel # Check whether Apex is installed and if it is, # add Apex's DistributedDataParallel to list of checked types try: from apex.parallel import DistributedDataParallel as apex_DDP parallel_wrappers = parallel_wrappers + (apex_DDP, ) except ImportError: pass return isinstance(model, parallel_wrappers)
[docs]def get_real_module(model: nn.Module) -> nn.Module: """ Return a real model from a torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, or apex.parallel.DistributedDataParallel. Args: model: A model, or DataParallel wrapper. Returns: A model """ if is_wrapped_with_ddp(model): model = model.module return model