#!/usr/bin/env python3

"""
Query container status and show most relevant information in a tabular overview.
"""

import argparse
import json
import os
import re
import socket
import sys
from abc import abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, fields, field
from datetime import datetime, timezone
from http.client import HTTPResponse, HTTPConnection
from pathlib import Path
from time import time
from typing import Union, Dict, List, Type, TypeVar, Optional, Iterator, ClassVar, Tuple, Callable, Any
from urllib.error import URLError
from urllib.parse import urlencode
from urllib.request import Request, AbstractHTTPHandler, OpenerDirector, build_opener


class UnixHTTPConnection(HTTPConnection):
    """HTTP client connection via Unix Domain Socket."""

    def __init__(self, *args, **kwargs) -> None:
        self._unix_path: Path = kwargs.pop("unix_path")
        super().__init__(*args, **kwargs)

    def connect(self) -> None:
        self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
        if self.timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:  # type: ignore[attr-defined]
            self.sock.settimeout(self.timeout)
        self.sock.connect(str(self._unix_path))
        if self._tunnel_host:  # type: ignore[attr-defined]
            self._tunnel()  # type: ignore[attr-defined]


class UnixHTTPHandler(AbstractHTTPHandler):
    """URLLib handler for the unix:// protocol."""

    def __init__(self, unix_path: Path) -> None:
        self._unix_path: Path = unix_path
        super().__init__()

    def unix_open(self, request: Request) -> HTTPResponse:
        return self.do_open(UnixHTTPConnection, request, unix_path=self._unix_path)

    def unix_request(self, request: Request) -> Request:
        return self.do_request_(request)


class RequestError(Exception):
    pass


class RequestNotFoundError(RequestError):
    pass


class ResponseTypeError(RequestError):
    pass


class DockerClient:
    """HTTP JSON or log requests to the docker socket."""

    def __init__(self, unix_path: Path, timeout: float = 30.0) -> None:
        self._timeout: float = timeout
        self._socket: str = unix_path.as_posix()
        self._base_url: str = "unix://localhost/"
        self._opener: OpenerDirector = build_opener(UnixHTTPHandler(unix_path))
        self._decoder: json.JSONDecoder = json.JSONDecoder()

    def _request(self, url: str, expect_ct: str, **kwargs: Union[str, List[str]]) -> bytes:
        full_url: str = self._base_url + url + "?" + urlencode(kwargs, doseq=True) if kwargs else self._base_url + url
        try:
            with self._opener.open(Request(full_url), timeout=self._timeout) as response:
                if response.status == 404 or response.status == 501:
                    raise RequestNotFoundError(f"{self._socket}: Response status {response.status}: {url}")
                elif response.status != 200:
                    raise RequestError(f"{self._socket}: Response status {response.status}: {url}")
                elif response.headers.get("content-type", "") != expect_ct:
                    raise ResponseTypeError(f"{self._socket}: Unexpected response "
                                            f"'{response.headers.get('content-type', '')}': {url}")
                else:
                    return response.read()
        except (OSError, URLError) as e:
            raise RequestError(f"{self._socket}: Request failed with {str(e)}: {url}") from None

    def request_json(self, url: str, **kwargs: Union[str, List[str]]) -> Any:
        return self._decoder.decode(self._request(url, "application/json", **kwargs).decode(encoding="utf-8",
                                                                                            errors="surrogatepass"))

    def request_logs(self, url: str, **kwargs: Union[str, List[str]]) -> Iterator[str]:
        yield from self._parse_logs(self._request(url, "application/vnd.docker.multiplexed-stream", **kwargs))

    def _parse_logs(self, stream: bytes) -> Iterator[str]:
        # https://docs.docker.com/engine/api/v1.45/#tag/Container/operation/ContainerAttach
        while len(stream) >= 8:
            header, stream = stream[:8], stream[8:]
            line_len: int = int.from_bytes(header[4:], byteorder="big", signed=False)
            if line_len > len(stream):
                raise ResponseTypeError(f"{self._socket}: Premature end of log stream")
            line, stream = stream[:line_len], stream[line_len:]
            yield line.decode(encoding="utf-8", errors="replace").rstrip()
        if len(stream) > 0:
            raise ResponseTypeError(f"{self._socket}: Trailing data after log stream")


@dataclass(frozen=True)
class DockerResponse:
    """Base class for JSON responses, with automatic dict lookups and custom factory callbacks."""

    T = TypeVar("T", bound="DockerResponse")

    @classmethod
    def parse(cls: Type[T], data: Dict, **kwargs: Any) -> T:
        f: Dict[str, Callable] = cls.fields()
        return cls(**{k: f[k](v) for k, v in data.items() if v is not None and k in f}, **kwargs)

    @classmethod
    def fields(cls) -> Dict[str, Callable]:
        if not hasattr(cls, "_fields"):
            identity: Callable = lambda _: _
            setattr(cls, "_fields", {_.name: _.metadata.get("factory", identity) for _ in fields(cls)})
        return getattr(cls, "_fields")


@dataclass(frozen=True)
class Container(DockerResponse):
    Id: str


@dataclass(frozen=True)
class ContainerState(DockerResponse):
    Status: str
    Running: bool
    Pid: int = 0
    StartedAt: Optional[str] = None
    FinishedAt: Optional[str] = None


@dataclass(frozen=True)
class ContainerConfig(DockerResponse):
    Image: str


@dataclass(frozen=True)
class HostLogConfig(DockerResponse):
    Type: str


@dataclass(frozen=True)
class HostConfig(DockerResponse):
    NetworkMode: str
    ReadonlyRootfs: bool
    ShmSize: int = 64 * 1024 * 1024
    Tmpfs: Dict = field(default_factory=dict)
    LogConfig: Optional[HostLogConfig] = field(default=None, metadata={"factory": HostLogConfig.parse})


@dataclass(frozen=True)
class PortBinding(DockerResponse):
    HostIp: str
    HostPort: str


@dataclass(frozen=True)
class NetworkSettings(DockerResponse):
    Networks: Dict[str, Dict] = field(default_factory=dict)
    Ports: Dict[str, List[PortBinding]] = field(default_factory=dict, metadata={
        "factory": lambda _: {k: [PortBinding.parse(b) for b in v] if v is not None else [] for k, v in _.items()}
    })


@dataclass(frozen=True)
class MountPoint(DockerResponse):
    Type: str
    Source: str
    Destination: str
    RW: bool
    Name: Optional[str] = None


@dataclass(frozen=True)
class ContainerInspect(DockerResponse):
    Id: str
    Image: str  # Id
    Name: str
    Path: str
    SizeRootFs: int
    SizeRw: int
    State: ContainerState = field(metadata={"factory": ContainerState.parse})
    Config: ContainerConfig = field(metadata={"factory": ContainerConfig.parse})
    HostConfig: HostConfig = field(metadata={"factory": HostConfig.parse})
    NetworkSettings: NetworkSettings = field(metadata={"factory": NetworkSettings.parse})
    Mounts: List[MountPoint] = field(default_factory=list, metadata={
        "factory": lambda _: [MountPoint.parse(mp) for mp in _]
    })
    Created: Optional[str] = None


@dataclass(frozen=True)
class ImageDf(DockerResponse):
    Id: str
    Created: int
    RepoTags: List[str] = field(default_factory=list)


@dataclass(frozen=True)
class VolumeDf(DockerResponse):
    Name: str
    Size: int


@dataclass(frozen=True)
class SystemDf(DockerResponse):
    Images: Dict[str, ImageDf] = field(default_factory=dict, metadata={
        "factory": lambda _: {df["Id"]: ImageDf.parse(df) for df in _}
    })
    Volumes: Dict[str, VolumeDf] = field(default_factory=dict, metadata={
        "factory": lambda _: {df["Name"]: VolumeDf.parse(df.get("UsageData", {}), Name=df["Name"]) for df in _}
    })


@dataclass(frozen=True)
class PidStat:
    """Parsed /proc/<pid>/stat fields."""

    pid: int
    ppid: int
    pgrp: int
    name: str
    state: str
    ctime: int  # utime + stime
    rss: int  # pages (might be inaccurate)
    vsize: int


class ProcStat:
    """Crawl /proc/<pid>/stat, providing a process tree."""

    @classmethod
    def _split(cls, line: str) -> List[str]:
        cmd_start: int = line.index("(")
        cmd_end: int = line.rindex(")")
        return [line[:cmd_start - 1], line[cmd_start + 1:cmd_end]] + line[cmd_end + 2:].split(" ")

    @classmethod
    def _parse(cls, pid: int, line: str) -> PidStat:
        # https://man7.org/linux/man-pages/man5/proc.5.html
        stat: List[str] = cls._split(line)
        return PidStat(pid=pid, ppid=int(stat[3]), pgrp=int(stat[4]), name=stat[1], state=stat[2],
                       ctime=int(stat[13]) + int(stat[14]), rss=int(stat[23]), vsize=int(stat[22]))

    @classmethod
    def run(cls) -> Dict[int, PidStat]:
        data: Dict[int, PidStat] = {}
        for child in Path("/proc/").iterdir():
            if child.name.isdigit():
                try:
                    pid: int = int(child.name)
                    data[pid] = cls._parse(pid, (child / "stat").read_text())
                except FileNotFoundError:
                    pass
                except (OSError, ValueError):
                    pass
        return data

    @classmethod
    def find_recursive(cls, pid: int, proc_stat: Dict[int, PidStat], depth: int = 0) -> Iterator[Tuple[int, PidStat]]:
        if pid in proc_stat:
            yield depth, proc_stat[pid]
            for stat in proc_stat.values():
                if stat.ppid == pid:
                    yield from cls.find_recursive(stat.pid, proc_stat, depth + 1)


class Formatter:
    """Terminal output and unit conversions."""

    _ansi_codes: ClassVar[Dict[str, str]] = {k: '\x1b[{}m'.format(';'.join(str(_) for _ in v)) for k, v in {
        "reset": (0,), "bold": (1,),
        "grey": (30,), "red": (31,), "green": (32,), "yellow": (33,),
        "blue": (34,), "magenta": (35,), "cyan": (36,), "white": (37,),
        "brightgrey": (90,), "brightred": (91,), "brightgreen": (92,), "brightyellow": (93,),
        "brightblue": (94,), "brightmagenta": (95,), "brightcyan": (96,), "brightwhite": (97,),
        "boldgrey": (1, 30), "boldred": (1, 31), "boldgreen": (1, 32), "boldyellow": (1, 33),
        "boldblue": (1, 34), "boldmagenta": (1, 35), "boldcyan": (1, 36), "boldwhite": (1, 37),
    }.items()}

    _sizes: ClassVar[List[Tuple[int, int, str]]] = [
        (1_000_000_000, 2, "G"),
        (1_000_000, 2, "M"),
        (1_000, 5, "K"),
        (1, 0, "B"),
    ]

    _durations: ClassVar[List[Tuple[int, int, str]]] = [
        (60 * 60 * 24, 3, " days ago"),
        (60 * 60, 2, " hours ago"),
        (60, 3, " minutes ago"),
        (1, 2, " seconds ago"),
        (1, 1, " second ago"),
        (1, 0, " seconds ago"),
    ]

    def __init__(self) -> None:
        self._now: float = time()
        self._jiffies: int = os.sysconf(os.sysconf_names["SC_CLK_TCK"]) if "SC_CLK_TCK" in os.sysconf_names else 100
        self._page: int = os.sysconf(os.sysconf_names["SC_PAGESIZE"]) if "SC_PAGESIZE" in os.sysconf_names else 4096
        self._colors: bool = len(os.getenv("NO_COLOR", "")) == 0 and os.isatty(sys.stdout.fileno())

    @classmethod
    def _format_unit(cls, val: Union[int, float], units: List[Tuple[int, int, str]]) -> str:
        for fac, thresh, unit in units:
            converted: int = round(val / fac)
            if converted >= thresh:
                return f"{converted}{unit}"
        return str(val)

    def format_size(self, val: Union[int, float]) -> str:
        return self._format_unit(val, self._sizes)

    def format_pages(self, val: Union[int, float]) -> str:
        return self._format_unit(val * self._page, self._sizes)

    def format_duration(self, val: Union[int, float]) -> str:
        return "; ".join((
            datetime.fromtimestamp(val, tz=timezone.utc).astimezone().strftime("%a %Y-%m-%d %H:%M:%S %Z"),
            self._format_unit(self._now - val, self._durations),
        ))

    def format_jiffy(self, val: Union[int, float]) -> str:
        val = round(val / self._jiffies)
        return f"{val // 60}:{val % 60:02d}"

    def format_status(self, val: str, status: str) -> str:
        if status in ["running"]:
            return self.colorize(val, "boldgreen")
        elif status in ["paused", "restarting", "removing"]:
            return self.colorize(val, "boldyellow")
        else:
            return self.colorize(val, "boldred")

    def format_path(self, val: str) -> str:
        if "/" in val:
            parent, name = val.rsplit("/", maxsplit=1)
            return "/".join((parent, self.highlight(name)))
        else:
            return self.highlight(val)

    def highlight(self, val: str) -> str:
        return self.colorize(val, "bold")

    def colorize(self, val: str, color: str) -> str:
        return "".join((self._ansi_codes[color], val, self._ansi_codes['reset'])) if self._colors else val

    def parse_datetime(self, dt: str) -> float:
        return datetime.fromisoformat(dt.split(".")[0] + "+00:00").timestamp()


class Table:
    """Left- or right-aligned table cells."""

    def __init__(self, alignment: str, sep: str = " ") -> None:
        self._alignment: List[str] = [_ for _ in alignment]  # l/r
        self._columns: int = len(self._alignment)
        self._sep: str = sep
        self._lines: List[List[str]] = []

    def push(self, *args: str) -> None:
        self._lines.append(list(args))

    def flush(self, prefix: str) -> Iterator[str]:
        widths: List[int] = [max(len(_[i]) for _ in self._lines) if self._lines else 0 for i in range(self._columns)]
        first_row: bool = True
        for line in self._lines:
            line_prefix: str = prefix
            if first_row:
                first_row = False
                prefix = " " * len(prefix)

            columns: List[str] = []
            for i in range(self._columns):
                if widths[i] != 0:
                    val: str = line[i]
                    pad: str = " " * (widths[i] - len(val))
                    columns.append((val + pad) if self._alignment[i] == "l" else (pad + val))
            yield line_prefix + self._sep.join(columns)

    def __len__(self) -> int:
        return len(self._lines)


class ResultProcessor:
    @abstractmethod
    def process(self, df: SystemDf, ct: ContainerInspect, ps: Dict[int, PidStat], logs: List[str]) -> None:
        raise NotImplementedError


class ResultPrinter(ResultProcessor):
    """From collected container and process details, print most relevant information in tabular form."""

    def __init__(self) -> None:
        self._f: Formatter = Formatter()
        self._first: bool = True

    def _collect_status(self, df: SystemDf, ct: ContainerInspect) -> Iterator[Tuple[str, str]]:
        yield "Name", self._f.highlight(ct.Name.lstrip("/"))
        yield "Path", self._f.format_path(ct.Path)
        yield "State", self._f.format_status(ct.State.Status, ct.State.Status)

        if ct.State.StartedAt and ct.State.StartedAt != "0001-01-01T00:00:00Z":
            yield "Up since", self._f.format_duration(self._f.parse_datetime(ct.State.StartedAt))
        elif ct.State.FinishedAt and ct.State.FinishedAt != "0001-01-01T00:00:00Z":
            yield "Finished", self._f.format_duration(self._f.parse_datetime(ct.State.FinishedAt))
        if ct.Created and ct.Created != "0001-01-01T00:00:00Z":
            yield "Created", self._f.format_duration(self._f.parse_datetime(ct.Created))

        if ct.Image in df.Images:
            if df.Images[ct.Image].Created > 0:
                yield "Image at", self._f.format_duration(df.Images[ct.Image].Created)
            yield "Image", ", ".join([self._f.highlight(ct.Config.Image)] +
                                     [_ for _ in df.Images[ct.Image].RepoTags if _ != ct.Config.Image])
        else:
            yield "Image", self._f.highlight(ct.Config.Image)

        yield "Size", self._f.format_size(ct.SizeRootFs)

    def _collect_mounts(self, df: SystemDf, ct: ContainerInspect) -> Table:
        table: Table = Table("llllr")
        table.push("/", "image", ct.Image.replace("sha256:", "")[:12],
                   "RO" if ct.HostConfig.ReadonlyRootfs else "RW",
                   self._f.format_size(ct.SizeRw) if ct.SizeRw > 0 else "")

        for m in ct.Mounts:
            name_short: Optional[str] = m.Name[:12] if m.Name and re.fullmatch("[0-9a-f]{64}", m.Name) else m.Name
            if m.Type == "volume" and m.Name and m.Name in df.Volumes:
                table.push(m.Destination, m.Type, name_short or m.Name,
                           "RW" if m.RW else "RO", self._f.format_size(df.Volumes[m.Name].Size))
            else:
                table.push(m.Destination, m.Type, name_short if name_short else m.Source, "RW" if m.RW else "RO", "")

        for t in ct.HostConfig.Tmpfs.keys():
            table.push(t, "tempfs", "", "", "")
        if ct.HostConfig.ShmSize > 0:
            table.push("/dev/shm", "tempfs", "", "", self._f.format_size(ct.HostConfig.ShmSize))  # nosec

        return table

    def _collect_ports(self, ct: ContainerInspect) -> Table:
        table: Table = Table("llr")
        for p, bindings in ct.NetworkSettings.Ports.items():
            if bindings:
                for binding in bindings:
                    table.push(f"{binding.HostIp}:{binding.HostPort}", "->", p)
            else:
                table.push("", "", p)
        return table

    def _collect_pid_stats(self, pid: int, ps: Dict[int, PidStat]) -> Table:
        t: Table = Table("lrrrrrr")
        for _, stat in ProcStat.find_recursive(pid, ps):
            t.push(self._f.highlight(stat.name), str(stat.pid), str(stat.pgrp), stat.state,
                   self._f.format_jiffy(stat.ctime), self._f.format_pages(stat.rss),
                   f"({self._f.format_size(stat.vsize)})")
        return t

    def process(self, df: SystemDf, ct: ContainerInspect, ps: Dict[int, PidStat], logs: List[str]) -> None:
        pad_width: int = 13
        if self._first:
            self._first = False
        else:
            print()

        print(" ".join((self._f.format_status('●' if ct.State.Running else '◯', ct.State.Status), ct.Id[:12])))
        for k, v in self._collect_status(df, ct):
            print(f"{k:>{pad_width}}: {v}")

        for line in self._collect_mounts(df, ct).flush(f"{'Mounts':>{pad_width}}: "):
            print(line)

        networks: List[str] = [ct.HostConfig.NetworkMode] + [_ for _ in ct.NetworkSettings.Networks.keys()
                                                             if _ != ct.HostConfig.NetworkMode]
        print(f"{'Networks':>{pad_width}}: {', '.join(networks)}")
        for line in self._collect_ports(ct).flush(" " * (pad_width + 2)):
            print(line)

        processes: Table = self._collect_pid_stats(ct.State.Pid, ps)
        if len(processes):
            for line in processes.flush(f"{'Processes':>{pad_width}}: "):
                print(line)
        elif ct.State.Pid > 0:
            print(f"{'Process':>{pad_width}}: {ct.State.Pid}")

        if logs:
            print()
            for line in logs:
                print(line)


class DockerStatus:
    """Collect all information as configured and pass to per-container output formatter."""

    def __init__(self, requester: DockerClient, processor: ResultProcessor,
                 do_system_df: bool, do_proc_stat: bool, do_all: bool, do_logs: bool,
                 list_filter: Optional[str]) -> None:
        self._requester: DockerClient = requester
        self._processor: ResultProcessor = processor
        self._executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=4,  # as mostly not CPU bound
                                                                thread_name_prefix=self.__class__.__name__)
        self._list_filter: Optional[str] = list_filter
        self._filter_by: str = "id" if list_filter and re.fullmatch("[0-9a-f]{12}", list_filter) is not None else "name"
        self._do_all: bool = do_all if not list_filter else True
        self._do_system_df: bool = do_system_df
        self._do_proc_stat: bool = do_proc_stat
        self._log_tail: int = 10 if do_logs else 0

    def __enter__(self) -> 'DockerStatus':
        self._executor.__enter__()
        return self

    def __exit__(self, *args) -> None:
        self._executor.__exit__(*args)

    def _list_containers(self) -> List[Container]:
        filters = {"filters": json.dumps({self._filter_by: [self._list_filter]})} if self._list_filter else {}
        return [Container.parse(_) for _ in self._requester.request_json(
            "containers/json", all="true" if self._do_all else "false", size="false", **filters
        )]

    def _inspect_container(self, cid: str) -> Tuple[ContainerInspect, List[str]]:
        inspect: ContainerInspect = ContainerInspect.parse(self._requester.request_json(
            f"containers/{cid}/json", size="true"
        ))
        if self._log_tail <= 0 or inspect.HostConfig.LogConfig is None or inspect.HostConfig.LogConfig.Type == "none":
            return inspect, []
        try:
            return inspect, list(self._requester.request_logs(
                f"containers/{cid}/logs", stdout="true", stderr="true", timestamps="true", tail=str(self._log_tail)
            ))
        except RequestNotFoundError:
            return inspect, []

    def _get_system_df(self) -> SystemDf:
        return SystemDf.parse(self._requester.request_json(
            "system/df", type=["image", "volume"]
        )) if self._do_system_df else SystemDf()

    def _get_proc_stat(self) -> Dict[int, PidStat]:
        return ProcStat.run() if self._do_proc_stat else {}

    def run(self) -> None:
        containers_fut: Future[List[Container]] = self._executor.submit(self._list_containers)
        system_df_fut: Future[SystemDf] = self._executor.submit(self._get_system_df)
        proc_stat_fut: Future[Dict[int, PidStat]] = self._executor.submit(self._get_proc_stat)
        inspections_fut: List[Future[Tuple[ContainerInspect, List[str]]]] = [
            self._executor.submit(self._inspect_container, _.Id) for _ in containers_fut.result()
        ]

        system_df: SystemDf = system_df_fut.result()
        proc_stat: Dict[int, PidStat] = proc_stat_fut.result()
        for inspection_fut in inspections_fut:
            inspection, logs = inspection_fut.result()
            self._processor.process(system_df, inspection, proc_stat, logs)


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--all", action="store_const", const=True,
                        default=False, help="also show not running containers")
    parser.add_argument("--no-logs", action="store_const", const=True,
                        default=False, help="do not query container logs")
    parser.add_argument("--no-system-df", action="store_const", const=True,
                        default=False, help="disable fetching image and volume details")
    parser.add_argument("--no-proc-stat", action="store_const", const=True,
                        default=False, help="disable scanning /proc/<pid>/stat")
    parser.add_argument("--docker-socket", metavar="PATH", type=Path,
                        default=Path("/var/run/docker.sock"), help="docker socket for API requests")
    parser.add_argument("FILTER", nargs="?", type=str,
                        default=None, help="filter containers by id or name")
    args = parser.parse_args()

    try:
        with DockerStatus(DockerClient(unix_path=args.docker_socket), ResultPrinter(),
                          do_all=args.all, do_logs=not args.no_logs, do_system_df=not args.no_system_df,
                          do_proc_stat=not args.no_proc_stat, list_filter=args.FILTER) as status:
            status.run()
    except RequestError as e:
        print(str(e), file=sys.stderr)
        return 1
    except KeyboardInterrupt:
        return 130
    else:
        return 0


if __name__ == "__main__":
    sys.exit(main())