import logging
from os import getenv
from pathlib import Path
from typing import Optional, Iterator, Dict

import torch
import whisper

from sttts.api.message import ModelNotFoundError, Message, SttMessageType
from sttts.api.model import RecognizerModel, Recognizer
from sttts.utils.utils import StringUtils, MathUtils, PerfCounter


class WhisperRecognizer(RecognizerModel, Recognizer):
    """
    Audio transcription using the `OpenAI Whisper <https://github.com/openai/whisper>`__ speech recognition models,
    which can be problematic on low-end hardware.
    """

    def __init__(self, keywords: Dict[str, SttMessageType], *,
                 model_name: Optional[str] = None, language: str = "en", download: bool = False,
                 device: Optional[str] = None, **kwargs) -> None:
        """
        :param str model_name: Name of the model to use, for example ``base.en``. Omitted to list available ones.
        :param str language: Indicate input language, default ``en``. Especially important for multilingual models.
        :param bool download: Opt-in automatic downloading of models to ``~/.cache/whisper/``.
               Otherwise, ensure the model exists beforehand.
        :param str device: :mod:`torch` device to use, default ``cuda`` if available, otherwise ``cpu``.
        :param kwargs: Extra arguments passed to :meth:`whisper.Whisper.transcribe()`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._sample_rate: int = 16000
        self._language: str = language
        self._options: Dict = kwargs
        self._keywords: Dict[str, SttMessageType] = keywords

        if model_name is None:
            raise ModelNotFoundError(self.__class__.__name__, None, None, whisper.available_models())
        elif download:
            self._model_name: str = model_name
            self._model: whisper.Whisper = whisper.load_model(model_name, device=device)
        else:
            self._model_name = model_name
            self._model = whisper.load_model(self._find_model(model_name), device=device)

    @classmethod
    def _find_model(cls, model_name: str) -> str:
        # cannot explicitly disable downloading, so need to duplicate the existence check beforehand
        download_root: Path = Path(getenv("XDG_CACHE_HOME", Path.home() / ".cache")) / "whisper"
        if download_root.is_dir():
            for model in download_root.iterdir():
                if model.is_file() and model.name == model_name + ".pt":
                    return model_name
        raise ModelNotFoundError(cls.__name__, model_name, f"Checked '{download_root}'.", None)

    def sample_rate(self) -> int:
        return self._sample_rate

    def _transform_utterance_type(self, utterance: str) -> Message[SttMessageType]:
        try:
            return Message(self._keywords[StringUtils.strip_punctuation(utterance).lower()], utterance.strip())
        except KeyError:
            return Message(SttMessageType.Utterance, utterance.strip() + "\n")

    def accept(self, buffer: bytes) -> Iterator[Message[SttMessageType]]:
        with PerfCounter(self._logger, logging.INFO, "msec") as counter:
            utterance: str = self._model.transcribe(MathUtils.buf2arr(buffer),
                                                    language=self._language,
                                                    condition_on_previous_text=False,
                                                    fp16=torch.cuda.is_available(),
                                                    **self._options)["text"]
            counter(round(len(buffer) / 2.0 / self._sample_rate * 1000.0))
            if utterance.strip():
                yield self._transform_utterance_type(utterance)

    def __enter__(self) -> Recognizer:
        self._logger.info(f"Entering <{self._model.__class__.__name__}>[{self._model_name}]")
        for line in str(self._model).splitlines(keepends=False):
            self._logger.debug(line)
        return self

    def __exit__(self, *args) -> None:
        pass