Source code for catalyst.contrib.dl.runner.wandb
from typing import Dict, List # isort:skip
from pathlib import Path
import shutil
import warnings
from deprecation import DeprecatedWarning
import wandb
from catalyst.dl import utils
from catalyst.dl.core import Experiment, Runner
from catalyst.dl.experiment import ConfigExperiment
from catalyst.dl.runner import SupervisedRunner
warnings.simplefilter("always")
[docs]class WandbRunner(Runner):
"""
Runner wrapper with wandb integration hooks.
"""
@staticmethod
def _log_metrics(metrics: Dict, mode: str, suffix: str = ""):
def key_locate(key: str):
"""
Wandb uses first symbol _ for it service purposes
because of that fact, we can not send original metric names
Args:
key: metric name
Returns:
formatted metric name
"""
if key.startswith("_"):
return key[1:]
return key
metrics = {
f"{key_locate(key)}/{mode}{suffix}": value
for key, value in metrics.items()
}
wandb.log(metrics)
def _init(
self,
log_on_batch_end: bool = False,
log_on_epoch_end: bool = True,
checkpoints_glob: List = None,
):
super()._init()
the_warning = DeprecatedWarning(
self.__class__.__name__,
deprecated_in="20.03",
removed_in="20.04",
details="Use WandbLogger instead."
)
warnings.warn(the_warning, category=DeprecationWarning, stacklevel=2)
self.log_on_batch_end = log_on_batch_end
self.log_on_epoch_end = log_on_epoch_end
self.checkpoints_glob = checkpoints_glob
if (self.log_on_batch_end and not self.log_on_epoch_end) \
or (not self.log_on_batch_end and self.log_on_epoch_end):
self.batch_log_suffix = ""
self.epoch_log_suffix = ""
else:
self.batch_log_suffix = "_batch"
self.epoch_log_suffix = "_epoch"
def _pre_experiment_hook(self, experiment: Experiment):
monitoring_params = experiment.monitoring_params
monitoring_params["dir"] = str(Path(experiment.logdir).absolute())
log_on_batch_end: bool = \
monitoring_params.pop("log_on_batch_end", False)
log_on_epoch_end: bool = \
monitoring_params.pop("log_on_epoch_end", True)
checkpoints_glob: List[str] = \
monitoring_params.pop("checkpoints_glob", [])
self._init(
log_on_batch_end=log_on_batch_end,
log_on_epoch_end=log_on_epoch_end,
checkpoints_glob=checkpoints_glob,
)
if isinstance(experiment, ConfigExperiment):
exp_config = utils.flatten_dict(experiment.stages_config)
wandb.init(**monitoring_params, config=exp_config)
else:
wandb.init(**monitoring_params)
def _post_experiment_hook(self, experiment: Experiment):
# @TODO: add params for artefacts logging
logdir_src = Path(experiment.logdir)
# logdir_dst = wandb.run.dir
#
# exclude = ["wandb", "checkpoints"]
# logdir_files = list(logdir_src.glob("*"))
# logdir_files = list(
# filter(
# lambda x: all(z not in str(x) for z in exclude), logdir_files
# )
# )
#
# for subdir in logdir_files:
# if subdir.is_dir():
# os.makedirs(f"{logdir_dst}/{subdir.name}", exist_ok=True)
# shutil.rmtree(f"{logdir_dst}/{subdir.name}")
# shutil.copytree(
# f"{str(subdir.absolute())}",
# f"{logdir_dst}/{subdir.name}"
# )
# else:
# shutil.copy2(
# f"{str(subdir.absolute())}",
# f"{logdir_dst}/{subdir.name}"
# )
#
checkpoints_src = logdir_src.joinpath("checkpoints")
checkpoints_dst = Path(wandb.run.dir).joinpath("checkpoints")
# os.makedirs(checkpoints_dst, exist_ok=True)
checkpoint_paths = []
for glob in self.checkpoints_glob:
checkpoint_paths.extend(list(checkpoints_src.glob(glob)))
checkpoint_paths = list(set(checkpoint_paths))
for checkpoint_path in checkpoint_paths:
shutil.copy2(
f"{str(checkpoint_path.absolute())}",
f"{checkpoints_dst}/{checkpoint_path.name}"
)
def _run_batch(self, batch):
super()._run_batch(batch=batch)
if self.log_on_batch_end:
mode = self.state.loader_name
metrics = self.state.batch_metrics
self._log_metrics(
metrics=metrics, mode=mode, suffix=self.batch_log_suffix
)
def _run_epoch(self, stage: str, epoch: int):
super()._run_epoch(stage=stage, epoch=epoch)
if self.log_on_epoch_end:
mode_metrics = utils.split_dict_to_subdicts(
dct=self.state.epoch_metrics,
prefixes=list(self.state.loaders.keys()),
extra_key="_base",
)
for mode, metrics in mode_metrics.items():
self._log_metrics(
metrics=metrics, mode=mode, suffix=self.epoch_log_suffix
)
[docs] def run_experiment(self, experiment: Experiment):
self._pre_experiment_hook(experiment=experiment)
super().run_experiment(experiment=experiment)
self._post_experiment_hook(experiment=experiment)
[docs]class SupervisedWandbRunner(WandbRunner, SupervisedRunner):
pass
__all__ = ["WandbRunner", "SupervisedWandbRunner"]