import logging
from typing import List, Dict, Iterator, Optional

import gpt4all
import requests

from sttts.api.message import ModelError, ModelNotFoundError
from sttts.api.model import ProcessorModel, Processor


class _Gpt4AllProcessor(Processor):
    def __init__(self, model: gpt4all.GPT4All, system_prompt: Optional[str], **kwargs) -> None:
        self._model: gpt4all.GPT4All = model
        self._system_prompt: Optional[str] = system_prompt
        self._options: Dict = kwargs
        self._prompt_template: str = self._options.pop("promptTemplate",
                                                       self._model.config.get("promptTemplate",
                                                                              gpt4all.gpt4all.DEFAULT_PROMPT_TEMPLATE))

    def generate(self, utterance: List[str]) -> Iterator[str]:
        with self._model.chat_session(system_prompt=self._system_prompt, prompt_template=self._prompt_template):
            yield from self._model.generate("".join(utterance).rstrip(), streaming=True, **self._options)

    def __exit__(self, *args) -> None:
        pass


class Gpt4AllProcessorModel(ProcessorModel):
    """
    Run language models through the `GPT4All <https://www.nomic.ai/gpt4all>`__
    `Python <https://docs.gpt4all.io/gpt4all_python/home.html>`__
    `client <https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/python>`__ around the
    `Nomic <https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-backend>`__ and/or
    `llama.cpp <https://github.com/ggerganov/llama.cpp>`__ backends.
    """

    def __init__(self, *,
                 model_name: Optional[str] = None, model_path: str = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY.as_posix(),
                 download: bool = False, device: Optional[str] = None, max_tokens: int = 200, n_ctx: int = 2048,
                 system_prompt: Optional[str] = None,
                 **kwargs) -> None:
        """
        :param str model_name: Model to use, for example ``Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf``.
               Omitted to list remote models if *download* is enabled.
        :param str model_path: Path for model files, default ``~/.cache/gpt4all/``.
        :param bool download: Opt-in model search and automatic download.
        :param str device: Explicit processing unit, such as ``cpu``, automatic per default.
        :param int max_tokens: The maximum number of tokens to generate (200).
        :param int n_ctx: Maximum size of context window (2048).
        :param str system_prompt: Override initial instruction for the model.
        :param kwargs: Extra options passed to :meth:`gpt4all.GPT4All.generate()`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._verbose: bool = self._logger.getEffectiveLevel() == logging.DEBUG

        if model_name is None:
            raise ModelNotFoundError(self.__class__.__name__,  None, None, [
                f"{model_config['filename']} ({model_config['filesize']}B, {model_config['ramrequired']}GB)"
                for model_config in gpt4all.GPT4All.list_models()
            ] if download else None)

        self._model_path: str = model_path
        self._model_name: str = model_name
        self._download: bool = download
        self._max_tokens: int = max_tokens
        self._n_ctx: int = n_ctx
        self._device: Optional[str] = device
        self._system_prompt: Optional[str] = system_prompt
        self._model: Optional[gpt4all.GPT4All] = None
        self._options: Dict = kwargs

    def __enter__(self) -> Processor:
        assert self._model is None
        try:
            self._model = gpt4all.GPT4All(model_name=self._model_name, model_path=self._model_path,
                                          allow_download=self._download, verbose=self._verbose,
                                          n_ctx=self._n_ctx, device=self._device).__enter__()
            self._logger.info(f"Entering {self._model}[{self._model_name}]")
        except FileNotFoundError as e:
            raise ModelNotFoundError(self.__class__.__name__, self._model_name, str(e), None)
        except (requests.RequestException, RuntimeError, ValueError) as e:
            raise ModelError(self.__class__.__name__, self._model_name, str(e))
        return _Gpt4AllProcessor(self._model, system_prompt=self._system_prompt, max_tokens=self._max_tokens,
                                 **self._options)

    def __exit__(self, *args) -> None:
        assert self._model is not None
        self._model.__exit__(*args)