from typing import Dict, List  # isort:skip
import logging
import os
import sys
from urllib.parse import quote_plus
from urllib.request import Request, urlopen
from tqdm import tqdm
from catalyst.dl import utils
from catalyst.dl.core import LoggerCallback, RunnerState
from catalyst.dl.utils.formatters import TxtMetricsFormatter
from catalyst.utils import format_metric
from catalyst.utils.tensorboard import SummaryWriter
[docs]class VerboseLogger(LoggerCallback):
    """
    Logs the params into console
    """
[docs]    def __init__(
        self,
        always_show: List[str] = None,
        never_show: List[str] = None,
    ):
        """
        Args:
            always_show (List[str]): list of metrics to always show
                if None default is ``["_timers/_fps"]``
                to remove always_show metrics set it to an empty list ``[]``
            never_show (List[str]): list of metrics which will not be shown
        """
        super().__init__()
        self.tqdm: tqdm = None
        self.step = 0
        self.always_show = (
            always_show if always_show is not None else ["_timers/_fps"]
        )
        self.never_show = never_show if never_show is not None else []
        intersection = set(self.always_show) & set(self.never_show)
        _error_message = (
            f"Intersection of always_show and "
            f"never_show has common values: {intersection}"
        )
        if bool(intersection):
            raise ValueError(_error_message) 
    def _need_show(self, key: str):
        not_is_never_shown: bool = key not in self.never_show
        is_always_shown: bool = key in self.always_show
        not_basic = not (key.startswith("_base") or key.startswith("_timers"))
        result = not_is_never_shown and (is_always_shown or not_basic)
        return result
[docs]    def on_loader_start(self, state: RunnerState):
        """Init tqdm progress bar"""
        self.step = 0
        self.tqdm = tqdm(
            total=state.loader_len,
            desc=f"{state.stage_epoch_log}/{state.num_epochs}"
            f" * Epoch ({state.loader_name})",
            leave=True,
            ncols=0,
            file=sys.stdout,
        ) 
[docs]    def on_batch_end(self, state: RunnerState):
        """Update tqdm progress bar at the end of each batch"""
        self.tqdm.set_postfix(
            **{
                k: "{:3.3f}".format(v) if v > 1e-3 else "{:1.3e}".format(v)
                for k, v in sorted(state.metrics.batch_values.items())
                if self._need_show(k)
            }
        )
        self.tqdm.update() 
[docs]    def on_loader_end(self, state: RunnerState):
        """Cleanup and close tqdm progress bar"""
        self.tqdm.close()
        self.tqdm = None
        self.step = 0 
[docs]    def on_exception(self, state: RunnerState):
        """Called if an Exception was raised"""
        exception = state.exception
        if not utils.is_exception(exception):
            return
        if isinstance(exception, KeyboardInterrupt):
            self.tqdm.write("Early exiting")
            state.need_reraise_exception = False  
[docs]class ConsoleLogger(LoggerCallback):
    """
    Logger callback, translates ``state.metrics`` to console and text file
    """
[docs]    def __init__(self):
        """Init ``ConsoleLogger``"""
        super().__init__()
        self.logger = None 
    @staticmethod
    def _get_logger(logdir):
        logger = logging.getLogger("metrics_logger")
        logger.setLevel(logging.INFO)
        fh = logging.FileHandler(f"{logdir}/log.txt")
        fh.setLevel(logging.INFO)
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.INFO)
        # @TODO: fix json logger
        # jh = logging.FileHandler(f"{logdir}/metrics.json")
        # jh.setLevel(logging.INFO)
        txt_formatter = TxtMetricsFormatter()
        # json_formatter = JsonMetricsFormatter()
        fh.setFormatter(txt_formatter)
        ch.setFormatter(txt_formatter)
        # jh.setFormatter(json_formatter)
        # add the handlers to the logger
        logger.addHandler(fh)
        logger.addHandler(ch)
        # logger.addHandler(jh)
        return logger
[docs]    def on_stage_start(self, state: RunnerState):
        """Prepare ``state.logdir`` for the current stage"""
        assert state.logdir is not None
        state.logdir.mkdir(parents=True, exist_ok=True)
        self.logger = self._get_logger(state.logdir) 
[docs]    def on_stage_end(self, state):
        """Called at the end of each stage"""
        for handler in self.logger.handlers:
            handler.close()
        self.logger.handlers = [] 
[docs]    def on_epoch_end(self, state):
        """
        Translate ``state.metrics`` to console and text file
        at the end of an epoch
        """
        self.logger.info("", extra={"state": state})  
[docs]class TensorboardLogger(LoggerCallback):
    """
    Logger callback, translates ``state.metrics`` to tensorboard
    """
[docs]    def __init__(
        self,
        metric_names: List[str] = None,
        log_on_batch_end: bool = True,
        log_on_epoch_end: bool = True,
    ):
        """
        Args:
            metric_names (List[str]): list of metric names to log,
                if none - logs everything
            log_on_batch_end (bool): logs per-batch metrics if set True
            log_on_epoch_end (bool): logs per-epoch metrics if set True
        """
        super().__init__()
        self.metrics_to_log = metric_names
        self.log_on_batch_end = log_on_batch_end
        self.log_on_epoch_end = log_on_epoch_end
        if not (self.log_on_batch_end or self.log_on_epoch_end):
            raise ValueError("You have to log something!")
        self.loggers = dict() 
    def _log_metrics(
        self, metrics: Dict[str, float], step: int, mode: str, suffix=""
    ):
        if self.metrics_to_log is None:
            metrics_to_log = sorted(list(metrics.keys()))
        else:
            metrics_to_log = self.metrics_to_log
        for name in metrics_to_log:
            if name in metrics:
                self.loggers[mode].add_scalar(
                    f"{name}{suffix}", metrics[name], step
                )
[docs]    def on_loader_start(self, state):
        """Prepare tensorboard writers for the current stage"""
        lm = state.loader_name
        if lm not in self.loggers:
            log_dir = os.path.join(state.logdir, f"{lm}_log")
            self.loggers[lm] = SummaryWriter(log_dir) 
[docs]    def on_batch_end(self, state: RunnerState):
        """Translate batch metrics to tensorboard"""
        if self.log_on_batch_end:
            mode = state.loader_name
            metrics_ = state.metrics.batch_values
            self._log_metrics(
                metrics=metrics_, step=state.step, mode=mode, suffix="/batch"
            ) 
[docs]    def on_loader_end(self, state: RunnerState):
        """Translate epoch metrics to tensorboard"""
        if self.log_on_epoch_end:
            mode = state.loader_name
            metrics_ = state.metrics.epoch_values[mode]
            self._log_metrics(
                metrics=metrics_,
                step=state.epoch_log,
                mode=mode,
                suffix="/epoch",
            )
        for logger in self.loggers.values():
            logger.flush() 
[docs]    def on_stage_end(self, state: RunnerState):
        """Close opened tensorboard writers"""
        for logger in self.loggers.values():
            logger.close()  
[docs]class TelegramLogger(LoggerCallback):
    """
    Logger callback, translates ``state.metrics`` to telegram channel
    """
[docs]    def __init__(
        self,
        token: str = None,
        chat_id: str = None,
        metric_names: List[str] = None,
        log_on_stage_start: bool = True,
        log_on_loader_start: bool = True,
        log_on_loader_end: bool = True,
        log_on_stage_end: bool = True,
        log_on_exception: bool = True,
    ):
        """
        Args:
            token (str): telegram bot's token,
                see https://core.telegram.org/bots
            chat_id (str): Chat unique identifier
            metric_names: List of metric names to log.
                if none - logs everything.
            log_on_stage_start (bool): send notification on stage start
            log_on_loader_start (bool): send notification on loader start
            log_on_loader_end (bool): send notification on loader end
            log_on_stage_end (bool): send notification on stage end
            log_on_exception (bool): send notification on exception
        """
        super().__init__()
        # @TODO: replace this logic with global catalyst config at ~/.catalyst
        self._token = token or os.environ.get("CATALYST_TELEGRAM_TOKEN", None)
        self._chat_id = (
            chat_id or os.environ.get("CATALYST_TELEGRAM_CHAT_ID", None)
        )
        assert self._token is not None and self._chat_id is not None
        self._base_url = (
            f"https://api.telegram.org/bot{self._token}/sendMessage"
        )
        self.log_on_stage_start = log_on_stage_start
        self.log_on_loader_start = log_on_loader_start
        self.log_on_loader_end = log_on_loader_end
        self.log_on_stage_end = log_on_stage_end
        self.log_on_exception = log_on_exception
        self.metrics_to_log = metric_names 
    def _send_text(self, text: str):
        try:
            url = (
                f"{self._base_url}?"
                f"chat_id={self._chat_id}&"
                f"disable_web_page_preview=1&"
                f"text={quote_plus(text, safe='')}"
            )
            request = Request(url)
            urlopen(request)
        except Exception as e:
            logging.getLogger(__name__).warning(f"telegram.send.error:{e}")
[docs]    def on_stage_start(self, state: RunnerState):
        """Notify about starting a new stage"""
        if self.log_on_stage_start:
            text = f"{state.stage} stage was started"
            self._send_text(text) 
[docs]    def on_loader_start(self, state: RunnerState):
        """Notify about starting running the new loader"""
        if self.log_on_loader_start:
            text = f"{state.loader_name} {state.epoch} epoch was started"
            self._send_text(text) 
[docs]    def on_loader_end(self, state: RunnerState):
        """Translate ``state.metrics`` to telegram channel"""
        if self.log_on_loader_end:
            metrics = state.metrics.epoch_values[state.loader_name]
            if self.metrics_to_log is None:
                metrics_to_log = sorted(list(metrics.keys()))
            else:
                metrics_to_log = self.metrics_to_log
            rows: List[str] = [
                f"{state.loader_name} {state.epoch} epoch was finished:"
            ]
            for name in metrics_to_log:
                if name in metrics:
                    rows.append(format_metric(name, metrics[name]))
            text = "\n".join(rows)
            self._send_text(text) 
[docs]    def on_stage_end(self, state: RunnerState):
        """Notify about finishing a stage"""
        if self.log_on_stage_end:
            text = f"{state.stage} stage was finished"
            self._send_text(text) 
[docs]    def on_exception(self, state: RunnerState):
        """Notify about raised Exception"""
        if self.log_on_exception:
            exception = state.exception
            if utils.is_exception(exception) and not isinstance(
                exception, KeyboardInterrupt
            ):
                text = (
                    f"`{type(exception).__name__}` exception was raised:\n"
                    f"{exception}"
                )
                self._send_text(text)  
__all__ = [
    "VerboseLogger",
    "ConsoleLogger",
    "TensorboardLogger",
    "TelegramLogger",
]