import argparse
import logging
import sys
from functools import cached_property
from pathlib import Path
from queue import Queue, Empty as QueueEmpty
from threading import Thread, Event
from typing import Generic, TypeVar, Callable, ContextManager, Iterable, Iterator

import librosa
import librosa.feature
import numpy as np
from mutagen import File, FileType, MutagenError

from .features import Features, AudioFeatures, TagFeatures, FileFeatures
from .log import run, Status
from .utils import add_args, CSVDictReader, CSVDictWriter, M3UReader

T = TypeVar('T')
F = TypeVar('F')


class ThreadPool(ContextManager, Generic[F, T]):
    def __init__(self, handler: Callable[[F], T], concurrency: int) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._concurrency: int = max(1, min(concurrency, 16))
        self._callback: Callable[[F], T] = handler
        self._queue_in: Queue[F | None] = Queue()
        self._queue_out: Queue[T | Exception] = Queue()
        self._threads: list[Thread] = []
        self._cancel: Event = Event()

    def __enter__(self) -> None:
        assert self._concurrency > 0
        assert len(self._threads) == 0
        self._logger.info(f"Using {self._concurrency} thread(s)")
        self._cancel.clear()
        self._threads = [Thread(target=self._handler) for _ in range(self._concurrency)]
        for t in self._threads:
            t.start()

    def __exit__(self, *args, **kwargs) -> None:
        self._cancel.set()
        for t in self._threads:
            t.join()
        self._threads.clear()

    def _handler(self) -> None:
        while not self._cancel.is_set():
            item: F | None = self._queue_in.get()
            if item is None:
                self._queue_in.task_done()
                break
            try:
                self._queue_out.put(self._callback(item))
            except Exception as e:
                self._queue_out.put(e)
            finally:
                self._queue_in.task_done()

    def close(self):
        for _ in range(self._concurrency):
            self._queue_in.put(None)

    def push(self, item: F) -> None:
        self._queue_in.put(item)

    def is_empty(self) -> bool:
        return self._queue_in.unfinished_tasks == 0 and self._queue_out.unfinished_tasks == 0

    def pop(self, block: bool) -> T | Exception:
        result: T | Exception = self._queue_out.get(block=block)
        self._queue_out.task_done()
        return result

    def feed_loop(self, items: Iterable[F]) -> Iterator[T | Exception]:
        with self:
            for item in items:
                self.push(item)
                try:
                    yield self.pop(block=False)
                except QueueEmpty:
                    pass
            self.close()
            while not self.is_empty():
                yield self.pop(block=True)


class SoundFile:
    def __init__(self, filename: Path) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._filename: Path = filename
        self._samples: np.ndarray
        self._rate: int
        self._samples, self._rate = librosa.load(filename, sr=None, mono=True)
        self._n_bins: int = 16
        self._fft_window: int = self._find_pow2(self._rate, narrow=1)  # 1 sec, should cover at least one beat smoothly
        self._fft_window_short: int = self._find_pow2(self._rate // 10, narrow=1)  # 512/2048 defaults for 22050
        self._fft_frequencies = librosa.fft_frequencies(sr=self._rate, n_fft=self._fft_window).tolist()
        self._n_freq: int = len(self._fft_frequencies)  # 16385

    @property
    def rate(self) -> int:
        return self._rate

    @property
    def num_samples(self) -> int:
        return self._samples.shape[0]

    @property
    def length_sec(self) -> float:
        return self.num_samples / self.rate

    @classmethod
    def _clamp(cls, data: np.ndarray, axis: int | None = None, percentile: int = 1) -> np.ndarray:
        # clamp vast outliers early
        return data.clip(np.percentile(data, percentile, axis=axis, keepdims=True),
                         np.percentile(data, 100 - percentile, axis=axis, keepdims=True))

    @classmethod
    def _avg_med_dev_l(cls, data: np.ndarray, axis: int | None, shape: int | None, percentile: int = 1) -> tuple[list[float], list[float], list[float]]:
        data = cls._clamp(data, axis, percentile)
        avg, med, dev = np.mean(data, axis=axis), np.percentile(data, 50, axis=axis), np.std(data, axis=axis)
        if shape is not None:
            assert data.shape[0] == shape, data.shape
            assert avg.shape == (shape,), avg.shape
            assert med.shape == (shape,)
            assert dev.shape == (shape,)
        return avg.tolist(), med.tolist(), dev.tolist()

    @classmethod
    def _avg_med_dev(cls, data: np.ndarray) -> tuple[float, float, float]:
        avg, med, dev = cls._avg_med_dev_l(data, axis=1, shape=1)
        return avg[0], med[0], dev[0]

    @classmethod
    def _find_pow2(cls, target: int, narrow: int = 0) -> int:
        """Next greater (narrow=0) or smaller/equal (narrow=1) power of two."""
        return 2 ** (int(target).bit_length() - narrow)

    @cached_property
    def _fft(self) -> np.ndarray:
        return librosa.stft(y=self._samples, n_fft=self._fft_window)

    @property
    def _fft_abs(self) -> np.ndarray:
        return np.abs(self._fft)  # energy (magnitude) spectrum instead of power

    @property
    def _fft_mag(self) -> np.ndarray:
        return librosa.magphase(self._fft)[0]

    @property
    def _fft_pow(self) -> np.ndarray:
        return librosa.util.abs2(self._fft)

    def normalize(self) -> None:
        self._samples, _ = librosa.effects.trim(self._samples)
        self._logger.info(f"{self._filename.name}: {len(self._samples)} samples @ {self._rate} Hz, {self._fft_window} window @ {self._n_freq} frequencies")

    def centroid(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.spectral_centroid(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window)
        data = librosa.hz_to_mel(data)
        return self._avg_med_dev(data)

    def bandwidth(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.spectral_bandwidth(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window)
        data = librosa.hz_to_mel(data)
        return self._avg_med_dev(data)

    def flatness(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.spectral_flatness(y=self._samples, S=self._fft_mag, n_fft=self._fft_window)
        data = librosa.amplitude_to_db(data)
        return self._avg_med_dev(data)

    def crossing(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.zero_crossing_rate(y=self._samples, frame_length=self._fft_window, hop_length=self._fft_window // 4)
        return self._avg_med_dev(data)

    def rollon(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.spectral_rolloff(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window, roll_percent=0.10)
        data = librosa.hz_to_mel(data)
        return self._avg_med_dev(data)

    def rolloff(self) -> tuple[float, float, float]:
        data: np.ndarray = librosa.feature.spectral_rolloff(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window, roll_percent=0.90)
        data = librosa.hz_to_mel(data)
        return self._avg_med_dev(data)

    def spectrum(self) -> tuple[list[float], list[float], list[float]]:
        bark_bands: list[int] = [100, 510, 1080, 2000, 3700, 15500]
        spec = self._fft_pow
        total, num = np.zeros((len(bark_bands), spec.shape[1])), np.array([[0]] * len(bark_bands))
        for i, f in enumerate(self._fft_frequencies):
            for j, b in enumerate(bark_bands):
                if f < b:
                    total[j] += spec[i]
                    num[j][0] += 1
                    break
        rms = np.sqrt(total / num.clip(1, None))  # as using power spectrum
        return self._avg_med_dev_l(rms, axis=1, shape=len(bark_bands))

    def tempogram(self) -> tuple[float, float, float]:
        hop_length: int = self._fft_window_short // 4
        win_length: int = 384  # already adjusted to SR, so can keep win_length=384 for 8.9sec
        onset = librosa.onset.onset_strength(y=self._samples, sr=self._rate, aggregate=np.median, hop_length=hop_length, n_fft=self._fft_window_short)
        tempo, _ = librosa.beat.beat_track(sr=self._rate, onset_envelope=onset, hop_length=hop_length)
        tempogram = librosa.feature.tempogram(onset_envelope=onset, sr=self._rate, hop_length=hop_length, win_length=win_length)[:12]
        dev = np.mean(self._clamp(np.std(tempogram, axis=0)))
        flatness = np.mean(self._clamp(1 - np.square(
            np.exp(np.mean(np.log(tempogram.clip(0.001, None)), axis=0, keepdims=False)) /
            np.mean(tempogram.clip(0.001, None), axis=0, keepdims=False)
        )))  # similar to spectral_flatness
        return tempo.item(), flatness, dev

    def hpss(self) -> tuple[float, float, float]:
        H, P = librosa.decompose.hpss(S=self._fft)
        harm = librosa.feature.rms(S=librosa.magphase(H)[0], frame_length=self._fft_window, hop_length=self._fft_window // 4)
        perc = librosa.feature.rms(S=librosa.magphase(P)[0], frame_length=self._fft_window, hop_length=self._fft_window // 4)
        return self._avg_med_dev(perc - harm)


class FeatureExtractor:
    def __init__(self, cache: list[Features], cache_only: bool) -> None:
        self._cache: dict[Path, Features] = {_.file.filename: _ for _ in cache}
        self._cache_only: bool = cache_only

    def _parse_audio(self, filename: Path) -> AudioFeatures:
        samples: SoundFile = SoundFile(filename)
        if not 2 * 60 <= samples.length_sec <= 20 * 60:
            raise OverflowError(f"{filename.name}: Skipped, {round(samples.length_sec)} sec")
        samples.normalize()

        return AudioFeatures(
            rate=samples.rate,
            length=samples.length_sec,
            tempogram=samples.tempogram(),
            centroid=samples.centroid(),
            rollon=samples.rollon(),
            rolloff=samples.rolloff(),
            crossing=samples.crossing(),
            flatness=samples.flatness(),
            bandwidth=samples.bandwidth(),
            hpss=samples.hpss(),  # extra slow
            spectrum=samples.spectrum(),
        )

    def _parse_tags(self, filename: Path) -> TagFeatures:
        try:
            file_info: FileType | None = File(filename, easy=True)
        except (MutagenError, OSError) as e:
            raise OSError(f"{filename.name}: {str(e)}") from None
        if file_info is None:
            raise RuntimeError(f"{filename}: Cannot parse")

        artist: list[str] = file_info.get("artist", [])
        genre: list[str] = file_info.get("genre", [])
        return TagFeatures(artist=artist[0] if artist else "", genre=genre[0] if genre else "")

    def handle(self, job: tuple[int, Path]) -> Features:
        file: FileFeatures = FileFeatures(index=job[0], filename=job[1])

        if not file.filename.is_file():
            raise ValueError(f"{file.filename.name}: Not a file")
        elif not 1_000_000 <= file.filename.stat().st_size <= 50_000_000:
            raise OverflowError(f"{file.filename.name}: Skipped, {round(file.filename.stat().st_size / 1_000_000)} MB")

        if file.filename in self._cache:
            return Features(
                file=file,
                tags=self._cache[file.filename].tags,
                audio=self._cache[file.filename].audio,
            )
        elif self._cache_only:
            raise LookupError(f"{file.filename.name}: No cache entry found")
        else:
            return Features(
                file=file,
                tags=self._parse_tags(file.filename),
                audio=self._parse_audio(file.filename),
            )


def run_main(status: Status, *,
             out_dir: Path,
             concurrency: int,
             extract_cache: bool | None,
             extract_limit: int | None,
             extract_in_playlist: Path,
             extract_features: Path,
             **kwargs) -> None:

    cache: list[Features] = []
    out_dir.mkdir(parents=False, exist_ok=True)
    if (out_dir / extract_features).is_file():
        if extract_cache is True:
            status.logger.warning("Skipping, cache exists")
            return
        elif extract_cache is None:
            status.start_progress("Reading cache")
            with CSVDictReader(out_dir / extract_features) as reader:
                cache = list(Features.from_dict(row) for row in status.task_pending_counter(reader.read()))
            status.finish_progress()

    status.start_progress("Processing audio files")
    with CSVDictWriter(out_dir / extract_features, Features.fieldnames()) as writer:
        extractor: FeatureExtractor = FeatureExtractor(cache, cache_only=extract_cache is True)
        threads: ThreadPool[tuple[int, Path], Features] = ThreadPool(extractor.handle, concurrency=concurrency)
        with M3UReader(extract_in_playlist) as reader:
            for result in status.task_done_counter(threads.feed_loop(status.task_add_counter(reader.read_limited(extract_limit)))):
                if isinstance(result, Exception):
                    status.logger.warning(f"{result.__class__.__name__}: {str(result)}")
                else:
                    status.logger.debug(f"Done: {str(result)}")
                    writer.write(result.to_dict())


def main() -> int:
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    run(run_main, **vars(args))
    return 0


if __name__ == "__main__":
    sys.exit(main())