import contextlib
import logging
import logging.config
import os
import sys
import warnings
from io import StringIO
from threading import Thread
from typing import ContextManager, TextIO
from typing import TypeVar, Callable, Iterable, Iterator, Any

from rich import progress
from rich.console import Console
from rich.highlighter import NullHighlighter
from rich.logging import RichHandler

T = TypeVar('T')


class NullStdErr(ContextManager):
    def __init__(self) -> None:
        self._orig_fd: int | None = None

    def __enter__(self) -> None:
        assert self._orig_fd is None
        with open("/dev/null", "wt") as null_fp:
            self._orig_fd = os.dup(2)
            os.dup2(null_fp.fileno(), 2)

    def __exit__(self, *args, **kwargs) -> None:
        assert self._orig_fd is not None
        os.dup2(self._orig_fd, 2)
        os.close(self._orig_fd)
        self._orig_fd = None


class DupStdErr(ContextManager):
    # native code from for example libmpg123 or tensorflow might directly write to fd 1/2 instead of Python's stdout/err stream

    def __init__(self, logger: logging.Logger, level: int) -> None:
        self._logger: logging.Logger = logger
        self._level: int = level
        self._pipe_fd: tuple[int, int] | None = None  # r/w
        self._orig_fd: int | None = None
        self._thread: Thread | None = None

    def _worker(self) -> None:
        assert self._pipe_fd is not None
        try:
            with os.fdopen(self._pipe_fd[0], "r", closefd=True) as fp:
                for line in fp:
                    self._logger.log(self._level, line.rstrip())
        except BaseException as e:
            self._logger.exception(f"reading from stderr pipe failed: {str(e)}", exc_info=e)

    def __enter__(self) -> None:
        assert self._pipe_fd is None
        assert self._orig_fd is None
        assert self._thread is None
        self._pipe_fd = os.pipe2(0)
        self._orig_fd = os.dup(2)
        os.dup2(self._pipe_fd[1], 2)
        self._thread = Thread(target=self._worker)
        self._thread.start()

    def __exit__(self, *args, **kwargs) -> None:
        assert self._pipe_fd is not None
        assert self._orig_fd is not None
        assert self._thread is not None
        os.dup2(self._orig_fd, 2)
        os.close(self._orig_fd)
        self._orig_fd = None
        os.close(self._pipe_fd[1])  # close writing end -> eof
        self._thread.join()
        self._thread = None
        self._pipe_fd = None


class LogIO(StringIO):
    """IO sink that passes written buffers to the given logger."""

    def __init__(self, logger: logging.Logger, level: int) -> None:
        self._logger: logging.Logger = logger
        self._level: int = level
        super().__init__()

    def write(self, s: str) -> int:
        for line in str(s).splitlines(keepends=False):
            if line:
                self._logger.log(self._level, line)
        return len(s)

    def isatty(self) -> bool:
        # prevent auto-detection for example for progress bars from tqdm
        return False


def redirect_warnings(logger: logging.Logger, level: int, action: str = "default") -> None:
    """Enable `warnings <https://docs.python.org/3/library/warnings.html>`__, but forward to the given logger."""

    def log_warning(message, category, filename, lineno, file=None, line=None) -> None:
        logger.log(level, "␤".join(
            warnings.formatwarning(message, category, filename, lineno, line).strip().splitlines(keepends=False)
        ))

    warnings.simplefilter(action)
    warnings.showwarning = log_warning


class redirect_logging(ContextManager[TextIO]):
    """
    Context to intercept standard streams and forward to the given logger instead.
    This works for libraries with the bad habit of directly using ``print`` statements, for native code directly writing
    to file-descriptors, we'd however need to ``dup()`` to a pipe.
    """

    def __init__(self, logger: logging.Logger, out_level: int = logging.INFO, err_level: int = logging.WARNING) -> None:
        self._logger: logging.Logger = logger
        self._out: ContextManager = contextlib.redirect_stdout(LogIO(self._logger, level=out_level))
        self._err: ContextManager = contextlib.redirect_stderr(LogIO(self._logger, level=err_level))

    def __enter__(self) -> TextIO:
        orig_out: TextIO = sys.stdout
        self._out.__enter__()
        self._err.__enter__()
        return orig_out

    def __exit__(self, *args) -> None:
        self._err.__exit__(*args)
        self._out.__exit__(*args)


class Status:
    def __init__(self, bar: progress.Progress, logger: logging.Logger) -> None:
        self._logger: logging.Logger = logger
        self._progress: progress.Progress = bar
        self._task: progress.TaskID | None = None
        self._total: int = 0
        self._done: int = 0
        self._pending: int = 0

    def get_substatus(self) -> 'Status':
        return Status(self._progress, self._logger)  # or own contextmanager?

    def start_progress(self, description: str, total: int | None = None) -> None:
        if self._task is not None:
            self._progress.stop_task(self._task)
        self._task = self._progress.add_task(description=description, total=total)
        self._total = total if total is not None else 0
        self._done = 0
        self._pending = 0

    def finish_progress(self) -> None:
        self._pending, self._done = 0, self._pending + self._done
        self._total = self._done
        self._update()
        if self._task is not None:
            self._progress.stop_task(self._task)

    def _update(self) -> None:
        if self._task is not None:
            self._progress.update(self._task, completed=self._done, total=self._total)

    def task_add(self, n: int = 1) -> None:
        self._total += n
        self._update()

    def task_done(self, n: int = 1) -> None:
        self._done += n
        self._update()

    def task_done_wrap(self, n: int, val: T) -> T:
        self.task_done(n)
        return val

    def task_pending(self, n: int = 1) -> None:
        self._pending += n

    def task_add_counter(self, items: Iterable[T]) -> Iterator[T]:
        for item in items:
            self.task_add()
            yield item

    def task_done_counter(self, items: Iterable[T]) -> Iterator[T]:
        for item in items:
            self.task_done()
            yield item

    def task_pending_counter(self, items: Iterable[T]) -> Iterator[T]:
        for item in items:
            self.task_pending()
            yield item

    @property
    def logger(self) -> logging.Logger:
        return self._logger

    @property
    def console(self) -> Console:
        return self._progress.console


class _Progress(progress.Progress):
    def __init__(self, console: Console) -> None:
        super().__init__(
            progress.TextColumn("[progress.description]{task.description}"),
            progress.SpinnerColumn(finished_text="✓"),
            progress.BarColumn(),
            progress.TaskProgressColumn(),
            progress.MofNCompleteColumn(),
            progress.TimeElapsedColumn(),
            progress.TimeRemainingColumn(),
            redirect_stdout=False,
            redirect_stderr=False,
            console=console,
        )


def run(callback: Callable, **kwargs) -> Any:
    debug: bool = kwargs.get("debug", False)
    console: Console = Console(file=sys.stdout)
    logging.basicConfig(level=logging.DEBUG if debug else logging.INFO,
                        format="%(name)s: %(message)s",
                        handlers=[RichHandler(console=console, show_path=False, log_time_format="[%X]", highlighter=NullHighlighter())])
    logger: logging.Logger = logging.getLogger("main")

    redirect_warnings(logger.getChild("warnings"), level=logging.WARNING if debug else logging.DEBUG)
    with redirect_logging(logger.getChild("std"), out_level=logging.INFO, err_level=logging.WARNING):
        with DupStdErr(logger.getChild("dup"), logging.DEBUG):
            with _Progress(console=console) as progress_bar:
                return callback(Status(progress_bar, logger), **kwargs)