Source code for catalyst.dl.callbacks.logging

from typing import Dict, List  # isort:skip
import logging
import os
import sys

from tqdm import tqdm

from catalyst.dl import utils
from catalyst.dl.core import LoggerCallback, RunnerState
from catalyst.dl.utils.formatters import TxtMetricsFormatter
from catalyst.utils.tensorboard import SummaryWriter


[docs]class VerboseLogger(LoggerCallback):
[docs] def __init__( self, always_show: List[str] = None, never_show: List[str] = None ): """ Logs the params into console Args: always_show (List[str]): list of metrics to always show if None default is ``["_timers/_fps"]`` to remove always_show metrics set it to an empty list ``[]`` never_show (List[str]): list of metrics which will not be shown """ super().__init__() self.tqdm: tqdm = None self.step = 0 self.always_show = always_show \ if always_show is not None else ["_timers/_fps"] self.never_show = never_show if never_show is not None else [] intersection = set(self.always_show) & set(self.never_show) _error_message = ( f"Intersection of always_show and " f"never_show has common values: {intersection}" ) if bool(intersection): raise ValueError(_error_message)
def _need_show(self, key: str): not_is_never_shown: bool = key not in self.never_show is_always_shown: bool = key in self.always_show not_basic = not (key.startswith("_base") or key.startswith("_timers")) result = not_is_never_shown and (is_always_shown or not_basic) return result
[docs] def on_loader_start(self, state: RunnerState): self.step = 0 self.tqdm = tqdm( total=state.loader_len, desc=f"{state.stage_epoch_log}/{state.num_epochs}" f" * Epoch ({state.loader_name})", leave=True, ncols=0, file=sys.stdout )
[docs] def on_batch_end(self, state: RunnerState): self.tqdm.set_postfix( **{ k: "{:3.3f}".format(v) if v > 1e-3 else "{:1.3e}".format(v) for k, v in sorted(state.metrics.batch_values.items()) if self._need_show(k) } ) self.tqdm.update()
[docs] def on_loader_end(self, state: RunnerState): self.tqdm.close() self.tqdm = None self.step = 0
[docs] def on_exception(self, state: RunnerState): exception = state.exception if not utils.is_exception(exception): return if isinstance(exception, KeyboardInterrupt): self.tqdm.write("Early exiting") state.need_reraise_exception = False
[docs]class ConsoleLogger(LoggerCallback): """ Logger callback, translates ``state.metrics`` to console and text file """ def __init__(self): super().__init__() self.logger = None @staticmethod def _get_logger(logdir): logger = logging.getLogger("metrics_logger") logger.setLevel(logging.INFO) fh = logging.FileHandler(f"{logdir}/log.txt") fh.setLevel(logging.INFO) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) # @TODO: fix json logger # jh = logging.FileHandler(f"{logdir}/metrics.json") # jh.setLevel(logging.INFO) txt_formatter = TxtMetricsFormatter() # json_formatter = JsonMetricsFormatter() fh.setFormatter(txt_formatter) ch.setFormatter(txt_formatter) # jh.setFormatter(json_formatter) # add the handlers to the logger logger.addHandler(fh) logger.addHandler(ch) # logger.addHandler(jh) return logger
[docs] def on_stage_start(self, state: RunnerState): assert state.logdir is not None state.logdir.mkdir(parents=True, exist_ok=True) self.logger = self._get_logger(state.logdir)
[docs] def on_stage_end(self, state): for handler in self.logger.handlers: handler.close() self.logger.handlers = []
[docs] def on_epoch_end(self, state): self.logger.info("", extra={"state": state})
[docs]class TensorboardLogger(LoggerCallback): """ Logger callback, translates state.metrics to tensorboard """
[docs] def __init__( self, metric_names: List[str] = None, log_on_batch_end: bool = True, log_on_epoch_end: bool = True ): """ Args: metric_names: List of metric names to log. If none - logs everything. log_on_batch_end: Logs per-batch metrics if set True. log_on_epoch_end: Logs per-epoch metrics if set True. """ super().__init__() self.metrics_to_log = metric_names self.log_on_batch_end = log_on_batch_end self.log_on_epoch_end = log_on_epoch_end assert self.log_on_batch_end or self.log_on_epoch_end, \ "You have to log something!" self.loggers = dict()
def _log_metrics( self, metrics: Dict[str, float], step: int, mode: str, suffix="" ): if self.metrics_to_log is None: metrics_to_log = sorted(list(metrics.keys())) else: metrics_to_log = self.metrics_to_log for name in metrics_to_log: if name in metrics: self.loggers[mode].add_scalar( f"{name}{suffix}", metrics[name], step )
[docs] def on_loader_start(self, state): lm = state.loader_name if lm not in self.loggers: log_dir = os.path.join(state.logdir, f"{lm}_log") self.loggers[lm] = SummaryWriter(log_dir)
[docs] def on_batch_end(self, state: RunnerState): if self.log_on_batch_end: mode = state.loader_name metrics_ = state.metrics.batch_values self._log_metrics( metrics=metrics_, step=state.step, mode=mode, suffix="/batch" )
[docs] def on_loader_end(self, state: RunnerState): if self.log_on_epoch_end: mode = state.loader_name metrics_ = state.metrics.epoch_values[mode] self._log_metrics( metrics=metrics_, step=state.epoch_log, mode=mode, suffix="/epoch" ) for logger in self.loggers.values(): logger.flush()
[docs] def on_stage_end(self, state: RunnerState): for logger in self.loggers.values(): logger.close()
__all__ = [ "VerboseLogger", "ConsoleLogger", "TensorboardLogger" ]