from typing import Any, Dict, Optional, Union
import torch
from torch import nn
import torch.cuda.amp as amp
from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
from catalyst.typing import Model, Optimizer
[docs]class AMPEngine(DeviceEngine):
"""Pytorch.AMP single training device engine.
Args:
device: used device, default is `"cuda"`.
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.AMPEngine("cuda:1"),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.AMPEngine("cuda:1")
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: AMPEngine
device: cuda:1
stages:
...
"""
def __init__(self, device: str = "cuda", scaler_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(device)
if scaler_kwargs is None:
scaler_kwargs = {}
# TODO: add OptimizerWithScaler abstraction?
self.scaler_kwargs = scaler_kwargs
self.scaler = amp.GradScaler(**self.scaler_kwargs)
def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self._device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
kwargs["scaler"] = self.scaler.state_dict()
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
if "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
# TODO: should be used with forward method? (similar to criterion)
def autocast(self):
"""AMP context"""
return amp.autocast()
[docs]class DataParallelAMPEngine(AMPEngine):
"""AMP multi-gpu training device engine.
Args:
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DataParallelAMPEngine(),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelAMPEngine()
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DataParallelAMPEngine
stages:
...
"""
def __init__(self, scaler_kwargs: Dict[str, Any] = None):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}", scaler_kwargs)
self.device_count = torch.cuda.device_count()
def __repr__(self) -> str: # noqa: D105
return (
f"{self.__class__.__name__}(device='{self._device}', "
f"scaler_kwargs={self.scaler_kwargs})"
)
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
if isinstance(model, nn.Module):
model = nn.DataParallel(model)
elif isinstance(model, dict):
model = {k: nn.DataParallel(v) for k, v in model.items()}
else:
raise ValueError("Model should be ``nn.Module`` or ``dict``")
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
[docs]class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
"""Distributed AMP multi-gpu training device engine.
Args:
address: master node (rank 0)'s address, should be either the IP address or the hostname
of node 0, for single node multi-proc training, can simply be 127.0.0.1
port: master node (rank 0)'s free port that needs to be used for communication
during distributed training
world_size: the number of processes to use for distributed training.
Should be less or equal to the number of GPUs
workers_dist_rank: the rank of the first process to run on the node.
It should be a number between `number of initialized processes` and `world_size - 1`,
the other processes on the node wiil have ranks `# of initialized processes + 1`,
`# of initialized processes + 2`, ...,
`# of initialized processes + num_node_workers - 1`
num_node_workers: the number of processes to launch on the node.
For GPU training, this is recommended to be set to the number of GPUs
on the current node so that each process can be bound to a single GPU
process_group_kwargs: parameters for `torch.distributed.init_process_group`.
More info here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
sync_bn: boolean flag for batchnorm synchonization during disributed training.
if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch
distributed only. Default, False.
ddp_kwargs: parameters for `torch.nn.parallel.DistributedDataParallel`.
More info here:
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel
scaler_kwargs: parameters for `torch.cuda.amp.GradScaler`.
Possible parameters:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
Examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
engine=dl.DistributedDataParallelAMPEngine(),
...
)
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelAMPEngine(
address="0.0.0.0",
port=23234,
ddp_kwargs={"find_unused_parameters": False},
process_group_kwargs={"port": 12345},
scaler_kwargs={"growth_factor": 1.5}
)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DistributedDataParallelAMPEngine
address: 0.0.0.0
port: 23234
ddp_kwargs:
find_unused_parameters: false
process_group_kwargs:
port: 12345
scaler_kwargs:
growth_factor: 1.5
stages:
...
.. _convert_sync_batchnorm:
https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
torch.nn.SyncBatchNorm.convert_sync_batchnorm
"""
def __init__(
self,
address: str = "127.0.0.1",
port: Union[str, int] = 2112,
world_size: Optional[int] = None,
workers_dist_rank: int = 0,
num_node_workers: Optional[int] = None,
process_group_kwargs: Dict[str, Any] = None,
sync_bn: bool = False,
ddp_kwargs: Dict[str, Any] = None,
scaler_kwargs: Dict[str, Any] = None,
):
"""Init."""
super().__init__(
address=address,
port=port,
world_size=world_size,
workers_dist_rank=workers_dist_rank,
num_node_workers=num_node_workers,
process_group_kwargs=process_group_kwargs,
sync_bn=sync_bn,
ddp_kwargs=ddp_kwargs,
)
if scaler_kwargs is None:
scaler_kwargs = {}
self.scaler_kwargs = scaler_kwargs
self.scaler = amp.GradScaler(**self.scaler_kwargs)
def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, "
f"ddp_kwargs={self.ddp_kwargs}, "
f"process_group_kwargs={self.process_group_kwargs}, "
f"scaler_kwargs={self.scaler_kwargs})"
)
def backward_loss(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()
def optimizer_step(self, loss: torch.Tensor, model: Model, optimizer: Optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
def pack_checkpoint(
self, model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
) -> Dict:
"""
Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
and some extra info ``**kwargs`` to torch-based checkpoint.
Args:
model: torch model
criterion: torch criterion
optimizer: torch optimizer
scheduler: torch scheduler
**kwargs: some extra info to pack
Returns:
torch-based checkpoint with ``model_state_dict``,
``criterion_state_dict``, ``optimizer_state_dict``,
``scheduler_state_dict`` keys.
"""
kwargs["scaler"] = self.scaler.state_dict()
checkpoint = super().pack_checkpoint(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
return checkpoint
def unpack_checkpoint(
self,
checkpoint: Dict,
model=None,
criterion=None,
optimizer=None,
scheduler=None,
**kwargs,
) -> None:
"""Load checkpoint from file and unpack the content to a model
(if not None), criterion (if not None), optimizer (if not None),
scheduler (if not None).
Args:
checkpoint: checkpoint to load
model: model where should be updated state
criterion: criterion where should be updated state
optimizer: optimizer where should be updated state
scheduler: scheduler where should be updated state
kwargs: extra arguments
"""
super().unpack_checkpoint(
checkpoint,
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
**kwargs,
)
if "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
def autocast(self):
"""AMP context"""
return amp.autocast()
__all__ = ["AMPEngine", "DataParallelAMPEngine", "DistributedDataParallelAMPEngine"]