#!/usr/bin/env python3

"""
Logsearch HTTP Server and Web-GUI
"""

import argparse
import json
import gzip
import os
import re
import signal
import sys
import threading
import logging.config

from pathlib import Path
from base64 import b64encode
from urllib.parse import urlparse, parse_qs
from hmac import compare_digest

from http import HTTPStatus
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from ssl import wrap_socket

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Iterator, Dict, DefaultDict, List, Optional, Tuple, TextIO, Callable


@dataclass
class HttpRequest:
    path: str
    query: Dict[str, List[str]] = field(default_factory=dict)


@dataclass
class HttpResponse:
    code: int
    headers: Dict[str, str] = field(default_factory=dict)
    body: bytes = b''


class AuthHTTPServer(ThreadingHTTPServer):
    class AuthHTTPRequestHandler(BaseHTTPRequestHandler):
        server: 'AuthHTTPServer'

        def do_GET(self) -> None:
            self._handle()

        def do_POST(self) -> None:
            self._handle()

        def _handle(self) -> None:
            response: Optional[HttpResponse] = self.server.check_auth(self.headers.get("Authorization", None))
            if response is None:
                try:
                    content_length: int = int(self.headers.get("Content-Length", "0"))
                    post_data: str = self.rfile.read(content_length).decode("utf-8", errors="replace")
                except (ValueError, OSError) as e:
                    self.server.logger.warning(f"Cannot read request body: {str(e)}")
                    response = HttpResponse(400, {"X-Exception": e.__class__.__name__})
                else:
                    response = self.server.handle(self.path, post_data)

            self.send_response(response.code)
            for header, value in response.headers.items():
                self.send_header(header.title(), value)
            self.send_header("Content-Length", str(len(response.body)))
            self.end_headers()
            self.wfile.write(response.body)

        def log_request(self, code='-', size='-') -> None:
            if isinstance(code, HTTPStatus):
                code = code.value
            self._log(logging.DEBUG if code == 200 else logging.INFO,
                      '"%s" %s %s',
                      self.requestline, str(code), str(size))

        def log_message(self, format, *args) -> None:
            self._log(logging.WARNING, format, *args)

        def _log(self, level: int, format, *args) -> None:
            self.server.logger.log(level, "%s - - [%s] %s",
                                   self.address_string(), self.log_date_time_string(), format % args)

    def __init__(self,
                 server_address: Tuple[str, int],
                 handler: Callable[[HttpRequest], HttpResponse],
                 auth: Optional[str],
                 ssl_cert: Optional[Path], ssl_key: Optional[Path]) -> None:
        self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        super().__init__(server_address, self.AuthHTTPRequestHandler, bind_and_activate=True)
        self._auth: Optional[str] = b64encode(auth.encode("utf-8")).decode("ascii") if auth is not None else None
        self._handler: Callable[[HttpRequest], HttpResponse] = handler
        if ssl_cert is not None:
            self.socket = wrap_socket(self.socket, keyfile=ssl_key, certfile=ssl_cert, server_side=True)

    def check_auth(self, authorization: Optional[str]) -> Optional[HttpResponse]:
        if self._auth is None:
            return None
        elif authorization is not None and compare_digest(authorization, "Basic " + self._auth):
            return None
        else:
            return HttpResponse(401, {"WWW-Authenticate": "Basic realm=\"logsearch\""})

    def handle(self, path: str, body: str) -> HttpResponse:
        try:
            parsed = urlparse(path)
            path = parsed.path
            query: Dict[str, List[str]] = parse_qs(parsed.query)
            query.update(parse_qs(body))
        except ValueError as e:
            self.logger.warning(f"Cannot parse request: {str(e)}")
            return HttpResponse(400, {"X-Exception": e.__class__.__name__})
        try:
            return self._handler(HttpRequest(path, query))
        except Exception as e:
            self.logger.error(f"Cannot handle request: {str(e)}")
            return HttpResponse(500, {"X-Exception": e.__class__.__name__})

    def serve(self) -> bool:
        shutdown_requested: threading.Event = threading.Event()

        def _handler(signum: int, frame) -> None:
            shutdown_requested.set()

        signal.signal(signal.SIGINT, _handler)
        signal.signal(signal.SIGTERM, _handler)

        thread: threading.Thread = threading.Thread(target=self.serve_forever)
        thread.start()

        self.logger.info("Serving on {}:{}".format(*self.server_address))
        shutdown_requested.wait()
        self.logger.info("Shutting down")
        self.shutdown()
        self.server_close()
        thread.join()
        return True

    @classmethod
    def run(cls, localhost: bool, port: int, handler: Callable[[HttpRequest], HttpResponse],
            auth: Optional[str], ssl_cert: Optional[Path], ssl_key: Optional[Path]) -> bool:
        try:
            httpd: AuthHTTPServer = AuthHTTPServer(
                server_address=("127.0.0.1" if localhost else "0.0.0.0", port),
                handler=handler,
                auth=auth,
                ssl_cert=ssl_cert, ssl_key=ssl_key,
            )
        except Exception as e:
            raise RuntimeError(str(e)) from None
        else:
            return httpd.serve()


class LogDir:
    def __init__(self, log_dir: Path) -> None:
        self._root: Path = log_dir
        self._file_rx: re.Pattern = re.compile("(\\.log)?(\\.[0-9]+)?(\\.log)?(\\.old)?(\\.gz)?$")

    def _crawl(self, include: Optional[List[str]] = None) -> Iterator[Dict]:
        try:
            with os.scandir(self._root) as it:
                for entry in it:  # type: os.DirEntry
                    if include is not None and entry.name not in include:
                        continue
                    match: Optional[re.Match] = self._file_rx.search(entry.name)
                    if match is not None:
                        if entry.is_file(follow_symlinks=False):
                            ss: os.stat_result = entry.stat(follow_symlinks=False)
                            if ss.st_size > 0:
                                yield dict(key=entry.name[:match.start()],
                                           name=entry.name,
                                           size=ss.st_size,
                                           mtime=int(ss.st_mtime))
        except OSError as e:
            raise RuntimeError(str(e)) from None

    def crawl(self) -> Dict[str, List[Dict]]:
        tree: Dict[str, List[Dict]] = {}
        for entry in self._crawl():
            try:
                tree[entry['key']].append(entry)
            except KeyError:
                tree[entry['key']] = [entry]
        return tree

    def find(self, files: List[str]) -> Iterator[Path]:
        for entry in self._crawl(files):
            yield self._root / entry['name']  # TODO: preserve order or sort


class LogDirs:
    def __init__(self, dirs: List[Path], file_limit: Optional[int], mtime_sort: bool) -> None:
        self._file_limit: Optional[int] = file_limit
        self._mtime_sort: bool = mtime_sort
        self._logdirs: Dict[str, LogDir] = {
            str(path): LogDir(path) for path in dirs
        }

    def _file_sort_key(self, file: Dict) -> Tuple:
        filename_parts: Tuple = tuple(int(_) if _.isdecimal() else _ for _ in file.get("name", "").split("."))
        return (-1 * file.get("mtime", 0),) + filename_parts if self._mtime_sort else filename_parts

    def crawl(self) -> Dict[str, Dict[str, List[Dict]]]:
        tree: Dict[str, Dict[str, List[Dict]]] = {}
        for path, logdir in self._logdirs.items():
            entries: Dict[str, List[Dict]] = logdir.crawl()
            entries = {k: sorted(entries[k], key=self._file_sort_key)[:self._file_limit]
                       for k in sorted(entries.keys(), key=str.lower)}
            tree[path] = entries
        return tree

    def find(self, files: List[Tuple[str, str]]) -> Iterator[Path]:
        combined: DefaultDict[str, List[str]] = defaultdict(list)
        for p, f in files:
            combined[p].append(f)

        for p, fs in combined.items():
            try:
                yield from self._logdirs[p].find(fs)
            except KeyError:
                pass


class LogSearch:
    class Pattern:
        def __init__(self, pattern: str) -> None:
            if pattern.startswith("!"):
                self._negate: bool = True
                pattern = pattern[1:]
            elif pattern.startswith("\\!"):
                self._negate = False
                pattern = pattern[1:]
            else:
                self._negate = False
            try:
                self._rx: Optional[re.Pattern] = re.compile(pattern) if pattern else None
            except re.error as e:
                raise ValueError(str(e)) from None

        def match(self, line: str) -> bool:
            return self._rx is None or (self._rx.search(line) is not None) != self._negate

    class LineReader:
        def __init__(self, filename: Path) -> None:
            self._filename: Path = filename

        def readlines(self) -> Iterator[str]:
            fp: TextIO
            if self._filename.suffix == ".gz":
                with gzip.open(self._filename, "rt", encoding="utf-8", errors="replace") as fp:
                    yield from fp
            else:
                with self._filename.open("r", encoding="utf-8", errors="replace") as fp:
                    yield from fp

    class Limit:
        def __init__(self, limit: Optional[int], tail: bool) -> None:
            self._limit: Optional[int] = max(0, limit) if limit is not None else None
            self._tail: bool = tail

        def __call__(self, it: Iterator[str]) -> Iterator[str]:
            if self._limit is None:
                return it  # type: ignore[return-value]

            if not self._tail:  # head
                lineno: int = 0
                for line in it:
                    lineno += 1
                    if lineno > self._limit:  # one more s.t. no marker indicates nothing was skipped
                        yield "[…]\n"
                        break
                    yield line
            else:  # tail
                buf: List[str] = []  # queue
                truncated: bool = False
                for line in it:
                    buf.append(line)
                    if len(buf) > self._limit:
                        buf.pop(0)
                        if not truncated:
                            yield "[…]\n"
                            truncated = True
                yield from buf

    def __init__(self, pattern: str, limit: Optional[int], tail: bool) -> None:
        self._pattern = self.Pattern(pattern)
        self._limit = self.Limit(limit, tail)

    def _search(self, lines: Iterator[str]) -> Iterator[str]:
        for line in lines:
            if self._pattern.match(line):
                yield line.rstrip() + "\n"

    def search_all(self, files: Iterator[Path]) -> Iterator[str]:
        first: bool = True
        for file in files:
            if not first:
                yield "\n"
            first = False
            yield f"==== {file.name} ====\n"

            try:
                reader: LogSearch.LineReader = self.LineReader(file)
                yield from self._limit(self._search(reader.readlines()))
            except OSError as e:
                yield f"[{e.__class__.__name__}]\n"

    def search_all_buffered(self, files: Iterator[Path]) -> bytes:
        # TODO: chunked responses, with frontend support
        return "".join(self.search_all(files)).encode(encoding="utf-8", errors="replace")


class Handlers:
    def __init__(self, logs: LogDirs) -> None:
        self._logs: LogDirs = logs
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)

    def _handle_index(self) -> HttpResponse:
        return HttpResponse(
            code=200,
            headers={"Content-Type": "text/html; charset=utf-8"},
            body=INDEX_HTML,
        )

    def _handle_list(self) -> HttpResponse:
        try:
            return HttpResponse(
                code=200,
                headers={"Content-Type": "application/json; charset=utf-8"},
                body=json.dumps(self._logs.crawl(), ensure_ascii=False).encode(encoding="utf-8", errors="replace"),
            )
        except RuntimeError as e:
            self._logger.warning(f"Cannot crawl logs: {str(e)}")
            return HttpResponse(400, {"X-Exception": e.__class__.__name__})

    def _handle_search(self, q: Dict[str, List[str]]) -> HttpResponse:
        try:
            files: List[Tuple[str, str]] = [(_["p"], _["f"]) for _ in json.loads(q.get("f", ["[]"])[0])]
            query: str = q.get("q", [""])[0]
            limit: Optional[int] = int(q.get("l", ["0"])[0]) if q.get("l") not in [None, [""]] else None
            tail: bool = False if int(q.get("t", ["0"])[0]) == 0 else True
            search: LogSearch = LogSearch(query, limit, tail)
        except (json.JSONDecodeError, KeyError, IndexError, ValueError, RuntimeError) as e:
            self._logger.warning(f"Cannot handle search: {str(e)}")
            return HttpResponse(400, {"X-Exception": e.__class__.__name__})

        return HttpResponse(
            code=200,
            headers={"Content-Type": "text/plain; charset=utf-8"},
            body=search.search_all_buffered(self._logs.find(files)),
        )

    def handle(self, r: HttpRequest) -> HttpResponse:
        if r.path == "/":
            return self._handle_index()
        elif r.path == "/l":
            return self._handle_list()
        elif r.path == "/q":
            return self._handle_search(r.query)
        else:
            return HttpResponse(404)


def main(localhost: bool, port: int,
         auth: Optional[str], ssl_cert: Optional[Path], ssl_key: Optional[Path],
         file_limit: Optional[int], mtime_sort: bool,
         log_dirs: List[Path]) -> bool:
    logs: LogDirs = LogDirs(log_dirs, file_limit, mtime_sort)
    handlers: Handlers = Handlers(logs)
    try:
        return AuthHTTPServer.run(localhost, port, handlers.handle,
                                  auth, ssl_cert, ssl_key)
    except RuntimeError as e:
        logging.getLogger(None).error(str(e))
        return False


# language=CSS
_INDEX_CSS: str = r"""
    html, body {
        background-color: #ffffff;
        color: #000000;
        margin: 0;
        padding: 0;
    }
    body {
        height: 100vh;
        width: 100vw;
    }

    body, input {
        font-family: monospace;
    }

    input[type=button] {
        min-width: 3em;
    }

    .noselect {
        -webkit-touch-callout: none;
        -webkit-user-select: none;
        -khtml-user-select: none;
        -moz-user-select: none;
        -ms-user-select: none;
        user-select: none;
    }

    #container {
        display: grid;
        width: 100%;
        height: 100%;
        grid-template-columns: 0 auto;
        grid-template-rows: min-content auto;
    }

    #side-bar-toggle:checked ~ #container {
        grid-template-columns: min-content auto;
    }

    #side-bar-toggle-label {
        display: inline-block;
        box-sizing: border-box;
        background-color: #dddddd;
        width: 3rem;
        text-align: center;
        float: left;
    }

    #side-bar-toggle {
        display: none;
    }

    #search-bar {
        grid-column: span 2;
        background-color: #aaaaaa;
        border-bottom: #aaaaaa 1px solid;
        width: 100%;
        line-height: 3rem;
        text-align: center;
        white-space: nowrap;
    }

    #search-bar #query {
        width: 50%;
    }

    #content {
        box-sizing: border-box;
        padding: 1em;
        width: 100%;
        height: 100%;
        overflow: scroll;
    }

    #log {
        white-space: pre;
    }

    .log-key.filtered {
        display: none;
    }

    #side-bar {
        background-color: #dddddd;
        border-right: #aaaaaa 1px solid;
        overflow: auto;
        white-space: nowrap;
        height: 100%;
        width: auto;
        max-width: 90vw;
    }

    #side-bar > div {
        padding: 1em;
    }

    #settings {
        display: grid;
        grid-template-columns: max-content auto;
        grid-gap: 0.2em 1em;
        border-bottom: #aaaaaa 1px solid;
    }

    #settings > label, input {
        box-sizing: border-box;
        align-self: center;
    }

    #settings > label {
        justify-self: end;
    }

    #settings input[type=number], #settings input[type=search] {
        width: 100%;
    }

    #settings input[type=checkbox] {
        justify-self: start;
    }

    #side-bar input[type=checkbox] {
        padding: 0;
        margin: 0 0.5em 0 0;
    }

    #files-select ul {
        padding: 0 0 0 1em;
        margin: 0;
        list-style-type: none;
    }

    #files-select > ul {
        padding: 0;
    }
"""  # noqa

# language=JS
_INDEX_JS: str = r"""
    "use_strict";

    let filter_timeout = null;
    function filter_logs(needle) {
        clearTimeout(filter_timeout);
        filter_timeout = setTimeout(function () {
            let logs = document.getElementsByClassName("log-key");
            for (i = 0; i < logs.length; i++) {
                if (needle.length > 0 && !logs[i].innerText.toLowerCase().includes(needle.toLowerCase())) {
                    logs[i].classList.add("filtered");
                    logs[i].classList.remove("unfiltered");
                } else {
                    logs[i].classList.remove("filtered");
                    logs[i].classList.add("unfiltered");
                }
            }
        }, 50);
    }

    function select_logs(what) {
        const cbs = document.getElementById("files-select").querySelectorAll(".log-key.unfiltered input[type=checkbox]");
        for (var i=0; i<cbs.length; i++) {
            if (what == 'all') {
                cbs[i].checked = true;
            } else if (what == 'none') {
                cbs[i].checked = false;
            } else if (what == 'first') {
                if (cbs[i].parentElement.parentElement.previousSibling == null) {
                    cbs[i].checked = true;
                } else {
                    cbs[i].checked = false;
                }
            }
        }
    }

    function request(method, url, body, callback) {
        let xhr = new XMLHttpRequest();
        xhr.withCredentials = true;
        xhr.open(method, url, true);

        if (body !== null) {
            xhr.setRequestHeader("Content-Type", "application/x-www-form-urlencoded"); // Content-Length will be set
        }

        xhr.onreadystatechange = function() {
            if (this.readyState !== XMLHttpRequest.DONE) {
                return;
            }
            if (this.status !== 200) {
                callback(null);
            } else {
                callback(this.responseText);
            }
        }
        xhr.send(body);
    }

    document.getElementById("side-bar-toggle").checked = false;
    request("GET", "l", null, function (resp) {
        if (resp == null) {
            return;
        }
        const response = JSON.parse(resp);

        let ul = document.createElement("ul");
        for (var p in response) { // log path
            var li = document.createElement("li");
            li.innerText = p;

            let pul = document.createElement("ul");
            for (var f in response[p]) {  // log key
                var pli = document.createElement("li");
                pli.innerText = f;
                pli.classList.add("log-key");

                let lul = document.createElement("ul");
                for (var l in response[p][f]) {  // array entry
                    var lli = document.createElement("li");

                    var lbl = document.createElement("label");
                    var ch = document.createElement("input");
                    ch.setAttribute("type", "checkbox");
                    ch.dataset.p = p;
                    ch.dataset.f = response[p][f][l]["name"];
                    lbl.appendChild(ch);
                    lbl.appendChild(document.createTextNode(response[p][f][l]["name"]));

                    lli.appendChild(lbl);
                    lul.appendChild(lli);
                }
                pli.appendChild(lul);
                pul.appendChild(pli);
            }
            li.appendChild(pul);
            ul.appendChild(li);
        }
        document.getElementById("files-select").appendChild(ul);
        filter_logs(document.getElementById("filter").value);
        document.getElementById("side-bar-toggle").checked = true;
    });

    function search() {
        let inp = document.getElementById("query");
        let btn = document.getElementById("search");
        inp.disabled = true;
        btn.disabled = true;

        let files = [];
        const cbs = document.getElementById("files-select").querySelectorAll("input[type=checkbox]:checked");
        for (var i=0; i<cbs.length; i++) {
            files.push({
                p: cbs[i].dataset.p,
                f: cbs[i].dataset.f
            });
        }
        const query = document.getElementById("query").value;
        const limit = document.getElementById("limit").value;
        const tail = document.getElementById("tail").checked ? 1 : 0;

        request("POST", "q" +
                       "?q=" + encodeURIComponent(query) +
                       "&l=" + encodeURIComponent(limit) +
                       "&t=" + encodeURIComponent(tail),
                       "f=" + encodeURIComponent(JSON.stringify(files)),
                       function (resp) {
            if (resp == null) {
                document.getElementById("log").innerText = "[request error]\n";
            } else {
                document.getElementById("log").innerText = resp;
            }
            // document.getElementById("content").scrollTo(0, 0); // ???
            inp.disabled = false;
            btn.disabled = false;
        });
    }
"""  # noqa

# language=HTML
INDEX_HTML: bytes = r"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="content-type" content="text/html; charset=utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1,user-scalable=no">
<title>logsearch</title>
<link rel="icon" href="data:,">
<style>{}</style>
</head>
<body>
<input id="side-bar-toggle" type="checkbox">
<div id="container">
    <div id="search-bar"><label id="side-bar-toggle-label" for="side-bar-toggle" class="noselect" title="Toggle sidebar">≡</label><input id="query" placeholder=".*" title="Regular expression to search for, prefix with '!' to negate" type="search" results="10" autosave="logsearch" onkeydown="if (event.keyCode == 13) search();">&nbsp;<input id="search" type="button" value="🔍" onclick="search()"></div>
    <div id="side-bar">
        <div id="settings">
            <label for="limit">Limit</label><input type="number" id="limit" value="1000" min="0" step="100" title="Maximum number of lines to return per logfile">
            <label for="tail">Tail</label><input type="checkbox" id="tail" title="Return the last matching lines, not the first ones" checked>
            <label for="filter">Filter</label><input type="search" id="filter" onkeyup="filter_logs(this.value)" onchange="filter_logs(this.value)" title="Filter shown logfiles by name">
            <label>Select</label><div><input type="button" value="✓" onclick="select_logs('all')" title="Select all shown logfiles">&nbsp;<input type="button" value="✛" onclick="select_logs('first')" title="Select the first shown logfile per prefix">&nbsp;<input type="button" value="✗" onclick="select_logs('none')" title="Unselect all shown logfiles"></div>
        </div>
        <div id="files-select"></div>
    </div>
    <div id="content"><div id="log"></div></div>
</div>
<script>{}</script>
</body>
</html>
""".format(_INDEX_CSS, _INDEX_JS).encode("utf-8", errors="replace")  # noqa


def setup_logging(debug: bool) -> None:
    logging.raiseExceptions = False
    logging.logThreads = True
    logging.logMultiprocessing = False
    logging.logProcesses = False
    logging.config.dictConfig({
        'version': 1,
        'formatters': {'standard': {
            'format': '%(message)s',
        }},
        'handlers': {'default': {
            'formatter': 'standard',
            'class': 'logging.StreamHandler',
            'stream': 'ext://sys.stderr',
        }},
        'loggers': {'': {
            'handlers': ['default'],
            'level': 'DEBUG' if debug else 'INFO',
            'propagate': False,
        }},
    })


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action='store_const', const=True, default=False,
                        help='enable debug logging')
    parser.add_argument('--localhost', action='store_const', const=True, default=False,
                        help='bind to localhost only')
    parser.add_argument('--port', type=int, default=8080,
                        help='port to bind to')
    parser.add_argument('--auth', metavar='USER:PASS', type=str, default=None,
                        help='enable HTTP Basic authentication')
    parser.add_argument('--ssl-cert', metavar='CERT.PEM', type=Path, default=None,
                        help='enable HTTPS server')
    parser.add_argument('--ssl-key', metavar='KEY.PEM', type=Path, default=None,
                        help='certificate keyfile')
    parser.add_argument('--file-limit', metavar='NUM', type=int, default=None,
                        help='limit number of per-prefix shown log files')
    parser.add_argument('--mtime-sort', action='store_const', const=True, default=False,
                        help='sort log files by modification time')
    parser.add_argument('LOGDIR', type=Path, nargs='+',
                        help='directory containing log files to crawl')
    args = parser.parse_args()

    setup_logging(args.verbose)
    sys.exit(0 if main(args.localhost, args.port,
                       args.auth, args.ssl_cert, args.ssl_key,
                       args.file_limit, args.mtime_sort,
                       args.LOGDIR) else 1)