Shortcuts

Engines

You could check engines overview under examples/engines section.

AMP

AMPEngine

class catalyst.engines.amp.AMPEngine(device: str = 'cuda', scaler_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DeviceEngine

Pytorch.AMP single training device engine.

Parameters

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.AMPEngine("cuda:1"),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.AMPEngine("cuda:1")
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: AMPEngine
    device: cuda:1

stages:
    ...

DataParallelAMPEngine

class catalyst.engines.amp.DataParallelAMPEngine(scaler_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.amp.AMPEngine

AMP multi-gpu training device engine.

Parameters

scaler_kwargs – parameters for torch.cuda.amp.GradScaler. Possible parameters: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DataParallelAMPEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DataParallelAMPEngine()
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DataParallelAMPEngine

stages:
    ...

DistributedDataParallelAMPEngine

class catalyst.engines.amp.DistributedDataParallelAMPEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None, scaler_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DistributedDataParallelEngine

Distributed AMP multi-gpu training device engine.

Parameters
  • address – master node (rank 0)’s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, can simply be 127.0.0.1

  • port – master node (rank 0)’s free port that needs to be used for communication during distributed training

  • world_size – the number of processes to use for distributed training. Should be less or equal to the number of GPUs

  • workers_dist_rank – the rank of the first process to run on the node. It should be a number between number of initialized processes and world_size - 1, the other processes on the node wiil have ranks # of initialized processes + 1, # of initialized processes + 2, …, # of initialized processes + num_node_workers - 1

  • num_node_workers – the number of processes to launch on the node. For GPU training, this is recommended to be set to the number of GPUs on the current node so that each process can be bound to a single GPU

  • process_group_kwargs – parameters for torch.distributed.init_process_group. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group

  • sync_bn – boolean flag for batchnorm synchonization during disributed training. if True, applies PyTorch convert_sync_batchnorm to the model for native torch distributed only. Default, False.

  • ddp_kwargs – parameters for torch.nn.parallel.DistributedDataParallel. More info here: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel

  • scaler_kwargs – parameters for torch.cuda.amp.GradScaler. Possible parameters: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DistributedDataParallelAMPEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DistributedDataParallelAMPEngine(
            address="0.0.0.0",
            port=23234,
            ddp_kwargs={"find_unused_parameters": False},
            process_group_kwargs={"port": 12345},
            scaler_kwargs={"growth_factor": 1.5}
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DistributedDataParallelAMPEngine
    address: 0.0.0.0
    port: 23234
    ddp_kwargs:
        find_unused_parameters: false
    process_group_kwargs:
        port: 12345
    scaler_kwargs:
        growth_factor: 1.5

stages:
    ...

Apex

APEXEngine

class catalyst.engines.apex.APEXEngine(device: str = 'cuda', apex_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DeviceEngine

Apex single training device engine.

Parameters

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False)),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.APEXEngine(apex_kwargs=dict(opt_level="O1", keep_batchnorm_fp32=False))
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: APEXEngine
    apex_kwargs:
        opt_level: O1
        keep_batchnorm_fp32: false

stages:
    ...

DataParallelApexEngine

catalyst.engines.apex.DataParallelApexEngine

alias of catalyst.engines.apex.DataParallelAPEXEngine

DistributedDataParallelApexEngine

catalyst.engines.apex.DistributedDataParallelApexEngine

alias of catalyst.engines.apex.DistributedDataParallelAPEXEngine

DeepSpeed

DistributedDataParallelDeepSpeedEngine

class catalyst.engines.deepspeed.DistributedDataParallelDeepSpeedEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, train_batch_size: int = 256, deepspeed_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DistributedDataParallelEngine

Distributed DeepSpeed MultiGPU training device engine.

Parameters
  • address – master node (rank 0)’s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, can simply be 127.0.0.1

  • port – master node (rank 0)’s free port that needs to be used for communication during distributed training

  • world_size – the number of processes to use for distributed training. Should be less or equal to the number of GPUs

  • workers_dist_rank – the rank of the first process to run on the node. It should be a number between number of initialized processes and world_size - 1, the other processes on the node wiil have ranks # of initialized processes + 1, # of initialized processes + 2, …, # of initialized processes + num_node_workers - 1

  • num_node_workers – the number of processes to launch on the node. For GPU training, this is recommended to be set to the number of GPUs on the current node so that each process can be bound to a single GPU

  • process_group_kwargs – parameters for torch.distributed.init_process_group. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group

  • train_batch_size – shortcut for train batch size for deepspeed scaling (default: 256) for proper configuration, please use deepspeed_kwargs[‘config’] instead

  • deepspeed_kwargs – parameters for deepspeed.initialize. More info here: https://deepspeed.readthedocs.io/en/latest/initialize.html

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DistributedDataParallelDeepSpeedEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DistributedDataParallelDeepSpeedEngine(
            address="0.0.0.0",
            port=23234,
            process_group_kwargs={"port": 12345},
            deepspeed_kwargs={"config": {"train_batch_size": 64}}
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DistributedDataParallelDeepSpeedEngine
    address: 0.0.0.0
    port: 23234
    process_group_kwargs:
        port: 12345
    deepspeed_kwargs:
        config:
            train_batch_size: 64

stages:
    ...

FairScale

PipelineParallelFairScaleEngine

class catalyst.engines.fairscale.PipelineParallelFairScaleEngine(pipe_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DeviceEngine

FairScale multi-gpu training device engine.

Parameters

pipe_kwargs – parameters for fairscale.nn.Pipe. Docs for fairscale.nn.Pipe: https://fairscale.readthedocs.io/en/latest/api/nn/pipe.html

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.PipelineParallelFairScaleEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.PipelineParallelFairScaleEngine(
            pipe_kwargs={"balance": [3, 1]}
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: PipelineParallelFairScaleEngine
    pipe_kwargs:
        balance: [3, 1]

stages:
    ...

SharedDataParallelFairScaleEngine

class catalyst.engines.fairscale.SharedDataParallelFairScaleEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DistributedDataParallelEngine

Distributed FairScale MultiGPU training device engine.

Parameters

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.SharedDataParallelFairScaleEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.SharedDataParallelFairScaleEngine(
            address="0.0.0.0",
            port=23234,
            ddp_kwargs={"find_unused_parameters": False},
            process_group_kwargs={"port": 12345},
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: SharedDataParallelFairScaleEngine
    address: 0.0.0.0
    port: 23234
    ddp_kwargs:
        find_unused_parameters: false
    process_group_kwargs:
        port: 12345

stages:
    ...

SharedDataParallelFairScaleAMPEngine

class catalyst.engines.fairscale.SharedDataParallelFairScaleAMPEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None, scaler_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.fairscale.SharedDataParallelFairScaleEngine

Distributed FairScale MultiGPU training device engine.

Parameters
  • address – master node (rank 0)’s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, can simply be 127.0.0.1

  • port – master node (rank 0)’s free port that needs to be used for communication during distributed training

  • world_size – the number of processes to use for distributed training. Should be less or equal to the number of GPUs

  • workers_dist_rank – the rank of the first process to run on the node. It should be a number between number of initialized processes and world_size - 1, the other processes on the node wiil have ranks # of initialized processes + 1, # of initialized processes + 2, …, # of initialized processes + num_node_workers - 1

  • num_node_workers – the number of processes to launch on the node. For GPU training, this is recommended to be set to the number of GPUs on the current node so that each process can be bound to a single GPU

  • process_group_kwargs – parameters for torch.distributed.init_process_group. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group

  • sync_bn – boolean flag for batchnorm synchonization during disributed training. if True, applies PyTorch convert_sync_batchnorm to the model for native torch distributed only. Default, False.

  • ddp_kwargs – parameters for fairscale.nn.data_parallel.ShardedDataParallel. Docs for fairscale.nn.ShardedDataParallel: https://fairscale.readthedocs.io/en/latest/api/nn/sharded_ddp.html

  • scaler_kwargs – parameters for fairscale.optim.grad_scaler.ShardedGradScaler. Possible parameters: https://fairscale.readthedocs.io/en/latest/api/index.html

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.SharedDataParallelFairScaleAMPEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.SharedDataParallelFairScaleAMPEngine(
            address="0.0.0.0",
            port=23234,
            ddp_kwargs={"find_unused_parameters": False},
            process_group_kwargs={"port": 12345},
            scaler_kwargs={"growth_factor": 1.5}
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: SharedDataParallelFairScaleAMPEngine
    address: 0.0.0.0
    port: 23234
    ddp_kwargs:
        find_unused_parameters: false
    process_group_kwargs:
        port: 12345
    scaler_kwargs:
        growth_factor: 1.5

stages:
    ...

FullySharedDataParallelFairScaleEngine

class catalyst.engines.fairscale.FullySharedDataParallelFairScaleEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.fairscale.SharedDataParallelFairScaleEngine

Distributed FairScale MultiGPU training device engine.

Parameters

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.FullySharedDataParallelFairScaleEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.FullySharedDataParallelFairScaleEngine(
            address="0.0.0.0",
            port=23234,
            ddp_kwargs={"find_unused_parameters": False},
            process_group_kwargs={"port": 12345},
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: FullySharedDataParallelFairScaleEngine
    address: 0.0.0.0
    port: 23234
    ddp_kwargs:
        find_unused_parameters: false
    process_group_kwargs:
        port: 12345

stages:
    ...

Torch

DeviceEngine

class catalyst.engines.torch.DeviceEngine(device: str = None)[source]

Bases: catalyst.core.engine.IEngine

Single training device engine.

Parameters

device – use device, default is “cpu”.

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DeviceEngine("cuda:1"),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DeviceEngine("cuda:1")
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DeviceEngine
    device: cuda:1

stages:
    ...

DataParallelEngine

class catalyst.engines.torch.DataParallelEngine[source]

Bases: catalyst.engines.torch.DeviceEngine

MultiGPU training device engine.

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DataParallelEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DataParallelEngine()
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DataParallelEngine

stages:
    ...

DistributedDataParallelEngine

class catalyst.engines.torch.DistributedDataParallelEngine(address: str = '127.0.0.1', port: Union[str, int] = 2112, world_size: Optional[int] = None, workers_dist_rank: int = 0, num_node_workers: Optional[int] = None, process_group_kwargs: Dict[str, Any] = None, sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None)[source]

Bases: catalyst.engines.torch.DeviceEngine

Distributed MultiGPU training device engine.

Parameters
  • address – master node (rank 0)’s address, should be either the IP address or the hostname of node 0, for single node multi-proc training, can simply be 127.0.0.1

  • port – master node (rank 0)’s free port that needs to be used for communication during distributed training

  • world_size – the number of processes to use for distributed training. Should be less or equal to the number of GPUs

  • workers_dist_rank – the rank of the first process to run on the node. It should be a number between number of initialized processes and world_size - 1, the other processes on the node wiil have ranks # of initialized processes + 1, # of initialized processes + 2, …, # of initialized processes + num_node_workers - 1

  • num_node_workers – the number of processes to launch on the node. For GPU training, this is recommended to be set to the number of GPUs on the current node so that each process can be bound to a single GPU

  • process_group_kwargs – parameters for torch.distributed.init_process_group. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group

  • sync_bn – boolean flag for batchnorm synchonization during disributed training. if True, applies PyTorch convert_sync_batchnorm to the model for native torch distributed only. Default, False.

  • ddp_kwargs – parameters for torch.nn.parallel.DistributedDataParallel. More info here: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel

Examples:

from catalyst import dl

runner = dl.SupervisedRunner()
runner.train(
    engine=dl.DistributedDataParallelEngine(),
    ...
)
from catalyst import dl

class MyRunner(dl.IRunner):
    # ...
    def get_engine(self):
        return dl.DistributedDataParallelEngine(
            address="0.0.0.0",
            port=23234,
            ddp_kwargs={"find_unused_parameters": False},
            process_group_kwargs={"backend": "nccl"},
        )
    # ...
args:
    logs: ...

model:
    _target_: ...
    ...

engine:
    _target_: DistributedDataParallelEngine
    address: 0.0.0.0
    port: 23234
    ddp_kwargs:
        find_unused_parameters: false
    process_group_kwargs:
        backend: nccl

stages:
    ...

XLA

DeviceEngine

class catalyst.engines.xla.XLAEngine[source]

Bases: catalyst.engines.torch.DeviceEngine

XLA SingleTPU training device engine.

Examples:

import os
from datetime import datetime

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from catalyst import dl
from catalyst.contrib import (
    ImageToTensor, NormalizeImage, Compose, CIFAR10, ResidualBlock
)

def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


def resnet9(in_channels: int, num_classes: int, size: int = 16):
    sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
    return nn.Sequential(
        conv_block(in_channels, sz),
        conv_block(sz, sz2, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
        conv_block(sz2, sz4, pool=True),
        conv_block(sz4, sz8, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
        nn.Sequential(
            nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)
        ),
    )

class CustomRunner(dl.IRunner):
    def __init__(self, logdir):
        super().__init__()
        self._logdir = logdir

    def get_engine(self):
        return dl.XLAEngine()

    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "csv": dl.CSVLogger(logdir=self._logdir),
            "tensorboard": dl.TensorboardLogger(logdir=self._logdir),
        }

    @property
    def stages(self):
        return ["train"]

    def get_stage_len(self, stage: str) -> int:
        return 3

    def get_loaders(self, stage: str):
        transform = Compose(
            [ImageToTensor(), NormalizeImage((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        train_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
        valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)

        if self.engine.is_ddp:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=True
            )
            valid_sampler = torch.utils.data.distributed.DistributedSampler(
                valid_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=False
            )
        else:
            train_sampler = valid_sampler = None

        return {
            "train": DataLoader(train_data, batch_size=32, sampler=train_sampler),
            "valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler),
        }

    def get_model(self, stage: str):
        model = self.model                     if self.model is not None                     else resnet9(in_channels=3, num_classes=10)
        return model

    def get_criterion(self, stage: str):
        return nn.CrossEntropyLoss()

    def get_optimizer(self, stage: str, model):
        return optim.Adam(model.parameters(), lr=1e-3)

    def get_scheduler(self, stage: str, optimizer):
        return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)

    def get_callbacks(self, stage: str):
        return {
            "criterion": dl.CriterionCallback(
                metric_key="loss", input_key="logits", target_key="targets"
            ),
            "optimizer": dl.OptimizerCallback(metric_key="loss"),
            "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
            "accuracy": dl.AccuracyCallback(
                input_key="logits", target_key="targets", topk_args=(1, 3, 5)
            ),
            "checkpoint": dl.CheckpointCallback(
                self._logdir,
                loader_key="valid",
                metric_key="accuracy",
                minimize=False,
                save_n_best=1,
            ),
            "tqdm": dl.TqdmCallback(),
        }

    def handle_batch(self, batch):
        x, y = batch
        logits = self.model(x)

        self.batch = {
            "features": x,
            "targets": y,
            "logits": logits,
        }

logdir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
runner = CustomRunner(logdir)
runner.run()

DataParallelEngine

class catalyst.engines.xla.DistributedXLAEngine[source]

Bases: catalyst.engines.torch.DeviceEngine

Distributed XLA MultiTPU training device engine.

Examples:

import os
from datetime import datetime

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from catalyst import dl
from catalyst.contrib import (
    ImageToTensor, NormalizeImage, Compose, CIFAR10, ResidualBlock
)

def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


def resnet9(in_channels: int, num_classes: int, size: int = 16):
    sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
    return nn.Sequential(
        conv_block(in_channels, sz),
        conv_block(sz, sz2, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
        conv_block(sz2, sz4, pool=True),
        conv_block(sz4, sz8, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
        nn.Sequential(
            nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)
        ),
    )

class CustomRunner(dl.IRunner):
    def __init__(self, logdir):
        super().__init__()
        self._logdir = logdir

    def get_engine(self):
        return dl.DistributedXLAEngine()

    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "csv": dl.CSVLogger(logdir=self._logdir),
            "tensorboard": dl.TensorboardLogger(logdir=self._logdir),
        }

    @property
    def stages(self):
        return ["train"]

    def get_stage_len(self, stage: str) -> int:
        return 3

    def get_loaders(self, stage: str):
        transform = Compose(
            [ImageToTensor(), NormalizeImage((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        train_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
        valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)

        if self.engine.is_ddp:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=True
            )
            valid_sampler = torch.utils.data.distributed.DistributedSampler(
                valid_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=False
            )
        else:
            train_sampler = valid_sampler = None

        return {
            "train": DataLoader(train_data, batch_size=32, sampler=train_sampler),
            "valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler),
        }

    def get_model(self, stage: str):
        model = self.model                     if self.model is not None                     else resnet9(in_channels=3, num_classes=10)
        return model

    def get_criterion(self, stage: str):
        return nn.CrossEntropyLoss()

    def get_optimizer(self, stage: str, model):
        return optim.Adam(model.parameters(), lr=1e-3)

    def get_scheduler(self, stage: str, optimizer):
        return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)

    def get_callbacks(self, stage: str):
        return {
            "criterion": dl.CriterionCallback(
                metric_key="loss", input_key="logits", target_key="targets"
            ),
            "optimizer": dl.OptimizerCallback(metric_key="loss"),
            "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
            "accuracy": dl.AccuracyCallback(
                input_key="logits", target_key="targets", topk_args=(1, 3, 5)
            ),
            "checkpoint": dl.CheckpointCallback(
                self._logdir,
                loader_key="valid",
                metric_key="accuracy",
                minimize=False,
                save_n_best=1,
            ),
            "tqdm": dl.TqdmCallback(),
        }

    def handle_batch(self, batch):
        x, y = batch
        logits = self.model(x)

        self.batch = {
            "features": x,
            "targets": y,
            "logits": logits,
        }

logdir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
runner = CustomRunner(logdir)
runner.run()