Source code for catalyst.callbacks.batch_transform
from typing import Callable, List, Union
from catalyst.core import Callback, CallbackOrder, IRunner
def _tuple_wrapper(transform: Callable):
def wrapper(*inputs):
"""Function wrapper for tuple output"""
output = transform(*inputs)
return (output,)
return wrapper
[docs]class BatchTransformCallback(Callback):
"""
Preprocess your batch with specified function.
Args:
transform (Callable): Function to apply.
scope (str): ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
Defaults to ``None``.
output_key (Union[List[str], str], optional): Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.
Raises:
TypeError: When keys is not str or a list.
When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
Examples:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# sample data
num_users, num_features, num_items = int(1e4), int(1e1), 10
X = torch.rand(num_users, num_features)
y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)
# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_items)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
# model training
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
num_epochs=3,
verbose=True,
callbacks=[
dl.LambdaPreprocessCallback(
input_key="logits", output_key="scores", transform=torch.sigmoid
),
dl.CriterionCallback(
input_key="logits", target_key="targets", metric_key="loss"
),
# uncomment for extra metrics:
# dl.AUCCallback(
# input_key="scores", target_key="targets"
# ),
# dl.HitrateCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
# dl.MRRCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
# dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
# dl.NDCGCallback(
# input_key="scores", target_key="targets", topk_args=(1, 3, 5)
# ),
dl.OptimizerCallback(metric_key="loss"),
dl.SchedulerCallback(),
dl.CheckpointCallback(
logdir="./logs", loader_key="valid", metric_key="map01", minimize=False
),
]
)
.. code-block:: python
class CustomRunner(dl.Runner):
def handle_batch(self, batch):
logits = self.model(
batch["features"].view(batch["features"].size(0), -1)
)
loss = F.cross_entropy(logits, batch["targets"])
accuracy01, accuracy03 = metrics.accuracy(
logits, batch["targets"], topk=(1, 3)
)
self.batch_metrics.update(
{"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03}
)
if self.is_train_loader:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
class MnistDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, item):
return {"features": self.dataset[item][0], "targets": self.dataset[item][1]}
def __len__(self):
return len(self.dataset)
model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
loaders = {
"train": DataLoader(
MnistDataset(
MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())
),
batch_size=32,
),
"valid": DataLoader(
MnistDataset(
MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())
),
batch_size=32,
),
}
transrorms = [
augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)),
]
runner = CustomRunner()
# model training
runner.train(
model=model,
optimizer=optimizer,
loaders=loaders,
logdir="./logs",
num_epochs=5,
verbose=False,
load_best_on_end=True,
check=True,
callbacks=[
BatchTransformCallback(
transform=transrorms, scope="on_batch_start", input_key="features"
)
],
)
"""
[docs] def __init__(
self,
transform: Callable,
scope: str,
input_key: Union[List[str], str] = None,
output_key: Union[List[str], str] = None,
):
"""
Preprocess your batch with specified function.
Args:
transform (Callable): Function to apply.
scope (str): ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key (Union[List[str], str], optional): Keys in batch dict to apply function.
Defaults to ``None``.
output_key (Union[List[str], str], optional): Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.
Raises:
TypeError: When keys is not str or a list.
When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
"""
super().__init__(order=CallbackOrder.Internal)
if input_key is not None:
if not isinstance(input_key, (list, str)):
raise TypeError("input key should be str or a list of str.")
elif isinstance(input_key, str):
input_key = [input_key]
self._handle_batch = self._handle_value
else:
self._handle_batch = self._handle_key_value
output_key = output_key or input_key
if output_key is not None:
if input_key is None:
raise TypeError("You should define input_key in " "case if output_key is not None")
if not isinstance(output_key, (list, str)):
raise TypeError("output key should be str or a list of str.")
if isinstance(output_key, str):
output_key = [output_key]
transform = _tuple_wrapper(transform)
if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]:
self.scope = scope
else:
raise TypeError('Expected scope to be on of the ["on_batch_end", "on_batch_start"]')
self.input_key = input_key
self.output_key = output_key
self.transform = transform
def _handle_value(self, runner):
batch_in = [runner.batch[key] for key in self.input_key]
batch_out = self.transform(*batch_in)
runner.batch.update(**{key: value for key, value in zip(self.output_key, batch_out)})
def _handle_key_value(self, runner):
runner.batch = self.transform(runner.batch)
def on_batch_start(self, runner: "IRunner") -> None:
"""On batch start action.
Args:
runner: runner for the experiment.
"""
if self.scope == "on_batch_start":
self._handle_batch(runner)
def on_batch_end(self, runner: "IRunner") -> None:
"""On batch end action.
Args:
runner: runner for the experiment.
"""
if self.scope == "on_batch_end":
self._handle_batch(runner)
__all__ = ["BatchTransformCallback"]