import logging
import zipfile
from pathlib import Path
from typing import Iterator, Optional, Dict, List

import numpy as np
import requests
import torch
from TTS.api import TTS, ModelManager

from sttts.api.message import ModelError, ModelNotFoundError
from sttts.api.model import Synthesizer, SynthesizerModel
from sttts.utils.utils import PerfCounter


class _CoquiSynthesizer(Synthesizer):
    def __init__(self, tts: TTS, **kwargs) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__.lstrip("_"))
        self._tts: TTS = tts
        self._options: Dict = kwargs

    def sample_rate(self) -> int:
        return self._tts.synthesizer.output_sample_rate

    def generate(self, utterance: str) -> Iterator[bytes]:
        with PerfCounter(self._logger, logging.INFO, "msec") as counter:
            try:
                wav: List = self._tts.tts(utterance, split_sentences=False, **self._options)  # see SentenceSegmenter
                counter(round(len(wav) / self.sample_rate() * 1000.0))
            except ValueError as e:
                raise ModelError(self.__class__.__name__, self._tts.model_name, str(e))

            wav_norm: np.ndarray = np.array(wav, dtype=np.single)
            wav_norm = wav_norm * (32768.0 / max(1.0, np.max(np.abs(wav_norm))))
            buffer: bytes = wav_norm.astype(np.dtype(np.int16).newbyteorder("<")).tobytes("A")
        yield buffer


class CoquiSynthesizerModel(SynthesizerModel):
    """
    Use the `TTS <https://github.com/coqui-ai/TTS>`__ text-to-speech library from `Coqui <http://coqui.ai/>`__.
    Originally forked from `Mozilla <https://github.com/mozilla/TTS>`__, both seem to be discontinued, though. A wide
    range of natural sounding `models <https://github.com/coqui-ai/TTS/blob/dev/TTS/.models.json>`__ is available, some
    examples can be found at `Coqui-TTS Voice Samples <https://mbarnig.github.io/TTS-Models-Comparison/>`__.
    Comes with lots of additional dependencies, such as ``espeak``, ``ffmpeg``, ``libav``, or ``rustc``.
    """

    def __init__(self, *,
                 model_name: Optional[str] = None, download: bool = False, device: Optional[str] = None,
                 **kwargs) -> None:
        """
        :param str model_name: Model to use, in the format *type/language/dataset/model* with ``tts_models`` type.
               For example ``tts_models/en/ljspeech/tacotron2-DDC``.
               Omitted to list available models.
        :param bool download: Opt-in automatic downloading of models to ``~/.local/share/tts/``.
               Otherwise, ensure the model exists beforehand.
        :param str device: Device to use, default ``cuda`` if available, otherwise ``cpu``.
        :param kwargs: Extra options passed to :meth:`TTS.tts()`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._use_gpu: bool = torch.cuda.is_available() if device is None else device != "cpu"

        if model_name is None:
            raise ModelNotFoundError(self.__class__.__name__, None, None, list(self._list_models()))
        elif download:
            self._model_name: str = model_name
        else:
            self._model_name = self._find_model(model_name)

        self._options: Dict = kwargs

    @classmethod
    def _list_models(cls) -> Iterator[str]:
        manager: ModelManager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
        for model_name in manager.list_models():
            if model_name.startswith("tts_models/"):
                yield model_name

    @classmethod
    def _find_model(cls, model_name: str) -> str:
        # cannot really control downloading, so need a duplicate existence check beforehand
        manager: ModelManager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
        output_prefix: Path = Path(manager.output_prefix)
        if output_prefix.is_dir():
            for fn in output_prefix.iterdir():
                if fn.is_dir() and "tts_models" in fn.name and fn.name.replace("--", "/") == model_name:
                    return model_name
        raise ModelNotFoundError(cls.__name__, model_name, f"Checked '{output_prefix}'.", list(cls._list_models()))

    def __enter__(self) -> Synthesizer:
        try:
            self._logger.info(f"Entering <{TTS.__name__}>[{self._model_name}]")
            return _CoquiSynthesizer(TTS(model_name=self._model_name, progress_bar=False, gpu=self._use_gpu),
                                     **self._options)
        except (ValueError, zipfile.BadZipFile, requests.RequestException) as e:
            raise ModelError(self.__class__.__name__, self._model_name, str(e))

    def __exit__(self, *args) -> None:
        pass