from typing import Any, Dict, Union
from collections import OrderedDict
import os
import torch
from torch import nn
import torch.distributed as dist
from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.settings import SETTINGS
from catalyst.typing import Model, Optimizer, RunnerModel, RunnerOptimizer
from catalyst.utils.misc import get_fn_default_params
if SETTINGS.apex_required:
import apex
import apex.amp as amp
from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
def _initialize_apex(model, optimizer=None, **engine_params):
"""
Prepares model and optimizer for work with Nvidia Apex.
Args:
model: torch model
optimizer: torch optimizer
**engine_params: extra params for ``apex.amp.initialize``
Returns:
model and optimizer, wrapped with Nvidia Apex initialization
"""
amp_params = get_fn_default_params(apex.amp.initialize, ["models", "optimizers"])
amp_params["opt_level"] = "O0"
for dp in engine_params:
if dp in amp_params:
amp_params[dp] = engine_params[dp]
# NVIDIA apex support only:
# model: nn.Module or list of modules
# optimizer: None, torch.Optimizer or list of optimizers
# while key-value is preferred in the `catalyst`.
# So if model/optimizer is a dict, convert it to lists of keys
# and values first, and then cast it back after apex initialization
model_keys, optimizer_keys = None, None
if isinstance(model, dict):
model_keys, model = list(model.keys()), list(model.values())
if isinstance(optimizer, dict):
optimizer_keys = list(optimizer.keys())
optimizer = list(optimizer.values())
amp_result = apex.amp.initialize(model, optimizer, **amp_params)
if optimizer is not None:
model, optimizer = amp_result
else:
model = amp_result
# convert model/optimizer back to dict if it needed
if model_keys is not None:
model = OrderedDict([(k, v) for k, v in zip(model_keys, model)])
if optimizer_keys is not None:
optimizers = [(k, v) for k, v in zip(optimizer_keys, optimizer)]
optimizer = OrderedDict(optimizers)
return model, optimizer
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
def _patch_forward(model):
input_caster_lambda = (
lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options["cast_model_type"])
if tensor.is_floating_point()
else tensor
)
output_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options.get("cast_model_outputs", torch.float32)
)
if tensor.is_floating_point()
else tensor
)
def new_fwd(
*args,
old_fwd=model.forward,
input_caster=input_caster_lambda,
output_caster=output_caster_lambda,
**kwargs,
):
return apex.amp._initialize.applier(
old_fwd(
*apex.amp._initialize.applier(args, input_caster),
**apex.amp._initialize.applier(kwargs, input_caster),
),
output_caster,
)
model.forward = new_fwd
return model
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
# apex issue https://github.com/deepset-ai/FARM/issues/210
# solution: https://github.com/NVIDIA/apex/issues/503#issuecomment-566181771
def _wrap_into_data_parallel_with_apex(
model: RunnerModel, optimizer: RunnerOptimizer, distributed_params: Dict
):
if isinstance(model, nn.Module):
model = nn.Sequential(model)
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = torch.nn.DataParallel(model[0])
model = _patch_forward(model)
elif isinstance(model, dict):
model = {k: nn.Sequential(v) for k, v in model.items()}
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = {k: nn.DataParallel(v[0]) for k, v in model.items()}
model = {k: _patch_forward(v) for k, v in model.items()}
else:
raise ValueError("Model should be ``nn.Module`` or ``dict``")
return model, optimizer
[docs]class APEXEngine(DeviceEngine):
"""Apex single training device engine.
Args:
device: use device, default is `"cuda"`.
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False)),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False))
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: APEXEngine
apex_kwargs:
opt_level: O1
keep_batchnorm_fp32: false
stages:
...
"""
def __init__(self, device: str = "cuda", apex_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(device)
self.apex_kwargs = apex_kwargs or {}
def __repr__(self) -> str: # noqa: D105
args_list = [f"device='{self._device}'", f"apex_kwargs={self.apex_kwargs}"]
return f"{self.__class__.__name__}(" + ",".join(args_list) + ")"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
# model
model = model_fn()
# model = _patch_forward(model)
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
# from official docs:
# https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
model, optimizer = _initialize_apex(model, optimizer, **self.apex_kwargs)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = {"amp": amp.state_dict()}
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**checkpoint,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
class DataParallelAPEXEngine(APEXEngine):
"""Apex multi-gpu training device engine.
Args:
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DataParallelApexEngine(apex_kwargs=dict(opt_level="O1")),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelApexEngine(apex_kwargs=dict(opt_level="O1"))
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DataParallelApexEngine
apex_kwargs:
opt_level: O1
stages:
...
"""
def __init__(self, apex_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}", apex_kwargs)
self.device_count = torch.cuda.device_count()
def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self._device}', apex_kwargs={self.apex_kwargs})"
)
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
model, optimizer = _wrap_into_data_parallel_with_apex(
model, optimizer, distributed_params=self.apex_kwargs
)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine):
"""Distributed Apex MultiGPU training device engine.
Args:
address: address to use for backend.
port: port to use for backend.
sync_bn: boolean flag for batchnorm synchonization during disributed training.
if True, applies Apex `convert_syncbn_model`_ to the model for native torch
distributed only. Default, False.
ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`.
More info here:
https://nvidia.github.io/apex/parallel.html#apex.parallel.DistributedDataParallel
process_group_kwargs: parameters for `torch.distributed.init_process_group`.
More info here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DistributedDataParallelApexEngine(
ddp_kwargs={"allreduce_always_fp32": True},
process_group_kwargs={"backend": "nccl"},
apex_kwargs={"opt_level": "O1"},
),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelApexEngine(
address="0.0.0.0",
port=23234,
ddp_kwargs={"allreduce_always_fp32": True},
process_group_kwargs={"backend": "nccl"},
apex_kwargs={"opt_level": "O1"},
)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DistributedDataParallelApexEngine
address: 0.0.0.0
port: 23234
ddp_kwargs:
allreduce_always_fp32: true
process_group_kwargs:
backend: nccl
apex_kwargs:
opt_level: O1
stages:
...
.. _`convert_syncbn_model`:
https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model
"""
def __init__(
self,
address: str = None,
port: Union[str, int] = None,
sync_bn: bool = False,
ddp_kwargs: Dict[str, Any] = None,
process_group_kwargs: Dict[str, Any] = None,
apex_kwargs: Dict[str, Any] = None,
):
"""Init."""
super().__init__(
address=address,
port=port,
sync_bn=sync_bn,
ddp_kwargs=None,
process_group_kwargs=process_group_kwargs,
)
self.ddp_kwargs = ddp_kwargs or {}
self.apex_kwargs = apex_kwargs or {}
def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, "
f"ddp_kwargs={self.ddp_kwargs}, "
f"process_group_kwargs={self.process_group_kwargs}, "
f"apex_kwargs={self.apex_kwargs})"
)
def setup_process(self, rank: int = -1, world_size: int = 1):
"""Initialize DDP variables and processes.
Args:
rank: process rank. Default is `-1`.
world_size: number of devices in netwok to expect for train.
Default is `1`.
"""
self._rank = rank
self._world_size = world_size
self.process_group_kwargs["rank"] = rank
self.process_group_kwargs["world_size"] = world_size
os.environ["MASTER_ADDR"] = str(self.address)
os.environ["MASTER_PORT"] = str(self.port)
dist.init_process_group(**self.process_group_kwargs)
torch.cuda.set_device(int(self._rank))
self._device = f"cuda:{int(self._rank)}"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
if self._sync_bn:
model = apex.parallel.convert_syncbn_model(model)
criterion = criterion_fn()
criterion = self.sync_device(criterion)
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
model, optimizer = amp.initialize(model, optimizer, **self.apex_kwargs)
model = ApexDistributedDataParallel(model, **self.ddp_kwargs)
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = {"amp": amp.state_dict()}
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**checkpoint,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
DataParallelApexEngine = DataParallelAPEXEngine
DistributedDataParallelApexEngine = DistributedDataParallelAPEXEngine
__all__ = [
"APEXEngine",
"DataParallelAPEXEngine",
"DistributedDataParallelAPEXEngine",
"DataParallelApexEngine",
"DistributedDataParallelApexEngine",
]