from typing import Any, Callable, Dict, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader
from catalyst.engines.torch import DeviceEngine
from catalyst.settings import SETTINGS
if SETTINGS.xla_required:
import torch_xla.core.xla_model as xm
from torch_xla.distributed.parallel_loader import ParallelLoader
import torch_xla.distributed.xla_multiprocessing as xmp
[docs]class XLAEngine(DeviceEngine):
"""XLA SingleTPU training device engine.
Examples:
.. code-block:: python
import os
from datetime import datetime
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.contrib import (
ImageToTensor, NormalizeImage, Compose, CIFAR10, ResidualBlock
)
def conv_block(in_channels, out_channels, pool=False):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if pool:
layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def resnet9(in_channels: int, num_classes: int, size: int = 16):
sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
conv_block(sz4, sz8, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
nn.Sequential(
nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)
),
)
class CustomRunner(dl.IRunner):
def __init__(self, logdir):
super().__init__()
self._logdir = logdir
def get_engine(self):
return dl.XLAEngine()
def get_loggers(self):
return {
"console": dl.ConsoleLogger(),
"csv": dl.CSVLogger(logdir=self._logdir),
"tensorboard": dl.TensorboardLogger(logdir=self._logdir),
}
@property
def stages(self):
return ["train"]
def get_stage_len(self, stage: str) -> int:
return 3
def get_loaders(self, stage: str):
transform = Compose(
[ImageToTensor(), NormalizeImage((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
if self.engine.is_ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=True
)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=False
)
else:
train_sampler = valid_sampler = None
return {
"train": DataLoader(train_data, batch_size=32, sampler=train_sampler),
"valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler),
}
def get_model(self, stage: str):
model = self.model \
if self.model is not None \
else resnet9(in_channels=3, num_classes=10)
return model
def get_criterion(self, stage: str):
return nn.CrossEntropyLoss()
def get_optimizer(self, stage: str, model):
return optim.Adam(model.parameters(), lr=1e-3)
def get_scheduler(self, stage: str, optimizer):
return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)
def get_callbacks(self, stage: str):
return {
"criterion": dl.CriterionCallback(
metric_key="loss", input_key="logits", target_key="targets"
),
"optimizer": dl.OptimizerCallback(metric_key="loss"),
"scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
"accuracy": dl.AccuracyCallback(
input_key="logits", target_key="targets", topk_args=(1, 3, 5)
),
"checkpoint": dl.CheckpointCallback(
self._logdir,
loader_key="valid",
metric_key="accuracy",
minimize=False,
save_n_best=1,
),
"tqdm": dl.TqdmCallback(),
}
def handle_batch(self, batch):
x, y = batch
logits = self.model(x)
self.batch = {
"features": x,
"targets": y,
"logits": logits,
}
logdir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
runner = CustomRunner(logdir)
runner.run()
"""
def __init__(self):
"""Init."""
super().__init__()
self._device = xm.xla_device()
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
xm.optimizer_step(optimizer, barrier=True)
[docs]class DistributedXLAEngine(DeviceEngine):
"""Distributed XLA MultiTPU training device engine.
Examples:
.. code-block:: python
import os
from datetime import datetime
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.contrib import (
ImageToTensor, NormalizeImage, Compose, CIFAR10, ResidualBlock
)
def conv_block(in_channels, out_channels, pool=False):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if pool:
layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def resnet9(in_channels: int, num_classes: int, size: int = 16):
sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
conv_block(sz4, sz8, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
nn.Sequential(
nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)
),
)
class CustomRunner(dl.IRunner):
def __init__(self, logdir):
super().__init__()
self._logdir = logdir
def get_engine(self):
return dl.DistributedXLAEngine()
def get_loggers(self):
return {
"console": dl.ConsoleLogger(),
"csv": dl.CSVLogger(logdir=self._logdir),
"tensorboard": dl.TensorboardLogger(logdir=self._logdir),
}
@property
def stages(self):
return ["train"]
def get_stage_len(self, stage: str) -> int:
return 3
def get_loaders(self, stage: str):
transform = Compose(
[ImageToTensor(), NormalizeImage((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
if self.engine.is_ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=True
)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=False
)
else:
train_sampler = valid_sampler = None
return {
"train": DataLoader(train_data, batch_size=32, sampler=train_sampler),
"valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler),
}
def get_model(self, stage: str):
model = self.model \
if self.model is not None \
else resnet9(in_channels=3, num_classes=10)
return model
def get_criterion(self, stage: str):
return nn.CrossEntropyLoss()
def get_optimizer(self, stage: str, model):
return optim.Adam(model.parameters(), lr=1e-3)
def get_scheduler(self, stage: str, optimizer):
return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)
def get_callbacks(self, stage: str):
return {
"criterion": dl.CriterionCallback(
metric_key="loss", input_key="logits", target_key="targets"
),
"optimizer": dl.OptimizerCallback(metric_key="loss"),
"scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
"accuracy": dl.AccuracyCallback(
input_key="logits", target_key="targets", topk_args=(1, 3, 5)
),
"checkpoint": dl.CheckpointCallback(
self._logdir,
loader_key="valid",
metric_key="accuracy",
minimize=False,
save_n_best=1,
),
"tqdm": dl.TqdmCallback(),
}
def handle_batch(self, batch):
x, y = batch
logits = self.model(x)
self.batch = {
"features": x,
"targets": y,
"logits": logits,
}
logdir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
runner = CustomRunner(logdir)
runner.run()
"""
def __init__(self):
"""Init."""
super().__init__()
self._device = None
self._rank = 0
self._world_size = 8
self._backend = "xla"
@property
def rank(self) -> int:
"""Process rank for distributed training."""
return self._rank
@property
def world_size(self) -> int:
"""Process world size for distributed training."""
return self._world_size
@property
def backend(self) -> Optional[str]:
"""String identifier for distributed backend."""
return self._backend
def barrier(self) -> None:
"""
Synchronizes all processes.
This collective blocks processes until the all runs enter the function.
"""
xm.rendezvous("barrier")
def spawn(self, fn: Callable, *args: Any, **kwargs: Any) -> None:
"""Spawns abstraction for``nprocs`` creation with specified ``fn`` and ``args``/``kwargs``.
Args:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
*args: Arguments passed to spawn method.
**kwargs: Keyword-arguments passed to spawn method.
Returns:
wrapped function.
"""
return xmp.spawn(
fn, args=(self._world_size,), nprocs=self._world_size, start_method="fork"
)
def setup_process(self, rank: int = -1, world_size: int = 1):
"""Initialize DDP variables and processes.
Args:
rank: process rank. Default is `-1`.
world_size: number of devices in netwok to expect for train. Default is `1`.
"""
self._rank = rank
self._world_size = world_size
self._device = xm.xla_device()
def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor:
"""Syncs ``tensor`` over ``world_size`` in distributed mode.
Args:
tensor: tensor to sync across the processes.
mode: tensor synchronization type,
should be one of 'sum' or 'mean'.
Default is 'mean'.
Returns:
torch.Tensor with synchronized values.
Raises:
ValueError: if mode is out of ``sum``, ``mean``.
"""
# return tensor
if mode not in {"sum", "mean"}:
raise ValueError(f"Unknown sync_type '{mode}'")
if mode == "sum":
return xm.all_reduce("sum", tensor)
elif mode == "mean":
return xm.all_reduce("sum", tensor, scale=1.0 / self.world_size)
def sync_metrics(self, metrics: Dict) -> Dict:
"""Syncs ``metrics`` over ``world_size`` in the distributed mode."""
metrics = {
k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean)
for k, v in metrics.items()
}
return metrics
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
xm.optimizer_step(optimizer)
def autocast_loader(self, loader: DataLoader):
"""Loader wrapper for the distributed mode."""
return ParallelLoader(loader, [self.device]).per_device_loader(self.device)
__all__ = ["XLAEngine", "DistributedXLAEngine"]