import logging
from os import getenv
from pathlib import Path
from typing import Optional, Iterator, Dict
import torch
import whisper
from sttts.api.message import ModelNotFoundError, Message, SttMessageType
from sttts.api.model import RecognizerModel, Recognizer
from sttts.utils.utils import StringUtils, MathUtils, PerfCounter
class WhisperRecognizer(RecognizerModel, Recognizer):
"""
Audio transcription using the `OpenAI Whisper <https://github.com/openai/whisper>`__ speech recognition models,
which can be problematic on low-end hardware.
"""
def __init__(self, keywords: Dict[str, SttMessageType], *,
model_name: Optional[str] = None, language: str = "en", download: bool = False,
device: Optional[str] = None, **kwargs) -> None:
"""
:param str model_name: Name of the model to use, for example ``base.en``. Omitted to list available ones.
:param str language: Indicate input language, default ``en``. Especially important for multilingual models.
:param bool download: Opt-in automatic downloading of models to ``~/.cache/whisper/``.
Otherwise, ensure the model exists beforehand.
:param str device: :mod:`torch` device to use, default ``cuda`` if available, otherwise ``cpu``.
:param kwargs: Extra arguments passed to :meth:`whisper.Whisper.transcribe()`.
"""
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._sample_rate: int = 16000
self._language: str = language
self._options: Dict = kwargs
self._keywords: Dict[str, SttMessageType] = keywords
if model_name is None:
raise ModelNotFoundError(self.__class__.__name__, None, None, whisper.available_models())
elif download:
self._model_name: str = model_name
self._model: whisper.Whisper = whisper.load_model(model_name, device=device)
else:
self._model_name = model_name
self._model = whisper.load_model(self._find_model(model_name), device=device)
@classmethod
def _find_model(cls, model_name: str) -> str:
# cannot explicitly disable downloading, so need to duplicate the existence check beforehand
download_root: Path = Path(getenv("XDG_CACHE_HOME", Path.home() / ".cache")) / "whisper"
if download_root.is_dir():
for model in download_root.iterdir():
if model.is_file() and model.name == model_name + ".pt":
return model_name
raise ModelNotFoundError(cls.__name__, model_name, f"Checked '{download_root}'.", None)
def sample_rate(self) -> int:
return self._sample_rate
def _transform_utterance_type(self, utterance: str) -> Message[SttMessageType]:
try:
return Message(self._keywords[StringUtils.strip_punctuation(utterance).lower()], utterance.strip())
except KeyError:
return Message(SttMessageType.Utterance, utterance.strip() + "\n")
def accept(self, buffer: bytes) -> Iterator[Message[SttMessageType]]:
with PerfCounter(self._logger, logging.INFO, "msec") as counter:
utterance: str = self._model.transcribe(MathUtils.buf2arr(buffer),
language=self._language,
condition_on_previous_text=False,
fp16=torch.cuda.is_available(),
**self._options)["text"]
counter(round(len(buffer) / 2.0 / self._sample_rate * 1000.0))
if utterance.strip():
yield self._transform_utterance_type(utterance)
def __enter__(self) -> Recognizer:
self._logger.info(f"Entering <{self._model.__class__.__name__}>[{self._model_name}]")
for line in str(self._model).splitlines(keepends=False):
self._logger.debug(line)
return self
def __exit__(self, *args) -> None:
pass