Source code for catalyst.callbacks.batch_transform
from typing import Any, Callable, Dict, List, Union
from functools import partial
from catalyst.core import Callback, CallbackOrder, IRunner
from catalyst.registry import REGISTRY
class _TupleWrapper:
"""Function wrapper for tuple output"""
def __init__(self, transform: Callable) -> None:
"""Init."""
self.transform = transform
def __call__(self, *inputs) -> Any:
"""Call."""
output = self.transform(*inputs)
return (output,)
[docs]class BatchTransformCallback(Callback):
"""
Preprocess your batch with specified function.
Args:
transform: Function to apply. If string will get function from registry.
scope: ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key: Keys in batch dict to apply function. Defaults to ``None``.
output_key: Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.
transform_kwargs: Kwargs for transform.
Raises:
TypeError: When keys is not str or a list.
When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
Examples:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# sample data
num_users, num_features, num_items = int(1e4), int(1e1), 10
X = torch.rand(num_users, num_features)
y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)
# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_items)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
# model training
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
num_epochs=3,
verbose=True,
callbacks=[
dl.BatchTransformCallback(
input_key="logits", output_key="scores", transform="F.sigmoid"
),
dl.CriterionCallback(
input_key="logits", target_key="targets", metric_key="loss"
),
dl.OptimizerCallback(metric_key="loss"),
dl.SchedulerCallback(),
dl.CheckpointCallback(
logdir="./logs",
loader_key="valid",
metric_key="map01",
minimize=False
),
]
)
.. code-block:: python
class CustomRunner(dl.Runner):
def handle_batch(self, batch):
logits = self.model(
batch["features"].view(batch["features"].size(0), -1)
)
loss = F.cross_entropy(logits, batch["targets"])
accuracy01, accuracy03 = metrics.accuracy(
logits, batch["targets"], topk=(1, 3)
)
self.batch_metrics.update({
"loss": loss,
"accuracy01":accuracy01,
"accuracy03": accuracy03
})
if self.is_train_loader:
self.engine.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
class MnistDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, item):
return {
"features": self.dataset[item][0],
"targets": self.dataset[item][1]
}
def __len__(self):
return len(self.dataset)
model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
loaders = {
"train": DataLoader(
MnistDataset(
MNIST(os.getcwd(), train=False)
),
batch_size=32,
),
"valid": DataLoader(
MnistDataset(
MNIST(os.getcwd(), train=False)
),
batch_size=32,
),
}
transrorms = [
augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)),
]
runner = CustomRunner()
# model training
runner.train(
model=model,
optimizer=optimizer,
loaders=loaders,
logdir="./logs",
num_epochs=5,
verbose=False,
load_best_on_end=True,
check=True,
callbacks=[
BatchTransformCallback(
transform=transrorms,
scope="on_batch_start",
input_key="features"
)
],
)
.. code-block:: yaml
...
callbacks:
transform:
_target_: BatchTransformCallback
transform: catalyst.ToTensor
scope: on_batch_start
input_key: features
"""
[docs] def __init__(
self,
transform: Union[Callable, str],
scope: str,
input_key: Union[List[str], str] = None,
output_key: Union[List[str], str] = None,
transform_kwargs: Dict[str, Any] = None,
):
"""
Preprocess your batch with specified function.
Args:
transform: Function to apply.
If string will get function from registry.
scope: ``"on_batch_end"`` (post-processing model output) or
``"on_batch_start"`` (pre-processing model input).
input_key: Keys in batch dict to apply function. Defaults to ``None``.
output_key: Keys for output.
If None then will apply function inplace to ``keys_to_apply``.
Defaults to ``None``.
transform_kwargs: Kwargs for transform.
Raises:
TypeError: When keys is not str or a list.
When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
"""
super().__init__(order=CallbackOrder.Internal)
if isinstance(transform, str):
transform = REGISTRY.get(transform)
if transform_kwargs is not None:
transform = partial(transform, **transform_kwargs)
if input_key is not None:
if not isinstance(input_key, (list, str)):
raise TypeError("input key should be str or a list of str.")
elif isinstance(input_key, str):
input_key = [input_key]
self._handle_batch = self._handle_value
else:
self._handle_batch = self._handle_key_value
output_key = output_key or input_key
if output_key is not None:
if input_key is None:
raise TypeError(
"You should define input_key in " "case if output_key is not None"
)
if not isinstance(output_key, (list, str)):
raise TypeError("output key should be str or a list of str.")
if isinstance(output_key, str):
output_key = [output_key]
transform = _TupleWrapper(transform)
if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]:
self.scope = scope
else:
raise TypeError(
'Expected scope to be on of the ["on_batch_end", "on_batch_start"]'
)
self.input_key = input_key
self.output_key = output_key
self.transform = transform
def _handle_value(self, runner):
batch_in = [runner.batch[key] for key in self.input_key]
batch_out = self.transform(*batch_in)
runner.batch.update(
**{key: value for key, value in zip(self.output_key, batch_out)}
)
def _handle_key_value(self, runner):
runner.batch = self.transform(runner.batch)
def on_batch_start(self, runner: "IRunner") -> None:
"""Event handler."""
if self.scope == "on_batch_start":
self._handle_batch(runner)
def on_batch_end(self, runner: "IRunner") -> None:
"""Event handler."""
if self.scope == "on_batch_end":
self._handle_batch(runner)
__all__ = ["BatchTransformCallback"]