#!/usr/bin/env python3

"""
Start a webserver and run youtube-dl on the requests, returning a playlist with audio URLs.
Request filenames are ignored, query parameters are appended to 'https://www.youtube.com/watch?'.
Only audio URLs will be resolved by youtube-dl, all results are returned as a single M3U playlist.
"""

import argparse
import logging
import signal
import sys
import threading
import time
from http.server import HTTPServer, ThreadingHTTPServer, BaseHTTPRequestHandler
from threading import Lock
from typing import Type, Optional, List, Tuple, Dict, ClassVar, Iterator
from urllib.parse import urlparse

from youtube_dl import YoutubeDL
from youtube_dl.utils import YoutubeDLError, std_headers, random_user_agent


class UrlExtractor:
    def __init__(self, preferred_ext: Optional[str] = None) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._opts: Dict = {
            "forceurl": True, "simulate": True, "skip_download": False, "noplaylist": True,
            "verbose": False, "no_color": True, "logger": self._logger,
            "call_home": False, "cachedir": False, "retries": 3, "fragment_retries": 3,
            "format": f"bestaudio[ext={preferred_ext}]/bestaudio" if preferred_ext is not None else "bestaudio",
        }

    def _extract_urls(self, info) -> Iterator[str]:
        if not isinstance(info, dict):
            raise RuntimeError(f"Unexpected return value, got {type(info)}")
        if "url" in info:
            yield info["url"]
        if "requested_formats" in info:
            yield from (f["url"] for f in info["requested_formats"] if "url" in f and f["acodec"] != "none")

    def extract_urls(self, url: str, user_agent: Optional[str] = None) -> List[str]:
        std_headers["User-Agent"] = user_agent if user_agent is not None else random_user_agent()  # XXX:
        try:
            with YoutubeDL(self._opts) as ydl:
                info = ydl.extract_info(url)
        except (OSError, YoutubeDLError) as e:
            raise RuntimeError(str(e)) from e
        else:
            return sorted(set(self._extract_urls(info)))


class _Cache:
    def __init__(self) -> None:
        self.ttl: int = 0
        self._cache: Dict[str, Tuple[int, List[str]]] = {}
        self._lock: Lock = Lock()

    def get(self, key: str) -> Optional[List[str]]:
        with self._lock:
            timeout: int = int(time.time()) - self.ttl
            for k in list(self._cache.keys()):  # maintain by purging outdated entries
                if self._cache[k][0] <= timeout:  # could break otherwise if we assume an OrderedDict
                    del self._cache[k]
            return self._cache[key][1] if key in self._cache else None

    def set(self, key: str, val: List[str]) -> None:
        with self._lock:
            self._cache[key] = (int(time.time()), val)


class _RequestHandler(BaseHTTPRequestHandler):
    cache: ClassVar[_Cache] = _Cache()
    preferred_ext: ClassVar[Optional[str]] = None

    def log_message(self, fmt: str, *args):
        logging.getLogger("Server").info("[{}] {}".format(self.address_string(), fmt % args))

    def do_GET(self) -> None:
        try:
            query: str = urlparse(self.path).query
        except ValueError:
            self.send_error(400)
            return
        if not query:
            self.send_error(404)
            return

        url: str = f"https://www.youtube.com/watch?{query}"
        result: Optional[List[str]] = self.cache.get(url)
        if result is not None:
            self._send_playlist_response(result)
            return

        try:
            result = UrlExtractor(preferred_ext=self.preferred_ext).extract_urls(
                url=url, user_agent=self.headers.get("user-agent", None)
            )
        except RuntimeError as e:
            self.send_error(502,
                            e.__cause__.__class__.__name__ if e.__cause__ is not None else e.__class__.__name__,
                            str(e))
        else:
            self.cache.set(url, result)
            self._send_playlist_response(result)

    def _send_playlist_response(self, urls: List[str]) -> None:
        if len(urls) == 0:
            self.send_response(204)
            self.send_header("Cache-Control", "private, max-age={}".format(self.cache.ttl))
            self.send_header("Content-Length", "0")
            self.end_headers()
            return

        body: bytes = "\r\n".join(urls).encode("utf-8", errors="strict") + b"\r\n"
        self.send_response(200)
        self.send_header("Cache-Control", "private, max-age={}".format(self.cache.ttl))
        self.send_header("Referrer-Policy", "no-referrer")
        self.send_header("Content-Type", "audio/x-mpegurl")  # ExtM3uPlaylistPlugin
        self.send_header("Content-Length", str(len(body)))
        self.end_headers()
        try:
            self.wfile.write(body)
        except OSError:  # BrokenPipeError
            self.close_connection = True


def _serve(server: HTTPServer) -> bool:
    logger: logging.Logger = logging.getLogger("Server")
    shutdown_requested: threading.Event = threading.Event()

    def _handler(signum: int, frame) -> None:
        shutdown_requested.set()

    def _shutdown() -> None:
        shutdown_requested.wait()
        server.shutdown()

    thread: threading.Thread = threading.Thread(target=_shutdown)
    thread.start()

    signal.signal(signal.SIGINT, _handler)
    signal.signal(signal.SIGTERM, _handler)

    logger.info("Serving on {}:{}".format(*server.server_address))
    server.serve_forever(poll_interval=2.0)
    logger.info("Shutting down")
    server.server_close()

    thread.join()
    return True


def _create_server(bind_all: bool, port: int) -> HTTPServer:
    klass: Type[HTTPServer] = ThreadingHTTPServer
    return klass(("0.0.0.0" if bind_all else "127.0.0.1", port),
                 _RequestHandler,
                 bind_and_activate=True)


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--bind-all", action="store_const", const=True, default=False,
                        help="bind to all interfaces instead of localhost only")
    parser.add_argument("--port", metavar="PORTNUM", type=int, default=8984,
                        help="port to bind to")
    parser.add_argument("--cache-ttl", metavar="SECONDS", type=int, default=3600,
                        help="cache lookups up to this duration in seconds, 0 to disable")
    parser.add_argument("--preferred-ext", metavar="EXT", type=str, default=None,
                        help="preferred audio file extension, for example m4a")
    args = parser.parse_args()

    _RequestHandler.cache.ttl = args.cache_ttl
    _RequestHandler.preferred_ext = args.preferred_ext
    logging.basicConfig(level=logging.DEBUG, format="%(levelname)s %(name)s: %(message)s")

    return 0 if _serve(_create_server(args.bind_all, args.port)) else 1


if __name__ == "__main__":
    sys.exit(main())