Source code for catalyst.loggers.wandb
from typing import Dict, Optional, TYPE_CHECKING
import os
import pickle
import warnings
import numpy as np
from catalyst.core.logger import ILogger
from catalyst.settings import SETTINGS
if SETTINGS.wandb_required:
import wandb
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class WandbLogger(ILogger):
"""Wandb logger for parameters, metrics, images and other artifacts.
W&B documentation: https://docs.wandb.com
Args:
Project: Name of the project in W&B to log to.
name: Name of the run in W&B to log to.
config: Configuration Dictionary for the experiment.
entity: Name of W&B entity(team) to log to.
log_batch_metrics: boolean flag to log batch metrics
(default: SETTINGS.log_batch_metrics or False).
log_epoch_metrics: boolean flag to log epoch metrics
(default: SETTINGS.log_epoch_metrics or True).
kwargs: Optional,
additional keyword arguments to be passed directly to the wandb.init
Python API examples:
.. code-block:: python
from catalyst import dl
runner = dl.SupervisedRunner()
runner.train(
...,
loggers={"wandb": dl.WandbLogger(project="wandb_test", name="expeirment_1")}
)
.. code-block:: python
from catalyst import dl
class CustomRunner(dl.IRunner):
# ...
def get_loggers(self):
return {
"console": dl.ConsoleLogger(),
"wandb": dl.WandbLogger(project="wandb_test", name="experiment_1")
}
# ...
runner = CustomRunner().run()
"""
def __init__(
self,
project: str,
name: Optional[str] = None,
entity: Optional[str] = None,
log_batch_metrics: bool = SETTINGS.log_batch_metrics,
log_epoch_metrics: bool = SETTINGS.log_epoch_metrics,
**kwargs,
) -> None:
super().__init__(
log_batch_metrics=log_batch_metrics, log_epoch_metrics=log_epoch_metrics
)
if self.log_batch_metrics:
warnings.warn(
"Wandb does NOT support several x-axes for logging."
"For this reason, everything has to be logged in the batch-based regime."
)
self.project = project
self.name = name
self.entity = entity
self.run = wandb.init(
project=self.project,
name=self.name,
entity=self.entity,
allow_val_change=True,
**kwargs,
)
@property
def logger(self):
"""Internal logger/experiment/etc. from the monitoring system."""
return self.run
def _log_metrics(
self, metrics: Dict[str, float], step: int, loader_key: str, prefix=""
):
for key, value in metrics.items():
self.run.log({f"{key}_{prefix}/{loader_key}": value}, step=step)
def log_artifact(
self,
tag: str,
runner: "IRunner",
artifact: object = None,
path_to_artifact: str = None,
scope: str = None,
) -> None:
"""Logs artifact (arbitrary file like audio, video, weights) to the logger."""
if artifact is None and path_to_artifact is None:
ValueError("Both artifact and path_to_artifact cannot be None")
artifact = wandb.Artifact(
name=self.run.id + "_aritfacts",
type="artifact",
metadata={"loader_key": runner.loader_key, "scope": scope},
)
if artifact:
art_file_dir = os.path.join("wandb", self.run.id, "artifact_dumps")
os.makedirs(art_file_dir, exist_ok=True)
art_file = open(os.path.join(art_file_dir, tag), "wb")
pickle.dump(artifact, art_file)
art_file.close()
artifact.add_file(str(os.path.join(art_file_dir, tag)))
else:
artifact.add_file(path_to_artifact)
self.run.log_artifact(artifact)
def log_image(
self,
tag: str,
image: np.ndarray,
runner: "IRunner",
scope: str = None,
) -> None:
"""Logs image to the logger."""
if scope == "batch" or scope == "loader":
log_path = "_".join(
[tag, f"epoch-{runner.epoch_step:04d}", f"loader-{runner.loader}"]
)
elif scope == "epoch":
log_path = "_".join([tag, f"epoch-{runner.epoch_step:04d}"])
elif scope == "experiment" or scope is None:
log_path = tag
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
self.run.log({f"{log_path}.png": wandb.Image(image)}, step=step)
def log_hparams(self, hparams: Dict, runner: "IRunner" = None) -> None:
"""Logs hyperparameters to the logger."""
self.run.config.update(hparams)
def log_metrics(
self,
metrics: Dict[str, float],
scope: str,
runner: "IRunner",
) -> None:
"""Logs batch and epoch metrics to wandb."""
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
if scope == "batch" and self.log_batch_metrics:
metrics = {k: float(v) for k, v in metrics.items()}
self._log_metrics(
metrics=metrics,
step=step,
loader_key=runner.loader_key,
prefix="batch",
)
elif scope == "loader" and self.log_epoch_metrics:
self._log_metrics(
metrics=metrics,
step=step,
loader_key=runner.loader_key,
prefix="epoch",
)
elif scope == "epoch" and self.log_epoch_metrics:
loader_key = "_epoch_"
per_loader_metrics = metrics[loader_key]
self._log_metrics(
metrics=per_loader_metrics,
step=step,
loader_key=loader_key,
prefix="epoch",
)
def flush_log(self) -> None:
"""Flushes the logger."""
pass
def close_log(self, scope: str = None) -> None:
"""Closes the logger."""
self.run.finish()
__all__ = ["WandbLogger"]