#!/usr/bin/env python3

"""
MPD Dynamic Playlist Daemon.
Uses playlist history context to heuristically determine similar songs which are then added to the playlist in order to
ensure there always is a certain number ahead of the currently played title.
Enabled when neither random, repeat, nor consume is set.
"""

import argparse
import configparser
import logging
import os
import random
import re
import socket
import sys
import threading

from abc import abstractmethod
from collections import defaultdict
from csv import reader as CsvReader
from dataclasses import dataclass
from difflib import SequenceMatcher
from functools import lru_cache
from pathlib import Path, PurePath
from typing import Dict, List, Set, Optional, Tuple, Iterator, DefaultDict, Any

import numpy as np
import requests
from mpd import MPDClient, MPDError
from mpd.base import escape as mpd_escape


class MPDIdle:
    """
    Subscribe to MPD state changes via the 'idle' command.
    Uses a dedicated connection maintained in an own thread, as an idle timeout would cause the socket to be reset,
    leading to frequent re-connects.
    """

    def __init__(self, host: str, port: int, subsystems: List[str]) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._host: str = host
        self._port: int = port
        self._subsystems: Set[str] = set(subsystems)

        self._client: MPDClient = MPDClient()
        self._client.timeout = 10.0
        self._client.idletimeout = None

        self._thread: Optional[threading.Thread] = None
        self._event: threading.Event = threading.Event()

    def _worker(self) -> None:
        try:
            while True:
                self._client.ping()  # fatal timeout here
                changes: List[str] = self._client.idle()
                if set(changes) & self._subsystems:
                    self._event.set()
        except (OSError, MPDError) as e:
            self._logger.warning(f"Stopping idle listener: {str(e)}")

    def start(self) -> None:
        if self._thread is None:
            try:
                self._client.connect(self._host, self._port)
            except (OSError, MPDError) as e:
                raise RuntimeError(f"Cannot connect to MPD at {self._host}:{self._port}: {str(e)}") from e
            else:
                self._logger.info(f"Connected to MPD at {self._host}:{self._port}")

            self._thread = threading.Thread(target=self._worker, name=self.__class__.__name__, daemon=True)
            self._thread.start()

    def stop(self) -> None:
        if self._thread is not None:
            try:
                self._client._sock.shutdown(socket.SHUT_RD)
            except OSError as e:
                self._logger.warning(f"Cannot close socket, ignoring: {str(e)}")
            else:
                self._logger.debug("Closed socket, wakeup listener")

            self._thread.join(timeout=self._client.timeout)
            if self._thread.is_alive():
                self._logger.warning("Listener did not terminate timely")
            self._thread = None  # NB: daemonic thread

            try:
                self._client.close()
            except (OSError, MPDError):
                pass
            try:
                self._client.disconnect()
            except OSError:
                pass

    def wait(self, timeout: Optional[float]) -> bool:
        if self._thread is None or not self._thread.is_alive():
            raise RuntimeError("Listener not connected")
        if self._event.wait(timeout):
            self._event.clear()
            return True
        else:
            return False


class MPD:
    """
    MPDClient wrapper for the commands in use.
    https://mpd.readthedocs.io/en/latest/protocol.html
    """

    def __init__(self, host: str, port: int) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._host: str = host
        self._port: int = port

        self._client: MPDClient = MPDClient()
        self._client.timeout = 10.0
        self._client.idletimeout = self._client.timeout

    def connect(self) -> None:
        try:
            self._client.connect(self._host, self._port)
        except (OSError, MPDError) as e:
            raise RuntimeError(f"Cannot connect to MPD at {self._host}:{self._port}: {str(e)}") from e
        else:
            self._logger.info(f"Connected to MPD at {self._host}:{self._port}")

    def disconnect(self) -> None:
        try:
            self._client.close()
        except (OSError, MPDError):
            pass
        try:
            self._client.disconnect()
        except OSError:
            pass

    def _command(self, cmd: str, *args, **kwargs) -> Any:
        # MPDClient dynamically creates handlers and does not provide typing anyway, so can unify error handling
        if not hasattr(self._client, cmd) or not callable(getattr(self._client, cmd)):
            raise RuntimeError(f"Unknown MPD command '{cmd}'")
        try:
            return getattr(self._client, cmd)(*args, **kwargs)
        except (OSError, MPDError) as e:
            raise RuntimeError(f"MPD command '{cmd}' failed: {str(e)}") from e

    def status(self) -> Dict[str, str]:
        return self._command("status")

    def add_file(self, uri: str) -> None:
        self._command("addid", uri)

    def find_file(self, filename: str) -> Optional[Dict[str, str]]:
        results = self._command("find", f'(file == "{mpd_escape(filename)}")')
        return results[0] if len(results) == 1 else None

    def playlist_song_info(self, song_id: str) -> Optional[Dict[str, str]]:
        try:
            return self._command("playlistid", song_id)[0]  # or: currentsong()
        except RuntimeError as e:
            self._logger.warning(str(e))
            return None

    def playlist_info(self, start: int) -> List[Dict[str, str]]:
        try:
            return self._command("playlistinfo", f"{max(0, start)}:")
        except RuntimeError as e:  # race for 'bad song index' ahead
            self._logger.warning(str(e))
            return self._command("playlistinfo")

    def playlist_files(self, start: int) -> List[str]:
        return [_.get("file", "") for _ in self.playlist_info(start)]  # or: playlistfind("file", uri)

    @lru_cache(maxsize=1024)
    def _search(self, key: str, val: str, fuzzy: Optional[bool]) -> List[Dict[str, str]]:
        if fuzzy is True:
            return self._command("search", key, val)
        elif fuzzy is False:  # still not case-sensitive, in contrast to `find`
            return self._command("search", f'({key} == "{mpd_escape(val)}")')
        else:
            return self._command("search", f'({key} contains_cs "{mpd_escape(val)}")')

    def search_artist(self, artist: str, fuzzy: bool = False) -> List[str]:
        return [_.get("file", "") for _ in self._search("artist", artist, fuzzy)]

    def search_genre(self, genre: str, fuzzy: bool = False) -> List[str]:
        return [_.get("file", "") for _ in self._search("genre", genre, fuzzy)]

    def search_filename(self, filename: PurePath) -> List[str]:
        return [_.get("file", "") for _ in self._search("file", filename.name, fuzzy=None)]


@dataclass(frozen=True)
class Song:
    artist: str
    file: str
    song_id: str  # actually an integer
    genre: Optional[str]

    @classmethod
    def from_dict(cls, dct: Optional[Dict[str, str]]) -> Optional['Song']:
        try:
            return cls(
                artist=dct["artist"],
                file=dct["file"],
                song_id=dct.get("songid", dct.get("id", "")),
                genre=dct.get("genre", None) or None,
            ) if dct is not None else None
        except (ValueError, TypeError, KeyError, IndexError):
            return None


class SongHistory:
    """
    Queue for recently played songs that are the basis for upcoming suggestions.
    """

    def __init__(self, max_len: int) -> None:
        self._max_len: int = max_len
        self._history: List[Song] = []

    def push(self, song: Song) -> bool:
        if len(self._history) and self._history[-1].file == song.file:
            return False
        else:
            self._history.append(song)
        while len(self._history) > self._max_len:
            self._history.pop(0)
        return True

    def clear(self) -> None:
        self._history.clear()

    @property
    def max_len(self) -> int:
        return self._max_len

    @property
    def last(self) -> Song:
        return self._history[-1]  # KeyError

    def __iter__(self) -> Iterator[Song]:
        yield from self._history

    def __len__(self) -> int:
        return len(self._history)


class Suggester:
    """
    A method for deriving suggestions to be added from a certain song from the history.
    """

    def __init__(self, mpd: MPD, weight: float = 1.0) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._mpd: MPD = mpd
        self._weight: float = weight  # base weight, zero basically disables this instance

    @property
    def weight(self) -> float:
        return self._weight

    @abstractmethod
    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        raise NotImplementedError


class ArtistSuggester(Suggester):
    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        """Give all files of the current artist from the library."""
        for file in self._mpd.search_artist(history.artist):
            yield 1.0, file


class GenreSuggester(Suggester):
    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        """Give all files with the current genre from the library."""
        if history.genre:
            for file in self._mpd.search_genre(history.genre):
                yield 1.0, file


class FuzzyGenreSuggester(Suggester):
    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        """Split the genre into substrings and search in library."""
        # for example: 'Progressive Rock' -> 'Rock', 'Progressive Metal', 'Hard Rock', ...
        if history.genre:
            for part in re.split("[ ,/&-]+", history.genre):
                if part.isalnum() and len(part) >= 3:
                    sim: float = SequenceMatcher(lambda _: _ in " ,/&-'", history.genre.lower(), part.lower()).ratio()
                    for file in self._mpd.search_genre(part, fuzzy=True):
                        yield sim, file


class LastFMSuggester(Suggester):
    def __init__(self, mpd: MPD, weight: float, api_key: Optional[str]) -> None:
        super().__init__(mpd, weight=weight if api_key is not None else 0.0)
        self._api_key: Optional[str] = api_key
        self._fuzzy: bool = True
        self._session = requests.Session()
        if self.weight <= 0.0:
            self._logger.warning("LastFM suggestions disabled")

    @lru_cache(maxsize=1024)
    def _request(self, artist: str) -> List[Tuple[float, str]]:
        if self._api_key is None:
            return []

        try:
            response: requests.Response = self._session.get("https://ws.audioscrobbler.com/2.0/", timeout=10.0, params={
                "method": "artist.getSimilar",
                "artist": artist,
                "autocorrect": "1" if self._fuzzy else "0",
                "limit": "50",
                "format": "json",
                "api_key": self._api_key,
            })
            response.raise_for_status()
            data: Dict = response.json()
        except (OSError, requests.RequestException) as e:
            raise RuntimeError(f"LastFM API call failed: {str(e)}") from None

        results: List[Tuple[float, str]] = []  # not yielding here to allow for simple caching
        if isinstance(data, dict) and "similarartists" in data:
            if isinstance(data["similarartists"], dict) and "artist" in data["similarartists"]:
                if isinstance(data["similarartists"]["artist"], list):
                    for artist in data["similarartists"]["artist"]:
                        if isinstance(artist, dict):
                            try:
                                results.append((float(artist.get("match", 1.0)), str(artist["name"])))
                            except (ValueError, TypeError, KeyError):
                                pass
        return results

    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        """Search LastFM for similar artists (weighted) and give all their files from the library."""
        for weight, result in self._request(history.artist):
            for file in self._mpd.search_artist(result):
                yield weight, file


class DistanceMatrixSuggester(Suggester):
    def __init__(self, mpd: MPD, weight: float, csv_filename: Optional[str]) -> None:
        super().__init__(mpd, weight=weight if csv_filename is not None else 0.0)
        self._filenames, coordinates = self._parse_coords(csv_filename)
        self._distances: np.ndarray = self._compute_distances(coordinates)
        self._file_map: Dict[str, int] = {f.name: i for i, f in enumerate(self._filenames)}  # assume no duplicate names
        self._limit: int = min(100, len(self._file_map) // 10)  # top 10%
        self._logger.info(f"Loaded {len(self._file_map)} distances")

    def _parse_coords(self, fn: Optional[str]) -> Tuple[List[PurePath], np.ndarray]:
        filenames: List[PurePath] = []
        coordinates: List[List[float]] = []
        if fn is not None:
            for f, c in self._read_coords(Path(fn)):
                filenames.append(f)
                coordinates.append(c)
        return filenames, np.asarray(coordinates)

    def _read_coords(self, fn: Path) -> Iterator[Tuple[PurePath, List[float]]]:
        try:
            with fn.open("r", newline="") as fp:
                for values in CsvReader(fp, delimiter=","):
                    yield PurePath(values[0]), [float(_) for _ in values[1:]]
        except (OSError, ValueError) as e:
            self._logger.error(f"Cannot parse '{fn.name}': {str(e)}")
            return

    def _compute_distances(self, coordinates: np.ndarray) -> np.ndarray:
        num: int = coordinates.shape[0]
        self._logger.debug(f"Computing {num}x{num} matrix")
        distances: np.ndarray = np.zeros(shape=(num, num))
        for i in range(num):
            for j in range(i):
                dist = np.linalg.norm(coordinates[i] - coordinates[j])  # euclidean
                distances[i][j] = dist
                distances[j][i] = dist
        for i in range(num):
            distances[i] /= distances[i].max()  # normalize per line/file
        return distances

    def _find_index(self, filename: str) -> Optional[int]:
        try:
            return self._file_map[PurePath(filename).name]
        except (KeyError, ValueError, TypeError):
            return None

    def suggest(self, history: Song) -> Iterator[Tuple[float, str]]:
        """Given a CSV file with filenames followed by coordinates, return those with smallest euclidean distance."""

        idx: Optional[int] = self._find_index(history.file)
        if idx is None:
            if len(self._filenames):
                self._logger.info(f"Cannot find weights for '{history.file}'")
            return

        for i, dist in sorted(enumerate(self._distances[idx]), key=lambda _: _[1])[:self._limit]:
            self._logger.debug(f"{history.file} -> {self._filenames[i]}: {dist}")
            if i != idx:
                for file in self._mpd.search_filename(self._filenames[i]):
                    yield 1.0 - dist, file


class DynPlaylist:
    """
    Main logic that reacts to playlist changes and adds newly suggested songs as needed.
    """

    def __init__(self, client: MPD, config: configparser.ConfigParser) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._client: MPD = client

        try:
            self._history: SongHistory = SongHistory(config.getint("history", "history_len", fallback=10))
            self._duplicate_history_max_len: int = config.getint("history", "duplicate_history_max_len", fallback=100)
            self._duplicate_history_len: int = config.getint("history", "duplicate_history_len", fallback=25)
            self._approved_timeout: float = config.getfloat("history", "approved_timeout", fallback=30.0)

            self._suggest_len: int = config.getint("playlist", "suggest_len", fallback=5)
            self._suggesters: List[Suggester] = [
                ArtistSuggester(self._client, config.getfloat("weights", "artist", fallback=1.0)),
                GenreSuggester(self._client, config.getfloat("weights", "genre", fallback=0.5)),
                FuzzyGenreSuggester(self._client, config.getfloat("weights", "fuzzy_genre", fallback=0.2)),
                LastFMSuggester(self._client,
                                config.getfloat("weights", "lastfm", fallback=1.0),
                                config.get("lastfm", "api_key", fallback=os.environ.pop("LASTFM_API_KEY", "")) or None),
                DistanceMatrixSuggester(self._client,
                                        config.getfloat("weights", "distance", fallback=1.0),
                                        config.get("distance", "csv_file", fallback="") or None)
            ]
            self._blacklist: Optional[re.Pattern] = (lambda p: re.compile(p) if p is not None else None)(
                config.get("playlist", "blacklist", fallback=None) or None
            )
        except (configparser.Error, re.error, UnicodeDecodeError, ValueError, TypeError) as e:
            raise RuntimeError(f"Cannot apply configuration: {str(e)}") from e

    @property
    def poll_interval(self) -> Optional[float]:
        return self._approved_timeout  # NB: should actually be less for more exact sampling

    def _is_blacklisted(self, file: str) -> bool:
        if not file or re.match("^[a-z]+://", file) is not None:
            return True
        return self._blacklist is not None and self._blacklist.search(file) is not None

    def _playlist_bootstrap_history(self, status: Dict[str, str]) -> None:
        """When starting up, read the end of the current playlist for bootstrapping suggestions."""
        files: List[Dict[str, str]] = self._client.playlist_info(
            int(status.get("playlistlength", 0)) - self._history.max_len
        )[-self._history.max_len:]
        for file in files:
            song: Optional[Song] = Song.from_dict(file)
            if song is not None and not self._is_blacklisted(song.file):
                if self._history.push(song):
                    self._logger.info(f"Added to history: {song.file}")

    def _playlist_get_duplicate_weights(self, status: Dict[str, str]) -> Dict[str, float]:
        """To prevent adding duplicates, reduce the weight for files already in (at the end of) the playlist."""
        playlist: List[str] = self._client.playlist_files(
            int(status.get("playlistlength", 0)) - self._duplicate_history_max_len
        )[-self._duplicate_history_max_len:] if self._duplicate_history_max_len > 0 else []
        return {e: 0.1 if i < len(playlist) - self._duplicate_history_len else 0.001 for i, e in enumerate(playlist)}

    def _is_enabled(self, status: Dict[str, str]) -> bool:
        """For control on whether we should be active (or reset), check the 'random', 'repeat', and 'consume' flags."""
        return \
            status.get("random", "0") == "0" and \
            status.get("repeat", "0") == "0" and \
            status.get("consume", "0") == "0"

    def _song_is_approved(self, status: Dict[str, str]) -> bool:
        """Only add songs that are not skipped for a certain time to the history. (This is the reason for polling.)"""
        return float(status.get("elapsed", 0.0)) >= self._approved_timeout

    def _get_suggest_len(self, status: Dict[str, str]) -> int:
        """How many songs should be added to the playlist to maintain the configured threshold"""
        return max(0, self._suggest_len - (int(status.get("playlistlength", 0)) - int(status.get("song", 0)) - 1))

    def _update_history(self, status: Dict[str, str]) -> None:
        """Get the currently played song and add it to the history."""
        if len(self._history) and self._history.last.song_id == status.get("songid", ""):  # early exit
            return
        if not self._song_is_approved(status):
            self._logger.debug("Current song not (yet) approved")
            return

        song: Optional[Song] = Song.from_dict(self._client.playlist_song_info(status.get("songid", "")))
        if song is not None and not self._is_blacklisted(song.file):
            if self._history.push(song):
                self._logger.info(f"Approved for history: {song.file}")

    def _get_suggestions(self, weights: Dict[str, float]) -> Dict[str, float]:
        """Run all suggestion approaches on the current history, accumulating weights."""
        suggestions: DefaultDict[str, float] = defaultdict(float)  # duplicates leading to additional weights
        for i, song in enumerate(self._history):  # NB: same weights, regardless of history position
            for suggester in self._suggesters:
                if suggester.weight <= 0.0:  # enabled?
                    continue
                try:
                    results: List[Tuple[float, str]] = [
                        (weight, suggestion)
                        for weight, suggestion in suggester.suggest(song)
                        if not self._is_blacklisted(suggestion)
                    ]
                except RuntimeError as e:
                    self._logger.error(f"{suggester.__class__.__name__} failed: {str(e)}")
                    continue
                for weight, suggestion in results:
                    suggestions[suggestion] += weight * suggester.weight * weights.get(suggestion, 1.0)
        return suggestions

    def _append_suggestions(self, status: Dict[str, str]) -> None:
        """Generate suggestions and heuristically select the needed amount to be added to the playlist."""
        num: int = self._get_suggest_len(status)
        if num > 0 and len(self._history) > 0:
            suggestions: Dict[str, float] = self._get_suggestions(self._playlist_get_duplicate_weights(status))
            self._logger.info(f"Got {len(suggestions)} suggestions for {num} entries from {len(self._history)} history")
        else:
            suggestions = {}

        while num > 0 and len(suggestions) > 0:
            num -= 1
            suggestion: str = random.choices(list(suggestions.keys()),  # nosec
                                             weights=list(suggestions.values()),
                                             k=1)[0]
            self._logger.info(f"Appending: {suggestion}")
            self._client.add_file(suggestion)
            del suggestions[suggestion]

    def test(self, filename: str) -> None:
        song: Optional[Song] = Song.from_dict(self._client.find_file(filename))
        if song is None:
            self._logger.info(f"Not found: {filename}")
        else:
            self._logger.info(f"Found {filename}: {str(song)}")
            self._history.push(song)
            for file, weight in sorted(self._get_suggestions({}).items(), key=lambda _: _[1], reverse=True):
                self._logger.info(f"{weight:0.3f} {file}")

    def run(self) -> None:
        """Periodic handler that - if enabled - updates history and adds suggestions to the playlist accordingly."""
        status: Dict[str, str] = self._client.status()

        if status.get("state", "") != "play":
            return  # neither update history nor add suggestions

        if not self._is_enabled(status):
            if len(self._history):
                self._logger.info("Not enabled, resetting state")
                self._history.clear()  # bootstrap again next time
            return
        if not len(self._history):
            self._logger.info("Bootstrapping state")
            self._playlist_bootstrap_history(status)

        self._update_history(status)
        self._append_suggestions(status)


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--debug", action="store_const", const=True, default=False,
                        help="enable more verbose log output")
    parser.add_argument("--host", type=str, default="127.0.0.1",
                        help="MPD hostname or ip address for TCP connections")
    parser.add_argument("--port", type=int, default=6600,
                        help="MPD port number for TCP connections")
    parser.add_argument("--config", metavar="CONFIG.INI", type=str, default=None,
                        help="configuration file")
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO,
                        format="%(levelname)-8s %(name)s: %(message)s")
    logger: logging.Logger = logging.getLogger(main.__name__)

    config: configparser.ConfigParser = configparser.ConfigParser()
    if args.config:
        try:
            config.read(args.config)
        except (OSError, configparser.Error, UnicodeDecodeError) as e:
            logger.error(f"Cannot read configuration from '{args.config}': {str(e)}")
            return 1
        else:
            logger.info(f"Read configuration from '{args.config}'")

    random.seed()
    mpd: MPD = MPD(args.host, args.port)
    mpd_idle: MPDIdle = MPDIdle(args.host, args.port, ["player", "playlist", "options"])

    try:
        playlist: DynPlaylist = DynPlaylist(mpd, config)
        logger.info("Starting up")

        mpd.connect()
        mpd_idle.start()
        while True:
            playlist.run()
            mpd_idle.wait(playlist.poll_interval)
    except KeyboardInterrupt:
        logger.info("Interrupt, shutting down")
        return 0
    except RuntimeError as e:
        logger.error(str(e))
        return 1
    except Exception as e:
        logger.exception(repr(e))
        return 2
    finally:
        mpd.disconnect()
        mpd_idle.stop()


if __name__ == "__main__":
    sys.exit(main())