Source code for catalyst.contrib.dl.callbacks.wandb
from typing import Dict, List
import wandb
from catalyst import utils
from catalyst.core.callback import (
Callback,
CallbackNode,
CallbackOrder,
CallbackScope,
)
from catalyst.core.runner import _Runner
[docs]class WandbLogger(Callback):
"""Logger callback, translates ``runner.*_metrics`` to Weights & Biases.
Read about Weights & Biases here https://docs.wandb.com/
Example:
.. code-block:: python
from catalyst import dl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
class Projector(nn.Module):
def __init__(self, input_size):
super().__init__()
self.linear = nn.Linear(input_size, 1)
def forward(self, X):
return self.linear(X).squeeze(-1)
X = torch.rand(16, 10)
y = torch.rand(X.shape[0])
model = Projector(X.shape[1])
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=8)
runner = dl.SupervisedRunner()
runner.train(
model=model,
loaders={
"train": loader,
"valid": loader
},
criterion=nn.MSELoss(),
optimizer=optim.Adam(model.parameters()),
logdir="log_example",
callbacks=[
dl.callbacks.WandbLogger(
project="wandb_logger_example"
)
],
num_epochs=10
)
"""
[docs] def __init__(
self,
metric_names: List[str] = None,
log_on_batch_end: bool = False,
log_on_epoch_end: bool = True,
**logging_params,
):
"""
Args:
metric_names (List[str]): list of metric names to log,
if None - logs everything
log_on_batch_end (bool): logs per-batch metrics if set True
log_on_epoch_end (bool): logs per-epoch metrics if set True
**logging_params: any parameters of function `wandb.init`
except `reinit` which is automatically set to `True`
and `dir` which is set to `<logdir>`
"""
super().__init__(
order=CallbackOrder.Logging,
node=CallbackNode.Master,
scope=CallbackScope.Experiment,
)
self.metrics_to_log = metric_names
self.log_on_batch_end = log_on_batch_end
self.log_on_epoch_end = log_on_epoch_end
if not (self.log_on_batch_end or self.log_on_epoch_end):
raise ValueError("You have to log something!")
if (self.log_on_batch_end and not self.log_on_epoch_end) or (
not self.log_on_batch_end and self.log_on_epoch_end
):
self.batch_log_suffix = ""
self.epoch_log_suffix = ""
else:
self.batch_log_suffix = "_batch"
self.epoch_log_suffix = "_epoch"
self.logging_params = logging_params
def _log_metrics(
self,
metrics: Dict[str, float],
step: int,
mode: str,
suffix="",
commit=True,
):
if self.metrics_to_log is None:
metrics_to_log = sorted(metrics.keys())
else:
metrics_to_log = self.metrics_to_log
def key_locate(key: str):
"""
Wandb uses first symbol _ for it service purposes
because of that fact, we can not send original metric names
Args:
key: metric name
Returns:
formatted metric name
"""
if key.startswith("_"):
return key[1:]
return key
metrics = {
f"{key_locate(key)}/{mode}{suffix}": value
for key, value in metrics.items()
if key in metrics_to_log
}
wandb.log(metrics, step=step, commit=commit)
[docs] def on_stage_start(self, runner: _Runner):
"""Initialize Weights & Biases."""
wandb.init(**self.logging_params, reinit=True, dir=str(runner.logdir))
[docs] def on_stage_end(self, runner: _Runner):
"""Finish logging to Weights & Biases."""
wandb.join()
[docs] def on_batch_end(self, runner: _Runner):
"""Translate batch metrics to Weights & Biases."""
if self.log_on_batch_end:
mode = runner.loader_name
metrics_ = runner.batch_metrics
self._log_metrics(
metrics=metrics_,
step=runner.global_sample_step,
mode=mode,
suffix=self.batch_log_suffix,
commit=True,
)
[docs] def on_loader_end(self, runner: _Runner):
"""Translate loader metrics to Weights & Biases."""
if self.log_on_epoch_end:
mode = runner.loader_name
metrics_ = runner.loader_metrics
self._log_metrics(
metrics=metrics_,
step=runner.global_epoch,
mode=mode,
suffix=self.epoch_log_suffix,
commit=False,
)
[docs] def on_epoch_end(self, runner: _Runner):
"""Translate epoch metrics to Weights & Biases."""
extra_mode = "_base"
splitted_epoch_metrics = utils.split_dict_to_subdicts(
dct=runner.epoch_metrics,
prefixes=list(runner.loaders.keys()),
extra_key=extra_mode,
)
if self.log_on_epoch_end:
self._log_metrics(
metrics=splitted_epoch_metrics[extra_mode],
step=runner.global_epoch,
mode=extra_mode,
suffix=self.epoch_log_suffix,
commit=True,
)
__all__ = ["WandbLogger"]