Shortcuts

Source code for catalyst.core.callbacks.logging

from typing import Dict, List
import logging
import os
import sys

from tqdm import tqdm

from catalyst.contrib.utils.tools.tensorboard import SummaryWriter
from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils

from . import formatters


[docs]class VerboseLogger(Callback): """Logs the params into console."""
[docs] def __init__( self, always_show: List[str] = None, never_show: List[str] = None, ): """ Args: always_show (List[str]): list of metrics to always show if None default is ``["_timer/_fps"]`` to remove always_show metrics set it to an empty list ``[]`` never_show (List[str]): list of metrics which will not be shown """ super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master) self.tqdm: tqdm = None self.step = 0 self.always_show = ( always_show if always_show is not None else ["_timer/_fps"] ) self.never_show = never_show if never_show is not None else [] intersection = set(self.always_show) & set(self.never_show) _error_message = ( f"Intersection of always_show and " f"never_show has common values: {intersection}" ) if bool(intersection): raise ValueError(_error_message)
def _need_show(self, key: str): not_is_never_shown: bool = key not in self.never_show is_always_shown: bool = key in self.always_show not_basic = not (key.startswith("_base") or key.startswith("_timers")) result = not_is_never_shown and (is_always_shown or not_basic) return result
[docs] def on_loader_start(self, state: State): """Init tqdm progress bar.""" self.step = 0 self.tqdm = tqdm( total=state.loader_len, desc=f"{state.epoch}/{state.num_epochs}" f" * Epoch ({state.loader_name})", leave=True, ncols=0, file=sys.stdout, )
[docs] def on_loader_end(self, state: State): """Cleanup and close tqdm progress bar.""" # self.tqdm.visible = False # self.tqdm.leave = True # self.tqdm.disable = True self.tqdm.clear() self.tqdm.close() self.tqdm = None self.step = 0
[docs] def on_batch_end(self, state: State): """Update tqdm progress bar at the end of each batch.""" self.tqdm.set_postfix( **{ k: "{:3.3f}".format(v) if v > 1e-3 else "{:1.3e}".format(v) for k, v in sorted(state.batch_metrics.items()) if self._need_show(k) } ) self.tqdm.update()
[docs] def on_exception(self, state: State): """Called if an Exception was raised.""" exception = state.exception if not utils.is_exception(exception): return if isinstance(exception, KeyboardInterrupt): self.tqdm.write("Early exiting") state.need_exception_reraise = False
[docs]class ConsoleLogger(Callback): """Logger callback, translates ``state.*_metrics`` to console and text file. """
[docs] def __init__(self): """Init ``ConsoleLogger``.""" super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master) self.logger = None
@staticmethod def _get_logger(logdir): logger = logging.getLogger("metrics_logger") logger.setLevel(logging.INFO) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) txt_formatter = formatters.TxtMetricsFormatter() ch.setFormatter(txt_formatter) # add the handlers to the logger logger.addHandler(ch) if logdir: fh = logging.FileHandler(f"{logdir}/log.txt") fh.setLevel(logging.INFO) fh.setFormatter(txt_formatter) logger.addHandler(fh) # logger.addHandler(jh) return logger
[docs] def on_stage_start(self, state: State): """Prepare ``state.logdir`` for the current stage.""" if state.logdir: state.logdir.mkdir(parents=True, exist_ok=True) self.logger = self._get_logger(state.logdir)
[docs] def on_stage_end(self, state: State): """Called at the end of each stage.""" for handler in self.logger.handlers: handler.close() self.logger.handlers = []
[docs] def on_epoch_end(self, state: State): """ Translate ``state.metric_manager`` to console and text file at the end of an epoch. """ self.logger.info("", extra={"state": state})
[docs]class TensorboardLogger(Callback): """Logger callback, translates ``state.metric_manager`` to tensorboard."""
[docs] def __init__( self, metric_names: List[str] = None, log_on_batch_end: bool = True, log_on_epoch_end: bool = True, ): """ Args: metric_names (List[str]): list of metric names to log, if none - logs everything log_on_batch_end (bool): logs per-batch metrics if set True log_on_epoch_end (bool): logs per-epoch metrics if set True """ super().__init__(order=CallbackOrder.Logging, node=CallbackNode.Master) self.metrics_to_log = metric_names self.log_on_batch_end = log_on_batch_end self.log_on_epoch_end = log_on_epoch_end if not (self.log_on_batch_end or self.log_on_epoch_end): raise ValueError("You have to log something!") self.loggers = {}
def _log_metrics( self, metrics: Dict[str, float], step: int, mode: str, suffix="" ): if self.metrics_to_log is None: metrics_to_log = sorted(metrics.keys()) else: metrics_to_log = self.metrics_to_log for name in metrics_to_log: if name in metrics: self.loggers[mode].add_scalar( f"{name}{suffix}", metrics[name], step )
[docs] def on_stage_start(self, state: State): """@TODO: Docs. Contribution is welcome.""" assert state.logdir is not None extra_mode = "_base" log_dir = os.path.join(state.logdir, f"{extra_mode}_log") self.loggers[extra_mode] = SummaryWriter(log_dir)
[docs] def on_loader_start(self, state: State): """Prepare tensorboard writers for the current stage.""" if state.loader_name not in self.loggers: log_dir = os.path.join(state.logdir, f"{state.loader_name}_log") self.loggers[state.loader_name] = SummaryWriter(log_dir)
[docs] def on_batch_end(self, state: State): """Translate batch metrics to tensorboard.""" if state.logdir is None: return if self.log_on_batch_end: mode = state.loader_name metrics_ = state.batch_metrics self._log_metrics( metrics=metrics_, step=state.global_step, mode=mode, suffix="/batch", )
[docs] def on_epoch_end(self, state: "State"): """Translate epoch metrics to tensorboard.""" if state.logdir is None: return if self.log_on_epoch_end: per_mode_metrics = utils.split_dict_to_subdicts( dct=state.epoch_metrics, prefixes=list(state.loaders.keys()), extra_key="_base", ) for mode, metrics in per_mode_metrics.items(): # suffix = "" if mode == "_base" else "/epoch" self._log_metrics( metrics=metrics, step=state.global_epoch, mode=mode, suffix="/epoch", ) for logger in self.loggers.values(): logger.flush()
[docs] def on_stage_end(self, state: State): """Close opened tensorboard writers.""" if state.logdir is None: return for logger in self.loggers.values(): logger.close()
__all__ = [ "ConsoleLogger", "TensorboardLogger", "VerboseLogger", ]