Source code for catalyst.callbacks.metrics.r2_squared
from catalyst.callbacks.metric import LoaderMetricCallback
from catalyst.metrics._r2_squared import R2Squared
[docs]class R2SquaredCallback(LoaderMetricCallback):
"""R2 Squared metric callback.
Args:
input_key: input key to use for r2squared calculation, specifies our ``y_true``
target_key: output key to use for r2squared calculation, specifies our ``y_pred``
prefix: metric prefix
suffix: metric suffix
Examples:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
# model training
runner = dl.SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
num_epochs=8,
verbose=True,
callbacks=[
dl.R2SquaredCallback(input_key="logits", target_key="targets")
]
)
.. note::
Please follow the `minimal examples`_ sections for more use cases.
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505
"""
def __init__(
self,
input_key: str,
target_key: str,
prefix: str = None,
suffix: str = None,
):
"""Init."""
super().__init__(
metric=R2Squared(prefix=prefix, suffix=suffix),
input_key=input_key,
target_key=target_key,
)
__all__ = ["R2SquaredCallback"]