"""
Mostly for debugging purposes, audio can be directly read from or written to ``*.wav`` files.
"""
import logging
import wave
from typing import Optional, Iterator
from sttts.api.message import ModelError
from sttts.api.model import AudioSource, AudioSink
class WaveRecorder(AudioSink):
def __init__(self, sample_rate: int, *, filename: str) -> None:
"""
:param str filename: Output ``.wav`` file, written in S16LE mono.
"""
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._sample_rate: int = sample_rate
self._filename: str = filename
self._written: int = 0
self._wave: Optional[wave.Wave_write] = None
def __enter__(self) -> None:
assert self._wave is None
try:
self._wave = wave.open(self._filename, "wb")
self._wave.setframerate(self._sample_rate)
self._wave.setnchannels(1)
self._wave.setsampwidth(2)
self._logger.info(f"Opened '{self._filename}'")
except (OSError, wave.Error) as e:
raise ModelError(self.__class__.__name__, self._filename, str(e))
def play(self, buffer: bytes) -> None:
assert self._wave is not None
self._wave.writeframesraw(buffer)
self._written += len(buffer)
def drain(self) -> None:
pass
def __exit__(self, *args) -> None:
assert self._wave is not None
try:
self._wave.close()
except (OSError, wave.Error) as e:
raise ModelError(self.__class__.__name__, self._filename, str(e))
finally:
self._wave = None
self._written = 0
class WavePlayer(AudioSource):
def __init__(self, sample_rate: int, *, filename: str, buffer_length: float = 0.25) -> None:
"""
:param str filename: Input ``.wav`` file, must match internal sample rate and assumed to be in S16LE mono.
:param float buffer_length: Read size in seconds, default 250ms.
"""
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._sample_rate: int = sample_rate
self._filename: str = filename
self._num_frames: int = round(buffer_length * sample_rate)
self._wave: Optional[wave.Wave_read] = None
def __enter__(self) -> None:
assert self._wave is None
try:
self._wave = wave.open(self._filename, "rb")
self._logger.info(f"Opened '{self._filename}': {self._wave.getparams()}")
except (OSError, wave.Error) as e:
raise ModelError(self.__class__.__name__, self._filename, str(e))
try:
if self._wave.getframerate() != self._sample_rate:
raise ModelError(self.__class__.__name__, self._filename,
f"Sample rate {self._wave.getframerate()}, need {self._sample_rate}")
if self._wave.getnchannels() != 1:
raise ModelError(self.__class__.__name__, self._filename,
f"{self._wave.getnchannels()} channels")
if self._wave.getsampwidth() != 2:
raise ModelError(self.__class__.__name__, self._filename,
f"Sample width {self._wave.getsampwidth()}, need 2")
except BaseException:
self._wave.close()
self._wave = None
raise
def __iter__(self) -> Iterator[bytes]:
return self
def __next__(self) -> bytes:
assert self._wave is not None
data: bytes = self._wave.readframes(self._num_frames)
if data:
return data
else:
raise StopIteration
def __exit__(self, *args) -> None:
assert self._wave is not None
self._wave.close()
self._wave = None