Source code for catalyst.engines.amp
import torch
import torch.cuda.amp as amp
from torch.nn.parallel import DataParallel
from catalyst.engines.torch import DeviceEngine, DistributedDataParallelEngine
[docs]class AMPEngine(DeviceEngine):
"""Pytorch.AMP single training device engine.
Args:
device: used device, default is `"cuda"`.
"""
def __init__(self, device: str = "cuda"):
"""Init."""
super().__init__(device)
self.scaler = amp.GradScaler()
def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}')"
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
# TODO: should be used with forward method? (similar to criterion)
def autocast(self):
"""AMP context"""
return amp.autocast()
[docs]class DataParallelAMPEngine(AMPEngine):
"""AMP multi-gpu training device engine."""
def __init__(self):
"""Init."""
super().__init__(f"cuda:{torch.cuda.current_device()}")
self.device_count = torch.cuda.device_count()
def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(device='{self.device}')"
def init_components(
self, model_fn=None, criterion_fn=None, optimizer_fn=None, scheduler_fn=None,
):
"""Inits the runs components."""
model = model_fn()
model = self.sync_device(model)
model = DataParallel(model)
# criterion
criterion = criterion_fn()
criterion = self.sync_device(criterion)
# optimizer
optimizer = optimizer_fn()
optimizer = self.sync_device(optimizer)
# scheduler
scheduler = scheduler_fn()
scheduler = self.sync_device(scheduler)
return model, criterion, optimizer, scheduler
[docs]class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
"""Distributed AMP multi-gpu training device engine.
Args:
address: process address to use (required for PyTorch backend), default is `"localhost"`.
port: process port to listen (required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use, default is `"nccl"`.
world_size: number of processes.
"""
def __init__(
self,
address: str = "localhost",
port: str = "12345",
backend: str = "nccl",
world_size: int = None,
):
"""Init."""
super().__init__(address, port, backend, world_size)
self.scaler = amp.GradScaler()
def __repr__(self): # noqa: D105
return (
f"{self.__class__.__name__}(address={self.address}, "
f"port={self.port}, backend='{self.backend}',"
f"rank={self._rank}, world_size={self._world_size})"
)
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
self.scaler.scale(loss).backward()
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
self.scaler.step(optimizer)
self.scaler.update()
# TODO: should be used with forward method? (similar to criterion)
def autocast(self):
"""AMP context"""
return amp.autocast()
__all__ = ["AMPEngine", "DataParallelAMPEngine", "DistributedDataParallelAMPEngine"]