"""
Mostly for debugging purposes, audio can be directly read from or written to ``*.wav`` files.
"""

import logging
import wave
from typing import Optional, Iterator

from sttts.api.message import ModelError
from sttts.api.model import AudioSource, AudioSink


class WaveRecorder(AudioSink):
    def __init__(self, sample_rate: int, *, filename: str) -> None:
        """
        :param str filename: Output ``.wav`` file, written in S16LE mono.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._sample_rate: int = sample_rate
        self._filename: str = filename
        self._written: int = 0
        self._wave: Optional[wave.Wave_write] = None

    def __enter__(self) -> None:
        assert self._wave is None
        try:
            self._wave = wave.open(self._filename, "wb")
            self._wave.setframerate(self._sample_rate)
            self._wave.setnchannels(1)
            self._wave.setsampwidth(2)
            self._logger.info(f"Opened '{self._filename}'")
        except (OSError, wave.Error) as e:
            raise ModelError(self.__class__.__name__, self._filename, str(e))

    def play(self, buffer: bytes) -> None:
        assert self._wave is not None
        self._wave.writeframesraw(buffer)
        self._written += len(buffer)

    def drain(self) -> None:
        pass

    def __exit__(self, *args) -> None:
        assert self._wave is not None
        try:
            self._wave.close()
        except (OSError, wave.Error) as e:
            raise ModelError(self.__class__.__name__, self._filename, str(e))
        finally:
            self._wave = None
            self._written = 0


class WavePlayer(AudioSource):
    def __init__(self, sample_rate: int, *, filename: str, buffer_length: float = 0.25) -> None:
        """
        :param str filename: Input ``.wav`` file, must match internal sample rate and assumed to be in S16LE mono.
        :param float buffer_length: Read size in seconds, default 250ms.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._sample_rate: int = sample_rate
        self._filename: str = filename
        self._num_frames: int = round(buffer_length * sample_rate)
        self._wave: Optional[wave.Wave_read] = None

    def __enter__(self) -> None:
        assert self._wave is None
        try:
            self._wave = wave.open(self._filename, "rb")
            self._logger.info(f"Opened '{self._filename}': {self._wave.getparams()}")
        except (OSError, wave.Error) as e:
            raise ModelError(self.__class__.__name__, self._filename, str(e))

        try:
            if self._wave.getframerate() != self._sample_rate:
                raise ModelError(self.__class__.__name__, self._filename,
                                 f"Sample rate {self._wave.getframerate()}, need {self._sample_rate}")
            if self._wave.getnchannels() != 1:
                raise ModelError(self.__class__.__name__, self._filename,
                                 f"{self._wave.getnchannels()} channels")
            if self._wave.getsampwidth() != 2:
                raise ModelError(self.__class__.__name__, self._filename,
                                 f"Sample width {self._wave.getsampwidth()}, need 2")
        except BaseException:
            self._wave.close()
            self._wave = None
            raise

    def __iter__(self) -> Iterator[bytes]:
        return self

    def __next__(self) -> bytes:
        assert self._wave is not None
        data: bytes = self._wave.readframes(self._num_frames)
        if data:
            return data
        else:
            raise StopIteration

    def __exit__(self, *args) -> None:
        assert self._wave is not None
        self._wave.close()
        self._wave = None