Source code for catalyst.loggers.tensorboard

from typing import Dict, TYPE_CHECKING
import os

import numpy as np

from tensorboardX import SummaryWriter
import torch

from catalyst.core.logger import ILogger
from catalyst.settings import SETTINGS

    from catalyst.core.runner import IRunner

def _image_to_tensor(image: np.ndarray) -> torch.Tensor:
    Creates tensor from RGB image.

        image: RGB image stored as np.ndarray

    image = np.moveaxis(image, -1, 0)
    image = np.ascontiguousarray(image)
    image = torch.from_numpy(image)
    return image

[docs]class TensorboardLogger(ILogger): """Tensorboard logger for parameters, metrics, images and other artifacts. Args: logdir: path to logdir for tensorboard. use_logdir_postfix: boolean flag to use extra ``tensorboard`` prefix in the logdir. log_batch_metrics: boolean flag to log batch metrics (default: SETTINGS.log_batch_metrics or False). log_epoch_metrics: boolean flag to log epoch metrics (default: SETTINGS.log_epoch_metrics or True). .. note:: This logger is used by default by ``dl.Runner`` and ``dl.SupervisedRunner`` in case of specified logdir during ``runner.train(..., logdir=/path/to/logdir)``. Examples: .. code-block:: python from catalyst import dl runner = dl.SupervisedRunner() runner.train( ..., loggers={"tensorboard": dl.TensorboardLogger(logdir="./logdir/tensorboard"} ) .. code-block:: python from catalyst import dl class CustomRunner(dl.IRunner): # ... def get_loggers(self): return { "console": dl.ConsoleLogger(), "tensorboard": dl.TensorboardLogger(logdir="./logdir/tensorboard") } # ... runner = CustomRunner().run() """ def __init__( self, logdir: str, use_logdir_postfix: bool = False, log_batch_metrics: bool = SETTINGS.log_batch_metrics, log_epoch_metrics: bool = SETTINGS.log_epoch_metrics, ): """Init.""" super().__init__( log_batch_metrics=log_batch_metrics, log_epoch_metrics=log_epoch_metrics ) if use_logdir_postfix: logdir = os.path.join(logdir, "tensorboard") self.logdir = logdir self.loggers = {} os.makedirs(self.logdir, exist_ok=True) @property def logger(self): """Internal logger/experiment/etc. from the monitoring system.""" return self.loggers def _check_loader_key(self, loader_key: str): if loader_key not in self.loggers.keys(): logdir = os.path.join(self.logdir, f"{loader_key}") self.loggers[loader_key] = SummaryWriter(logdir) def _log_metrics( self, metrics: Dict[str, float], step: int, loader_key: str, suffix="" ): for key, value in metrics.items(): self.loggers[loader_key].add_scalar(f"{key}{suffix}", float(value), step) def log_image( self, tag: str, image: np.ndarray, runner: "IRunner", scope: str = None, ) -> None: """Logs image to Tensorboard for current scope on current step.""" assert runner.loader_key is not None self._check_loader_key(loader_key=runner.loader_key) tensor = _image_to_tensor(image) self.loggers[runner.loader_key].add_image( f"{tag}", tensor, global_step=runner.epoch_step ) def log_metrics( self, metrics: Dict[str, float], scope: str, runner: "IRunner", ) -> None: """Logs batch and epoch metrics to Tensorboard.""" if scope == "batch" and self.log_batch_metrics: self._check_loader_key(loader_key=runner.loader_key) # metrics = {k: float(v) for k, v in metrics.items()} self._log_metrics( metrics=metrics, step=runner.sample_step, loader_key=runner.loader_key, suffix="/batch", ) elif scope == "loader" and self.log_epoch_metrics: self._check_loader_key(loader_key=runner.loader_key) self._log_metrics( metrics=metrics, step=runner.epoch_step, loader_key=runner.loader_key, suffix="/epoch", ) elif scope == "epoch" and self.log_epoch_metrics: # @TODO: remove naming magic loader_key = "_epoch_" per_loader_metrics = metrics[loader_key] self._check_loader_key(loader_key=loader_key) self._log_metrics( metrics=per_loader_metrics, step=runner.epoch_step, loader_key=loader_key, suffix="/epoch", ) def flush_log(self) -> None: """Flushes the loggers.""" for logger in self.loggers.values(): logger.flush() def close_log(self) -> None: """Closes the loggers.""" for logger in self.loggers.values(): logger.close()
__all__ = ["TensorboardLogger"]