from typing import Dict, Union
from collections import OrderedDict
import torch
from torch import nn
from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.settings import SETTINGS
from catalyst.typing import RunnerModel, RunnerOptimizer
from catalyst.utils.misc import get_fn_default_params
if SETTINGS.apex_required:
import apex
import apex.amp as amp
from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
def _initialize_apex(model, optimizer=None, **engine_params):
"""
Prepares model and optimizer for work with Nvidia Apex.
Args:
model: torch model
optimizer: torch optimizer
**engine_params: extra params for ``apex.amp.initialize``
Returns:
model and optimizer, wrapped with Nvidia Apex initialization
"""
amp_params = get_fn_default_params(apex.amp.initialize, ["models", "optimizers"])
amp_params["opt_level"] = "O0"
for dp in engine_params:
if dp in amp_params:
amp_params[dp] = engine_params[dp]
# NVIDIA apex support only:
# model: nn.Module or list of modules
# optimizer: None, torch.Optimizer or list of optimizers
# while key-value is preferred in the `catalyst`.
# So if model/optimizer is a dict, convert it to lists of keys
# and values first, and then cast it back after apex initialization
model_keys, optimizer_keys = None, None
if isinstance(model, dict):
model_keys, model = list(model.keys()), list(model.values())
if isinstance(optimizer, dict):
optimizer_keys = list(optimizer.keys())
optimizer = list(optimizer.values())
amp_result = apex.amp.initialize(model, optimizer, **amp_params)
if optimizer is not None:
model, optimizer = amp_result
else:
model = amp_result
# convert model/optimizer back to dict if it needed
if model_keys is not None:
model = OrderedDict([(k, v) for k, v in zip(model_keys, model)])
if optimizer_keys is not None:
optimizers = [(k, v) for k, v in zip(optimizer_keys, optimizer)]
optimizer = OrderedDict(optimizers)
return model, optimizer
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
def _patch_forward(model):
input_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options["cast_model_type"]
) # noqa: WPS437
if tensor.is_floating_point()
else tensor
)
output_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options.get(
"cast_model_outputs", torch.float32
) # noqa: WPS437
)
if tensor.is_floating_point()
else tensor
)
def new_fwd(
*args,
old_fwd=model.forward,
input_caster=input_caster_lambda,
output_caster=output_caster_lambda,
**kwargs,
):
return apex.amp._initialize.applier( # noqa: WPS437
old_fwd(
*apex.amp._initialize.applier(args, input_caster), # noqa: WPS437
**apex.amp._initialize.applier(kwargs, input_caster), # noqa: WPS437
),
output_caster,
)
model.forward = new_fwd
return model
# taken form https://github.com/catalyst-team/catalyst/blob/master/catalyst/utils/components.py
# apex issue https://github.com/deepset-ai/FARM/issues/210
# solution: https://github.com/NVIDIA/apex/issues/503#issuecomment-566181771
def _wrap_into_data_parallel_with_apex(
model: RunnerModel, optimizer: RunnerOptimizer, distributed_params: Dict
):
if isinstance(model, nn.Module):
model = nn.Sequential(model)
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = torch.nn.DataParallel(model[0])
model = _patch_forward(model)
elif isinstance(model, dict):
model = {k: nn.Sequential(v) for k, v in model.items()}
model, optimizer = _initialize_apex(model, optimizer, **distributed_params)
model = {k: nn.DataParallel(v[0]) for k, v in model.items()}
model = {k: _patch_forward(v) for k, v in model.items()}
else:
raise NotImplementedError()
return model, optimizer
[docs]class APEXEngine(DeviceEngine):
"""Apex single training device engine.
Args:
device: use device, default is `"cuda"`.
opt_level: optimization level, should be one of ``"O0"``,
``"O1"``, ``"O2"`` or ``"O3"``.
- ``"O0"`` - no-op training
- ``"O1"`` - mixed precision (FP16) training (default)
- ``"O2"`` - "almost" mixed precision training
- ``"O3"`` - another implementation of mixed precision training
Details about levels can be found here:
https://nvidia.github.io/apex/amp.html#opt-levels
keep_batchnorm_fp32: To enhance precision and enable CUDNN batchnorm
(which improves performance),
it’s often beneficial to keep batchnorm weights in FP32 even
if the rest of the model is FP16.
loss_scale: If loss_scale is a float value,
use this value as the static (fixed) loss scale
If loss_scale is the string "dynamic",
adaptively adjust the loss scale over time.
Dynamic loss scale adjustments are performed by Amp automatically.
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.APEXEngine(opt_level="O1", keep_batchnorm_fp32=False)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: APEXEngine
opt_level: O1
keep_batchnorm_fp32: false
stages:
...
"""
def __init__(
self,
device: str = "cuda",
opt_level: str = "O1",
keep_batchnorm_fp32: bool = None,
loss_scale: Union[float, str] = None,
):
"""Init."""
super().__init__(device)
self.opt_level = opt_level
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
self.loss_scale = loss_scale
def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}',opt_level='{self.opt_level}')"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
):
"""Inits the runs components."""
# model
model = model_fn()
# model = _patch_forward(model)
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
# from official docs:
# https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
model, optimizer = _initialize_apex(
model,
optimizer,
opt_level=self.opt_level,
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
loss_scale=self.loss_scale,
)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs,
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = {"amp": amp.state_dict()}
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**checkpoint,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
[docs]class DataParallelApexEngine(APEXEngine):
"""Apex multi-gpu training device engine.
Args:
opt_level: optimization level, should be one of ``"O0"``,
``"O1"``, ``"O2"`` or ``"O3"``.
- ``"O0"`` - no-op training
- ``"O1"`` - mixed precision (FP16) training (default)
- ``"O2"`` - "almost" mixed precision training
- ``"O3"`` - another implementation of mixed precision training
Details about levels can be found here:
https://nvidia.github.io/apex/amp.html#opt-levels
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelApexEngine(opt_level="O1")
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DataParallelApexEngine
opt_level: O1
stages:
...
"""
def __init__(self, opt_level: str = "O1"):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}", opt_level)
self.device_count = torch.cuda.device_count()
def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}',opt_level='{self.opt_level}')"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
model, optimizer = _wrap_into_data_parallel_with_apex(
model, optimizer, distributed_params={"opt_level": self.opt_level}
)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
[docs]class DistributedDataParallelApexEngine(DistributedDataParallelEngine):
"""Distributed Apex MultiGPU training device engine.
Args:
address: process address to use
(required for PyTorch backend), default is `"localhost"`.
port: process port to listen
(required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use,
default is `"nccl"`.
world_size: number of processes.
opt_level: optimization level, should be one of ``"O0"``,
``"O1"``, ``"O2"`` or ``"O3"``.
- ``"O0"`` - no-op training
- ``"O1"`` - mixed precision (FP16) training (default)
- ``"O2"`` - "almost" mixed precision training
- ``"O3"`` - another implementation of mixed precision training
Details about levels can be found here:
https://nvidia.github.io/apex/amp.html#opt-levels
keep_batchnorm_fp32: To enhance precision and
enable CUDNN batchnorm (which improves performance),
it’s often beneficial to keep batchnorm weights in FP32 even
if the rest of the model is FP16.
loss_scale: If loss_scale is a float value,
use this value as the static (fixed) loss scale.
If loss_scale is the string "dynamic",
adaptively adjust the loss scale over time.
Dynamic loss scale adjustments are performed by Amp automatically.
delay_all_reduce (bool): boolean flag for delayed all reduce,
default is `True`.
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelApexEngine(
port=12345,
opt_level="O1"
)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DistributedDataParallelApexEngine
port: 12345
opt_level: O1
stages:
...
"""
def __init__(
self,
address: str = "localhost",
port: str = "12345",
backend: str = "nccl",
world_size: int = None,
opt_level: str = "O1",
keep_batchnorm_fp32: bool = None,
loss_scale: Union[float, str] = None,
delay_all_reduce: bool = True,
):
"""Init."""
super().__init__()
self.address = address
self.port = port
self.backend = backend
self._rank = 0
self._world_size = world_size or torch.cuda.device_count()
self.device = None
self.opt_level = opt_level
self.delay_all_reduce = delay_all_reduce
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
self.loss_scale = loss_scale
def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, backend='{self.backend}', "
f"rank={self._rank}, world_size={self._world_size}, "
f"opt_level='{self.opt_level}')"
)
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
criterion = criterion_fn()
criterion = self.sync_device(criterion)
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
model, optimizer = amp.initialize(
model,
optimizer,
opt_level=self.opt_level,
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
loss_scale=self.loss_scale,
)
model = ApexDistributedDataParallel(model, delay_allreduce=self.delay_all_reduce)
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs,
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
checkpoint = {"amp": amp.state_dict()}
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**checkpoint,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
# NOTE: propper way to load state, docs:
# https://nvidia.github.io/apex/amp.html#checkpointing
if "amp" in checkpoint:
amp.load_state_dict(checkpoint["amp"])
__all__ = ["APEXEngine", "DataParallelApexEngine", "DistributedDataParallelApexEngine"]