Source code for catalyst.callbacks.quantization

from typing import Dict, Optional, TYPE_CHECKING, Union
from pathlib import Path

import torch

from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils import quantize_model

    from catalyst.core import IRunner

[docs]class QuantizationCallback(Callback): """ Callback for model quantiztion. Args: logdir: path to folder for saving filename: filename qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None. dtype (Union[str, Optional[torch.dtype]], optional): Type of weights after quantization. Defaults to "qint8". Example: .. code-block:: python import os import torch from torch import nn from import DataLoader from catalyst import dl from import ToTensor from catalyst.contrib.datasets import MNIST from catalyst.contrib.nn.modules import Flatten loaders = { "train": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32, ), "valid": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32, ), } model = nn.Sequential(Flatten(), nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10)) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) runner = dl.SupervisedRunner() runner.train( model=model, callbacks=[dl.QuantizationCallback(logdir="./logs")], loaders=loaders, criterion=criterion, optimizer=optimizer, num_epochs=1, logdir="./logs", ) """
[docs] def __init__( self, logdir: Union[str, Path] = None, filename: str = "quantized.pth", qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8", ): """Init.""" super().__init__( order=CallbackOrder.ExternalExtra, node=CallbackNode.master ) # External Extra for applying # after CheckpointCallback; node master for saving. self.qconfig_spec = qconfig_spec self.dtype = dtype if logdir is not None: self.filename = Path(logdir) / filename else: self.filename = filename
def on_stage_end(self, runner: "IRunner") -> None: """Event handler.""" q_model = quantize_model( runner.model.cpu(), qconfig_spec=self.qconfig_spec, dtype=self.dtype ) checkpoint = runner.engine.pack_checkpoint(model=q_model) runner.engine.save_checkpoint(checkpoint=checkpoint, path=self.filename)
__all__ = ["QuantizationCallback"]