import logging
import os
import resource
import subprocess
import threading
import time
from typing import Any, Optional, List, Dict, ContextManager, IO


class BackgroundProcess(ContextManager[None]):
    """Context running a process in the background, logging its stdio."""

    def __init__(self, logger: logging.Logger, env: Dict[str, str], executable: str, *args: str) -> None:
        self._logger: logging.Logger = logger
        self._env: Dict[str, str] = {**os.environ, **env}
        self._args: List[str] = [executable] + list(args)
        self._process: Optional[subprocess.Popen] = None
        self._out_thread: Optional[threading.Thread] = None
        self._err_thread: Optional[threading.Thread] = None

    def _worker(self, pipe: IO[str], level: int) -> None:
        for line in pipe:
            line = line.rstrip()
            if line:
                self._logger.log(level, line)

    def __enter__(self) -> None:
        assert self._process is None
        assert self._out_thread is None
        assert self._err_thread is None
        self._process = subprocess.Popen(
            self._args,
            stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            close_fds=True, shell=False, env=self._env,
            encoding="utf-8", errors="replace",
        )
        self._logger.info(f"Starting {self._args[0]}, PID {self._process.pid}")
        self._out_thread = threading.Thread(target=self._worker, args=(self._process.stdout, logging.DEBUG))
        self._err_thread = threading.Thread(target=self._worker, args=(self._process.stderr, logging.INFO))
        self._out_thread.start()
        self._err_thread.start()

    def __exit__(self, *args: Any) -> None:
        assert self._process is not None
        assert self._out_thread is not None
        assert self._err_thread is not None

        self._logger.info(f"Terminating PID {self._process.pid}")
        self._process.terminate()
        rv: int = self._process.wait()
        if rv != 0:
            self._logger.warning(f"PID {self._process.pid} exited with {rv}")
        self._out_thread.join()
        self._err_thread.join()

    def __str__(self) -> str:
        return " ".join(self._args)


class Rusage(ContextManager[None]):
    """Measure runtime, (overall) CPU time, and max RSS usage."""

    def __init__(self, logger: logging.Logger, level: int) -> None:
        self._logger: logging.Logger = logger.getChild(self.__class__.__name__)
        self._level: int = level
        self._start_at: float = time.monotonic()

    def __enter__(self) -> None:
        self._start_at = time.monotonic()

    def __exit__(self, *args: Any) -> None:
        duration: float = time.monotonic() - self._start_at
        ru_self: resource.struct_rusage = resource.getrusage(resource.RUSAGE_SELF)
        ru_chld: resource.struct_rusage = resource.getrusage(resource.RUSAGE_CHILDREN)

        rss: int = round(ru_self.ru_maxrss / 1000)
        rss_all: int = round((ru_self.ru_maxrss + ru_chld.ru_maxrss) / 1000)
        cpu: float = ru_self.ru_utime + ru_chld.ru_utime + ru_self.ru_stime + ru_chld.ru_stime

        self._logger.log(self._level, f"{rss} ({rss_all}) MB RSS, {cpu:.3f} sec CPU, {duration:.3f} sec")