from typing import Any, Dict, Optional, Union
from collections import OrderedDict
import torch
from torch import nn
from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.settings import SETTINGS
from catalyst.typing import Model, Optimizer, RunnerModel, RunnerOptimizer
from catalyst.utils.misc import get_fn_default_params
if SETTINGS.apex_required:
import apex
import apex.amp as amp
from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
def _initialize_apex(model, optimizer=None, **engine_params):
"""
Prepares model and optimizer for work with Nvidia Apex.
Args:
model: torch model
optimizer: torch optimizer
**engine_params: extra params for ``apex.amp.initialize``
Returns:
model and optimizer, wrapped with Nvidia Apex initialization
"""
amp_params = get_fn_default_params(apex.amp.initialize, ["models", "optimizers"])
amp_params["opt_level"] = "O0"
for dp in engine_params:
if dp in amp_params:
amp_params[dp] = engine_params[dp]
# NVIDIA apex support only:
# model: nn.Module or list of modules
# optimizer: None, torch.Optimizer or list of optimizers
# while key-value is preferred in the `catalyst`.
# So if model/optimizer is a dict, convert it to lists of keys
# and values first, and then cast it back after apex initialization
model_keys, optimizer_keys = None, None
if isinstance(model, dict):
model_keys, model = list(model.keys()), list(model.values())
if isinstance(optimizer, dict):
optimizer_keys = list(optimizer.keys())
optimizer = list(optimizer.values())
amp_result = apex.amp.initialize(model, optimizer, **amp_params)
if optimizer is not None:
model, optimizer = amp_result
else:
model = amp_result
# convert model/optimizer back to dict if it needed
if model_keys is not None:
model = OrderedDict([(k, v) for k, v in zip(model_keys, model)])
if optimizer_keys is not None:
optimizers = [(k, v) for k, v in zip(optimizer_keys, optimizer)]
optimizer = OrderedDict(optimizers)
return model, optimizer
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
def _patch_forward(model):
input_caster_lambda = (
lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options["cast_model_type"])
if tensor.is_floating_point()
else tensor
)
output_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options.get("cast_model_outputs", torch.float32)
)
if tensor.is_floating_point()
else tensor
)
def new_fwd(
*args,
old_fwd=model.forward,
input_caster=input_caster_lambda,
output_caster=output_caster_lambda,
**kwargs,
):
return apex.amp._initialize.applier(
old_fwd(
*apex.amp._initialize.applier(args, input_caster),
**apex.amp._initialize.applier(kwargs, input_caster),
),
output_caster,
)
model.forward = new_fwd
return model
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
# apex issue https://github.com/deepset-ai/FARM/issues/210
# solution: https://github.com/NVIDIA/apex/issues/503#issuecomment-566181771
def _wrap_into_data_parallel_with_apex(
model: RunnerModel, optimizer: RunnerOptimizer, distributed_params: Dict
):
if isinstance(model, nn.Module):
model = nn.Sequential(model)
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = torch.nn.DataParallel(model[0])
model = _patch_forward(model)
elif isinstance(model, dict):
model = {k: nn.Sequential(v) for k, v in model.items()}
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = {k: nn.DataParallel(v[0]) for k, v in model.items()}
model = {k: _patch_forward(v) for k, v in model.items()}
else:
raise ValueError("Model should be ``nn.Module`` or ``dict``")
return model, optimizer
[docs]class APEXEngine(DeviceEngine):
"""Apex single training device engine.
Args:
device: use device, default is `"cuda"`.
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False)),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False))
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: APEXEngine
apex_kwargs:
opt_level: O1
keep_batchnorm_fp32: false
stages:
...
"""
def __init__(self, device: str = "cuda", apex_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(device)
self.apex_kwargs = apex_kwargs or {}
def __repr__(self) -> str: # noqa: D105
args_list = [f"device='{self._device}'", f"apex_kwargs={self.apex_kwargs}"]
return f"{self.__class__.__name__}(" + ",".join(args_list) + ")"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
# model
model = model_fn()
# model = _patch_forward(model)
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
# from official docs:
# https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
model, optimizer = _initialize_apex(model, optimizer, **self.apex_kwargs)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
amp=amp.state_dict(),
**kwargs,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
class DataParallelAPEXEngine(APEXEngine):
"""Apex multi-gpu training device engine.
Args:
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DataParallelApexEngine(apex_kwargs=dict(opt_level="O1")),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelApexEngine(apex_kwargs=dict(opt_level="O1"))
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DataParallelApexEngine
apex_kwargs:
opt_level: O1
stages:
...
"""
def __init__(self, apex_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}", apex_kwargs)
self.device_count = torch.cuda.device_count()
def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self._device}', apex_kwargs={self.apex_kwargs})"
)
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
model, optimizer = _wrap_into_data_parallel_with_apex(
model, optimizer, distributed_params=self.apex_kwargs
)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine):
"""Distributed Apex MultiGPU training device engine.
Args:
address: master node (rank 0)'s address, should be either the IP address or the hostname
of node 0, for single node multi-proc training, can simply be 127.0.0.1
port: master node (rank 0)'s free port that needs to be used for communication
during distributed training
world_size: the number of processes to use for distributed training.
Should be less or equal to the number of GPUs
workers_dist_rank: the rank of the first process to run on the node.
It should be a number between `number of initialized processes` and `world_size - 1`,
the other processes on the node wiil have ranks `# of initialized processes + 1`,
`# of initialized processes + 2`, ...,
`# of initialized processes + num_node_workers - 1`
num_node_workers: the number of processes to launch on the node.
For GPU training, this is recommended to be set to the number of GPUs
on the current node so that each process can be bound to a single GPU
process_group_kwargs: parameters for `torch.distributed.init_process_group`.
More info here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
sync_bn: boolean flag for batchnorm synchonization during disributed training.
if True, applies Apex `convert_syncbn_model`_ to the model for native torch
distributed only. Default, False.
ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`.
More info here:
https://nvidia.github.io/apex/parallel.html#apex.parallel.DistributedDataParallel
apex_kwargs: parameters for `apex.amp.initialize`
except models and optimizers (they will be forwared automatically).
Docs for `apex.amp.initialize`:
https://nvidia.github.io/apex/amp.html#apex.amp.initialize
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DistributedDataParallelApexEngine(
ddp_kwargs={"allreduce_always_fp32": True},
process_group_kwargs={"backend": "nccl"},
apex_kwargs={"opt_level": "O1"},
),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelApexEngine(
address="0.0.0.0",
port=23234,
ddp_kwargs={"allreduce_always_fp32": True},
process_group_kwargs={"backend": "nccl"},
apex_kwargs={"opt_level": "O1"},
)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DistributedDataParallelApexEngine
address: 0.0.0.0
port: 23234
ddp_kwargs:
allreduce_always_fp32: true
process_group_kwargs:
backend: nccl
apex_kwargs:
opt_level: O1
stages:
...
.. _`convert_syncbn_model`:
https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model
"""
def __init__(
self,
address: str = "127.0.0.1",
port: Union[str, int] = 2112,
world_size: Optional[int] = None,
workers_dist_rank: int = 0,
num_node_workers: Optional[int] = None,
process_group_kwargs: Dict[str, Any] = None,
sync_bn: bool = False,
ddp_kwargs: Dict[str, Any] = None,
apex_kwargs: Dict[str, Any] = None,
):
"""Init."""
super().__init__(
address=address,
port=port,
world_size=world_size,
workers_dist_rank=workers_dist_rank,
num_node_workers=num_node_workers,
process_group_kwargs=process_group_kwargs,
sync_bn=sync_bn,
ddp_kwargs=None,
)
self.ddp_kwargs = ddp_kwargs or {}
self.apex_kwargs = apex_kwargs or {}
def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, "
f"ddp_kwargs={self.ddp_kwargs}, "
f"process_group_kwargs={self.process_group_kwargs}, "
f"apex_kwargs={self.apex_kwargs})"
)
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
if self._sync_bn:
model = apex.parallel.convert_syncbn_model(model)
criterion = criterion_fn()
criterion = self.sync_device(criterion)
optimizer = optimizer_fn(model)
optimizer = self.sync_device(optimizer)
model, optimizer = amp.initialize(model, optimizer, **self.apex_kwargs)
model = ApexDistributedDataParallel(model, **self.ddp_kwargs)
scheduler = scheduler_fn(optimizer)
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
amp=amp.state_dict(),
**kwargs,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
DataParallelApexEngine = DataParallelAPEXEngine
DistributedDataParallelApexEngine = DistributedDataParallelAPEXEngine
__all__ = [
"APEXEngine",
"DataParallelAPEXEngine",
"DistributedDataParallelAPEXEngine",
"DataParallelApexEngine",
"DistributedDataParallelApexEngine",
]