import logging
from typing import List, Iterator, Dict, Optional, ContextManager

import httpcore
import httpx
import ollama

from sttts.api.message import ModuleError, ModelError, ModelNotFoundError
from sttts.api.model import ProcessorModel, Processor
from sttts.utils.process import BackgroundProcess
from sttts.utils.utils import NoopCtx


class _OllamaProcessor(Processor):
    def __init__(self, client: ollama.Client, model_name: str, system_prompt: Optional[str], **kwargs) -> None:
        self._client: ollama.Client = client
        self._model_name: str = model_name
        self._system_prompt: Optional[str] = system_prompt
        self._options: ollama.Options = ollama.Options(**kwargs)  # type: ignore

    def generate(self, utterance: List[str]) -> Iterator[str]:
        try:
            for chunk in self._client.generate(model=self._model_name,
                                               prompt="".join(utterance).rstrip(),
                                               system=self._system_prompt or "",
                                               options=self._options, stream=True):
                yield chunk["response"]  # type: ignore
        except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError) as e:
            raise ModelError(self.__class__.__name__, self._model_name, str(e))

    def __exit__(self, *args) -> None:
        pass


class _OllamaClient(ollama.Client):
    def __init__(self, host: Optional[str] = None) -> None:
        super().__init__(host=host, transport=httpx.HTTPTransport(retries=5))  # retry until process is up and listening

    def __enter__(self) -> None:
        pass

    def __exit__(self, *args) -> None:
        if isinstance(self._client, httpx.Client):
            self._client.close()


class OllamaProcessorModel(ProcessorModel):
    """
    Run language models on a remote or locally started `Ollama <https://ollama.com/>`__ server, using the client
    provided by the `ollama <https://github.com/ollama/ollama-python>`__ package.

    If no `installation <https://ollama.com/download/linux>`__ as system daemon is needed, the self-contained binary
    can simply be downloaded, for example from `<https://ollama.com/download/ollama-linux-amd64.tgz>`__.
    """

    def __init__(self, *,
                 model_name: Optional[str] = None, host: str = "127.0.0.1:11434", download: bool = False,
                 serve: bool = False, serve_exe: str = "ollama", serve_env: Optional[Dict[str, str]] = None,
                 device: Optional[str] = None, system_prompt: Optional[str] = None,
                 **kwargs) -> None:
        """
        :param str model_name: Model to use, for example ``llama3``. Omitted to list locally available ones.
               Remotely available models can be browsed in the official `model library <https://ollama.com/library>`__.
        :param str host: API host, default ``127.0.0.1:11434``.
        :param bool download: Opt-in automatic model pull, usually to ``~/.ollama/models/``.
        :param bool serve: Run ``ollama serve`` in an own subprocess.
        :param str serve_exe: Path to local binary when using internal serving instead of ``ollama``.
        :param dict serve_env: Extra environment variables when using internal serving, see ``ollama serve --help``.
        :param str device: Disable CUDA when using internal serving and set to ``cpu``.
        :param str system_prompt: Override system message from what is defined in the Modelfile.
        :param kwargs: Extra :class:`ollama.Options` passed to :meth:`ollama.Client.generate()`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._model_name: Optional[str] = model_name
        self._download: bool = download
        self._system_prompt: Optional[str] = system_prompt
        self._options: Dict = kwargs
        self._client: _OllamaClient = _OllamaClient(host)

        self._serve: ContextManager[None] = BackgroundProcess(
            self._logger.getChild("serve"),
            {
                "OLLAMA_HOST": host,
                # "OLLAMA_DEBUG": "1" if self._logger.getEffectiveLevel() == logging.DEBUG else "0",
                **({"CUDA_VISIBLE_DEVICES": ""} if device == "cpu" else {}),
                **(serve_env if serve_env is not None else {})
            },
            serve_exe, "serve",
        ) if serve else NoopCtx()

    def _check_model(self, model_name: str) -> bool:
        try:
            for key, _val in self._client.show(model_name).items():
                self._logger.debug(f"{model_name}: {key}")
            return True
        except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError, httpcore.ConnectError) as e:
            self._logger.warning(f"{model_name}: {str(e)}")
            return False

    def _list_models(self) -> Iterator[str]:
        for model in self._client.list()["models"]:
            yield f"{model['name']} ({model['size']}B)"

    def __enter__(self) -> Processor:
        try:
            self._serve.__enter__()
            self._client.__enter__()
        except OSError as e:
            raise ModuleError(self.__class__.__name__, f"Cannot run '{str(self._serve)}': {str(e)}") from None

        try:
            if self._model_name is None:
                raise ModelNotFoundError(self.__class__.__name__, None, None, list(self._list_models()))
            elif self._check_model(self._model_name):
                pass
            elif not self._download:
                raise ModelNotFoundError(self.__class__.__name__, self._model_name, None, None)
            else:
                self._client.pull(self._model_name)
            self._logger.info(f"Entering {self._client}[{self._model_name}]")
            return _OllamaProcessor(self._client, self._model_name, self._system_prompt, **self._options)
        except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError, httpcore.ConnectError) as e:
            self._client.__exit__(None, None, None)
            self._serve.__exit__(None, None, None)
            raise ModelError(self.__class__.__name__, self._model_name, str(e))
        except BaseException:
            self._client.__exit__(None, None, None)
            self._serve.__exit__(None, None, None)
            raise

    def __exit__(self, *args) -> None:
        self._client.__exit__(*args)
        self._serve.__exit__(*args)