import logging
import zipfile
from pathlib import Path
from typing import Iterator, Optional, Dict, List
import numpy as np
import requests
import torch
from TTS.api import TTS, ModelManager
from sttts.api.message import ModelError, ModelNotFoundError
from sttts.api.model import Synthesizer, SynthesizerModel
from sttts.utils.utils import PerfCounter
class _CoquiSynthesizer(Synthesizer):
def __init__(self, tts: TTS, **kwargs) -> None:
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__.lstrip("_"))
self._tts: TTS = tts
self._options: Dict = kwargs
def sample_rate(self) -> int:
return self._tts.synthesizer.output_sample_rate
def generate(self, utterance: str) -> Iterator[bytes]:
with PerfCounter(self._logger, logging.INFO, "msec") as counter:
try:
wav: List = self._tts.tts(utterance, split_sentences=False, **self._options) # see SentenceSegmenter
counter(round(len(wav) / self.sample_rate() * 1000.0))
except ValueError as e:
raise ModelError(self.__class__.__name__, self._tts.model_name, str(e))
wav_norm: np.ndarray = np.array(wav, dtype=np.single)
wav_norm = wav_norm * (32768.0 / max(1.0, np.max(np.abs(wav_norm))))
buffer: bytes = wav_norm.astype(np.dtype(np.int16).newbyteorder("<")).tobytes("A")
yield buffer
class CoquiSynthesizerModel(SynthesizerModel):
"""
Use the `TTS <https://github.com/coqui-ai/TTS>`__ text-to-speech library from `Coqui <http://coqui.ai/>`__.
Originally forked from `Mozilla <https://github.com/mozilla/TTS>`__, both seem to be discontinued, though. A wide
range of natural sounding `models <https://github.com/coqui-ai/TTS/blob/dev/TTS/.models.json>`__ is available, some
examples can be found at `Coqui-TTS Voice Samples <https://mbarnig.github.io/TTS-Models-Comparison/>`__.
Comes with lots of additional dependencies, such as ``espeak``, ``ffmpeg``, ``libav``, or ``rustc``.
"""
def __init__(self, *,
model_name: Optional[str] = None, download: bool = False, device: Optional[str] = None,
**kwargs) -> None:
"""
:param str model_name: Model to use, in the format *type/language/dataset/model* with ``tts_models`` type.
For example ``tts_models/en/ljspeech/tacotron2-DDC``.
Omitted to list available models.
:param bool download: Opt-in automatic downloading of models to ``~/.local/share/tts/``.
Otherwise, ensure the model exists beforehand.
:param str device: Device to use, default ``cuda`` if available, otherwise ``cpu``.
:param kwargs: Extra options passed to :meth:`TTS.tts()`.
"""
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._use_gpu: bool = torch.cuda.is_available() if device is None else device != "cpu"
if model_name is None:
raise ModelNotFoundError(self.__class__.__name__, None, None, list(self._list_models()))
elif download:
self._model_name: str = model_name
else:
self._model_name = self._find_model(model_name)
self._options: Dict = kwargs
@classmethod
def _list_models(cls) -> Iterator[str]:
manager: ModelManager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
for model_name in manager.list_models():
if model_name.startswith("tts_models/"):
yield model_name
@classmethod
def _find_model(cls, model_name: str) -> str:
# cannot really control downloading, so need a duplicate existence check beforehand
manager: ModelManager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
output_prefix: Path = Path(manager.output_prefix)
if output_prefix.is_dir():
for fn in output_prefix.iterdir():
if fn.is_dir() and "tts_models" in fn.name and fn.name.replace("--", "/") == model_name:
return model_name
raise ModelNotFoundError(cls.__name__, model_name, f"Checked '{output_prefix}'.", list(cls._list_models()))
def __enter__(self) -> Synthesizer:
try:
self._logger.info(f"Entering <{TTS.__name__}>[{self._model_name}]")
return _CoquiSynthesizer(TTS(model_name=self._model_name, progress_bar=False, gpu=self._use_gpu),
**self._options)
except (ValueError, zipfile.BadZipFile, requests.RequestException) as e:
raise ModelError(self.__class__.__name__, self._model_name, str(e))
def __exit__(self, *args) -> None:
pass