Shortcuts

Source code for catalyst.contrib.dl.callbacks.wandb

from typing import Dict, List

import wandb

from catalyst import utils
from catalyst.core.callback import (
    Callback,
    CallbackNode,
    CallbackOrder,
    CallbackScope,
)
from catalyst.core.runner import _Runner


[docs]class WandbLogger(Callback): """Logger callback, translates ``runner.*_metrics`` to Weights & Biases. Read about Weights & Biases here https://docs.wandb.com/ Example: .. code-block:: python from catalyst import dl import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset class Projector(nn.Module): def __init__(self, input_size): super().__init__() self.linear = nn.Linear(input_size, 1) def forward(self, X): return self.linear(X).squeeze(-1) X = torch.rand(16, 10) y = torch.rand(X.shape[0]) model = Projector(X.shape[1]) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=8) runner = dl.SupervisedRunner() runner.train( model=model, loaders={ "train": loader, "valid": loader }, criterion=nn.MSELoss(), optimizer=optim.Adam(model.parameters()), logdir="log_example", callbacks=[ dl.callbacks.WandbLogger( project="wandb_logger_example" ) ], num_epochs=10 ) """
[docs] def __init__( self, metric_names: List[str] = None, log_on_batch_end: bool = False, log_on_epoch_end: bool = True, **logging_params, ): """ Args: metric_names (List[str]): list of metric names to log, if None - logs everything log_on_batch_end (bool): logs per-batch metrics if set True log_on_epoch_end (bool): logs per-epoch metrics if set True **logging_params: any parameters of function `wandb.init` except `reinit` which is automatically set to `True` and `dir` which is set to `<logdir>` """ super().__init__( order=CallbackOrder.Logging, node=CallbackNode.Master, scope=CallbackScope.Experiment, ) self.metrics_to_log = metric_names self.log_on_batch_end = log_on_batch_end self.log_on_epoch_end = log_on_epoch_end if not (self.log_on_batch_end or self.log_on_epoch_end): raise ValueError("You have to log something!") if (self.log_on_batch_end and not self.log_on_epoch_end) or ( not self.log_on_batch_end and self.log_on_epoch_end ): self.batch_log_suffix = "" self.epoch_log_suffix = "" else: self.batch_log_suffix = "_batch" self.epoch_log_suffix = "_epoch" self.logging_params = logging_params
def _log_metrics( self, metrics: Dict[str, float], step: int, mode: str, suffix="", commit=True, ): if self.metrics_to_log is None: metrics_to_log = sorted(metrics.keys()) else: metrics_to_log = self.metrics_to_log def key_locate(key: str): """ Wandb uses first symbol _ for it service purposes because of that fact, we can not send original metric names Args: key: metric name Returns: formatted metric name """ if key.startswith("_"): return key[1:] return key metrics = { f"{key_locate(key)}/{mode}{suffix}": value for key, value in metrics.items() if key in metrics_to_log } wandb.log(metrics, step=step, commit=commit)
[docs] def on_stage_start(self, runner: _Runner): """Initialize Weights & Biases.""" wandb.init(**self.logging_params, reinit=True, dir=str(runner.logdir))
[docs] def on_stage_end(self, runner: _Runner): """Finish logging to Weights & Biases.""" wandb.join()
[docs] def on_batch_end(self, runner: _Runner): """Translate batch metrics to Weights & Biases.""" if self.log_on_batch_end: mode = runner.loader_name metrics_ = runner.batch_metrics self._log_metrics( metrics=metrics_, step=runner.global_sample_step, mode=mode, suffix=self.batch_log_suffix, commit=True, )
[docs] def on_loader_end(self, runner: _Runner): """Translate loader metrics to Weights & Biases.""" if self.log_on_epoch_end: mode = runner.loader_name metrics_ = runner.loader_metrics self._log_metrics( metrics=metrics_, step=runner.global_epoch, mode=mode, suffix=self.epoch_log_suffix, commit=False, )
[docs] def on_epoch_end(self, runner: _Runner): """Translate epoch metrics to Weights & Biases.""" extra_mode = "_base" splitted_epoch_metrics = utils.split_dict_to_subdicts( dct=runner.epoch_metrics, prefixes=list(runner.loaders.keys()), extra_key=extra_mode, ) if self.log_on_epoch_end: self._log_metrics( metrics=splitted_epoch_metrics[extra_mode], step=runner.global_epoch, mode=extra_mode, suffix=self.epoch_log_suffix, commit=True, )
__all__ = ["WandbLogger"]