import logging
from typing import List, Iterator, Dict, Optional, ContextManager
import httpcore
import httpx
import ollama
from sttts.api.message import ModuleError, ModelError, ModelNotFoundError
from sttts.api.model import ProcessorModel, Processor
from sttts.utils.process import BackgroundProcess
from sttts.utils.utils import NoopCtx
class _OllamaProcessor(Processor):
def __init__(self, client: ollama.Client, model_name: str, system_prompt: Optional[str], **kwargs) -> None:
self._client: ollama.Client = client
self._model_name: str = model_name
self._system_prompt: Optional[str] = system_prompt
self._options: ollama.Options = ollama.Options(**kwargs) # type: ignore
def generate(self, utterance: List[str]) -> Iterator[str]:
try:
for chunk in self._client.generate(model=self._model_name,
prompt="".join(utterance).rstrip(),
system=self._system_prompt or "",
options=self._options, stream=True):
yield chunk["response"] # type: ignore
except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError) as e:
raise ModelError(self.__class__.__name__, self._model_name, str(e))
def __exit__(self, *args) -> None:
pass
class _OllamaClient(ollama.Client):
def __init__(self, host: Optional[str] = None) -> None:
super().__init__(host=host, transport=httpx.HTTPTransport(retries=5)) # retry until process is up and listening
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
if isinstance(self._client, httpx.Client):
self._client.close()
class OllamaProcessorModel(ProcessorModel):
"""
Run language models on a remote or locally started `Ollama <https://ollama.com/>`__ server, using the client
provided by the `ollama <https://github.com/ollama/ollama-python>`__ package.
If no `installation <https://ollama.com/download/linux>`__ as system daemon is needed, the self-contained binary
can simply be downloaded, for example from `<https://ollama.com/download/ollama-linux-amd64.tgz>`__.
"""
def __init__(self, *,
model_name: Optional[str] = None, host: str = "127.0.0.1:11434", download: bool = False,
serve: bool = False, serve_exe: str = "ollama", serve_env: Optional[Dict[str, str]] = None,
device: Optional[str] = None, system_prompt: Optional[str] = None,
**kwargs) -> None:
"""
:param str model_name: Model to use, for example ``llama3``. Omitted to list locally available ones.
Remotely available models can be browsed in the official `model library <https://ollama.com/library>`__.
:param str host: API host, default ``127.0.0.1:11434``.
:param bool download: Opt-in automatic model pull, usually to ``~/.ollama/models/``.
:param bool serve: Run ``ollama serve`` in an own subprocess.
:param str serve_exe: Path to local binary when using internal serving instead of ``ollama``.
:param dict serve_env: Extra environment variables when using internal serving, see ``ollama serve --help``.
:param str device: Disable CUDA when using internal serving and set to ``cpu``.
:param str system_prompt: Override system message from what is defined in the Modelfile.
:param kwargs: Extra :class:`ollama.Options` passed to :meth:`ollama.Client.generate()`.
"""
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._model_name: Optional[str] = model_name
self._download: bool = download
self._system_prompt: Optional[str] = system_prompt
self._options: Dict = kwargs
self._client: _OllamaClient = _OllamaClient(host)
self._serve: ContextManager[None] = BackgroundProcess(
self._logger.getChild("serve"),
{
"OLLAMA_HOST": host,
# "OLLAMA_DEBUG": "1" if self._logger.getEffectiveLevel() == logging.DEBUG else "0",
**({"CUDA_VISIBLE_DEVICES": ""} if device == "cpu" else {}),
**(serve_env if serve_env is not None else {})
},
serve_exe, "serve",
) if serve else NoopCtx()
def _check_model(self, model_name: str) -> bool:
try:
for key, _val in self._client.show(model_name).items():
self._logger.debug(f"{model_name}: {key}")
return True
except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError, httpcore.ConnectError) as e:
self._logger.warning(f"{model_name}: {str(e)}")
return False
def _list_models(self) -> Iterator[str]:
for model in self._client.list()["models"]:
yield f"{model['name']} ({model['size']}B)"
def __enter__(self) -> Processor:
try:
self._serve.__enter__()
self._client.__enter__()
except OSError as e:
raise ModuleError(self.__class__.__name__, f"Cannot run '{str(self._serve)}': {str(e)}") from None
try:
if self._model_name is None:
raise ModelNotFoundError(self.__class__.__name__, None, None, list(self._list_models()))
elif self._check_model(self._model_name):
pass
elif not self._download:
raise ModelNotFoundError(self.__class__.__name__, self._model_name, None, None)
else:
self._client.pull(self._model_name)
self._logger.info(f"Entering {self._client}[{self._model_name}]")
return _OllamaProcessor(self._client, self._model_name, self._system_prompt, **self._options)
except (ollama.RequestError, ollama.ResponseError, httpx.HTTPError, httpcore.ConnectError) as e:
self._client.__exit__(None, None, None)
self._serve.__exit__(None, None, None)
raise ModelError(self.__class__.__name__, self._model_name, str(e))
except BaseException:
self._client.__exit__(None, None, None)
self._serve.__exit__(None, None, None)
raise
def __exit__(self, *args) -> None:
self._client.__exit__(*args)
self._serve.__exit__(*args)