import contextlib
import logging.config
import sys
import warnings
from io import StringIO
from typing import ContextManager, TextIO, Optional, Dict

import colorlog


class LogIO(StringIO):
    """IO sink that passes written buffers to the given logger."""

    def __init__(self, logger: logging.Logger, level: int) -> None:
        self._logger: logging.Logger = logger
        self._level: int = level
        super().__init__()

    def write(self, s: str) -> int:
        for line in str(s).splitlines(keepends=False):
            if line:
                self._logger.log(self._level, line)
        return len(s)

    def isatty(self) -> bool:
        # prevent for example progress bars from ``tqdm``
        return False


def _log_warnings(logger: logging.Logger, level: int) -> None:
    """Enable `warnings <https://docs.python.org/3/library/warnings.html>`__, but forward to the given logger."""

    def log_warning(message, category, filename, lineno, file=None, line=None) -> None:
        logger.log(level, "␤".join(
            warnings.formatwarning(message, category, filename, lineno, line).strip().splitlines(keepends=False)
        ))

    warnings.simplefilter("default")
    warnings.showwarning = log_warning


def setup_logging(config: Optional[Dict], name: str, level: str) -> logging.Logger:
    """Setup for the logging subsystem by the given config or using defaults to ``stderr``."""

    if config is not None:
        logging.config.dictConfig(config)
        logging.getLogger().setLevel(level)
    elif sys.stderr.isatty():
        colorlog.basicConfig(level=logging.getLevelName(level), stream=sys.stderr,
                             format="%(log_color)s%(levelname)-8s%(reset)s %(name)s: %(message)s")
    else:
        logging.basicConfig(level=logging.getLevelName(level), stream=sys.stderr,
                            format="%(levelname)s %(name)s: %(message)s")

    # prevent persistent manipulation of the global log level configuration, for example from:
    # TTS.utils.audio.processor -> scipy.optimize -> numpy.distutils.log -> set_verbosity(0, force=True)
    logging.getLogger().setLevel = lambda _: None  # type: ignore

    logger: logging.Logger = logging.getLogger(name)
    _log_warnings(logger.getChild("warnings"), logging.DEBUG)
    return logger


class redirect_logging(ContextManager[TextIO]):
    """
    Context to intercept standard streams and forward to the given logger instead.
    This works for libraries with the bad habit of directly using ``print`` statements, for native code directly writing
    to file-descriptors, we'd however need to ``dup()`` to a pipe.
    """

    def __init__(self, logger: logging.Logger) -> None:
        self._logger: logging.Logger = logger.getChild("print")
        self._out: ContextManager = contextlib.redirect_stdout(LogIO(self._logger, level=logging.DEBUG))
        self._err: ContextManager = contextlib.redirect_stderr(LogIO(self._logger, level=logging.WARNING))

    def __enter__(self) -> TextIO:
        orig_out: TextIO = sys.stdout
        self._out.__enter__()
        self._err.__enter__()
        return orig_out

    def __exit__(self, *args) -> None:
        self._err.__exit__(*args)
        self._out.__exit__(*args)