import contextlib
import logging
import logging.config
import os
import sys
import warnings
from io import StringIO
from threading import Thread
from typing import ContextManager, TextIO
from typing import TypeVar, Callable, Iterable, Iterator, Any
from rich import progress
from rich.console import Console
from rich.highlighter import NullHighlighter
from rich.logging import RichHandler
T = TypeVar('T')
class NullStdErr(ContextManager):
def __init__(self) -> None:
self._orig_fd: int | None = None
def __enter__(self) -> None:
assert self._orig_fd is None
with open("/dev/null", "wt") as null_fp:
self._orig_fd = os.dup(2)
os.dup2(null_fp.fileno(), 2)
def __exit__(self, *args, **kwargs) -> None:
assert self._orig_fd is not None
os.dup2(self._orig_fd, 2)
os.close(self._orig_fd)
self._orig_fd = None
class DupStdErr(ContextManager):
# native code from for example libmpg123 or tensorflow might directly write to fd 1/2 instead of Python's stdout/err stream
def __init__(self, logger: logging.Logger, level: int) -> None:
self._logger: logging.Logger = logger
self._level: int = level
self._pipe_fd: tuple[int, int] | None = None # r/w
self._orig_fd: int | None = None
self._thread: Thread | None = None
def _worker(self) -> None:
assert self._pipe_fd is not None
try:
with os.fdopen(self._pipe_fd[0], "r", closefd=True) as fp:
for line in fp:
self._logger.log(self._level, line.rstrip())
except BaseException as e:
self._logger.exception(f"reading from stderr pipe failed: {str(e)}", exc_info=e)
def __enter__(self) -> None:
assert self._pipe_fd is None
assert self._orig_fd is None
assert self._thread is None
self._pipe_fd = os.pipe2(0)
self._orig_fd = os.dup(2)
os.dup2(self._pipe_fd[1], 2)
self._thread = Thread(target=self._worker)
self._thread.start()
def __exit__(self, *args, **kwargs) -> None:
assert self._pipe_fd is not None
assert self._orig_fd is not None
assert self._thread is not None
os.dup2(self._orig_fd, 2)
os.close(self._orig_fd)
self._orig_fd = None
os.close(self._pipe_fd[1]) # close writing end -> eof
self._thread.join()
self._thread = None
self._pipe_fd = None
class LogIO(StringIO):
"""IO sink that passes written buffers to the given logger."""
def __init__(self, logger: logging.Logger, level: int) -> None:
self._logger: logging.Logger = logger
self._level: int = level
super().__init__()
def write(self, s: str) -> int:
for line in str(s).splitlines(keepends=False):
if line:
self._logger.log(self._level, line)
return len(s)
def isatty(self) -> bool:
# prevent auto-detection for example for progress bars from tqdm
return False
def redirect_warnings(logger: logging.Logger, level: int, action: str = "default") -> None:
"""Enable `warnings <https://docs.python.org/3/library/warnings.html>`__, but forward to the given logger."""
def log_warning(message, category, filename, lineno, file=None, line=None) -> None:
logger.log(level, "".join(
warnings.formatwarning(message, category, filename, lineno, line).strip().splitlines(keepends=False)
))
warnings.simplefilter(action)
warnings.showwarning = log_warning
class redirect_logging(ContextManager[TextIO]):
"""
Context to intercept standard streams and forward to the given logger instead.
This works for libraries with the bad habit of directly using ``print`` statements, for native code directly writing
to file-descriptors, we'd however need to ``dup()`` to a pipe.
"""
def __init__(self, logger: logging.Logger, out_level: int = logging.INFO, err_level: int = logging.WARNING) -> None:
self._logger: logging.Logger = logger
self._out: ContextManager = contextlib.redirect_stdout(LogIO(self._logger, level=out_level))
self._err: ContextManager = contextlib.redirect_stderr(LogIO(self._logger, level=err_level))
def __enter__(self) -> TextIO:
orig_out: TextIO = sys.stdout
self._out.__enter__()
self._err.__enter__()
return orig_out
def __exit__(self, *args) -> None:
self._err.__exit__(*args)
self._out.__exit__(*args)
class Status:
def __init__(self, bar: progress.Progress, logger: logging.Logger) -> None:
self._logger: logging.Logger = logger
self._progress: progress.Progress = bar
self._task: progress.TaskID | None = None
self._total: int = 0
self._done: int = 0
self._pending: int = 0
def get_substatus(self) -> 'Status':
return Status(self._progress, self._logger) # or own contextmanager?
def start_progress(self, description: str, total: int | None = None) -> None:
if self._task is not None:
self._progress.stop_task(self._task)
self._task = self._progress.add_task(description=description, total=total)
self._total = total if total is not None else 0
self._done = 0
self._pending = 0
def finish_progress(self) -> None:
self._pending, self._done = 0, self._pending + self._done
self._total = self._done
self._update()
if self._task is not None:
self._progress.stop_task(self._task)
def _update(self) -> None:
if self._task is not None:
self._progress.update(self._task, completed=self._done, total=self._total)
def task_add(self, n: int = 1) -> None:
self._total += n
self._update()
def task_done(self, n: int = 1) -> None:
self._done += n
self._update()
def task_done_wrap(self, n: int, val: T) -> T:
self.task_done(n)
return val
def task_pending(self, n: int = 1) -> None:
self._pending += n
def task_add_counter(self, items: Iterable[T]) -> Iterator[T]:
for item in items:
self.task_add()
yield item
def task_done_counter(self, items: Iterable[T]) -> Iterator[T]:
for item in items:
self.task_done()
yield item
def task_pending_counter(self, items: Iterable[T]) -> Iterator[T]:
for item in items:
self.task_pending()
yield item
@property
def logger(self) -> logging.Logger:
return self._logger
@property
def console(self) -> Console:
return self._progress.console
class _Progress(progress.Progress):
def __init__(self, console: Console) -> None:
super().__init__(
progress.TextColumn("[progress.description]{task.description}"),
progress.SpinnerColumn(finished_text="✓"),
progress.BarColumn(),
progress.TaskProgressColumn(),
progress.MofNCompleteColumn(),
progress.TimeElapsedColumn(),
progress.TimeRemainingColumn(),
redirect_stdout=False,
redirect_stderr=False,
console=console,
)
def run(callback: Callable, **kwargs) -> Any:
debug: bool = kwargs.get("debug", False)
console: Console = Console(file=sys.stdout)
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO,
format="%(name)s: %(message)s",
handlers=[RichHandler(console=console, show_path=False, log_time_format="[%X]", highlighter=NullHighlighter())])
logger: logging.Logger = logging.getLogger("main")
redirect_warnings(logger.getChild("warnings"), level=logging.WARNING if debug else logging.DEBUG)
with redirect_logging(logger.getChild("std"), out_level=logging.INFO, err_level=logging.WARNING):
with DupStdErr(logger.getChild("dup"), logging.DEBUG):
with _Progress(console=console) as progress_bar:
return callback(Status(progress_bar, logger), **kwargs)