"""
Listening source and playback sink using `PyAudio <https://people.csail.mit.edu/hubert/pyaudio/>`__,
which relies on the cross-platform PortAudio library.

Building requires the ``portaudio19-dev`` package or similar.
At the time of writing, on Ubuntu 22.04, this conflicts with ``jackd`` and the pre-built ``python3-pyaudio`` binary
package 0.2.11 is broken for Python 3.10.
Also, problems might arise for non-default microphone sampling rates using ALSA.

This option provides the most high-level abstraction and compatibility if neither ALSA nor PulseAudio is supported.
"""

import logging
import time
from typing import Optional, Iterator, Dict, List

import pyaudio

from sttts.api.message import ModuleError, ModelNotFoundError
from sttts.api.model import AudioSink, AudioSource


class PyAudioPlayer(AudioSink):
    def __init__(self, sample_rate: int, *,
                 device: Optional[str] = None, buffer_length: float = 5.0,
                 **kwargs) -> None:
        """
        :param str device: Playback device to use, none for default, for example ``bcm2835 Headphones: - (hw:0,0)``.
               If invalid, error out with a list of available devices.
        :param int buffer_length: Requested output buffer length in seconds, default 5. Note that the actually applied
               buffer size might be lower.
        :param kwargs: Extra options passed to :class:`pyaudio.Stream`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._p: pyaudio.PyAudio = pyaudio.PyAudio()
        self._pcm: Optional[pyaudio.Stream] = None
        self._device: Optional[int] = self._find_device(self._p, device) if device is not None else None
        self._sample_rate: int = sample_rate
        self._buffer_length: int = round(buffer_length * sample_rate)  # in frames
        self._options: Dict = kwargs

    @classmethod
    def _find_device(cls, p: pyaudio.PyAudio, device: str) -> int:
        devices: List[Dict] = [p.get_device_info_by_index(i) for i in range(p.get_device_count())]
        indices: Dict[str, int] = {_["name"]: _["index"] for _ in devices if _["maxOutputChannels"] > 0}
        try:
            return indices[device]
        except KeyError:
            raise ModelNotFoundError(cls.__name__, device, None, options=list(indices.keys())) from None

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = self._p.open(rate=self._sample_rate, channels=1, format=pyaudio.paInt16,
                                     output=True, output_device_index=self._device, start=False,
                                     frames_per_buffer=self._buffer_length, **self._options)
            self._buffer_length = min(self._buffer_length, self._pcm.get_write_available())
            self._pcm.start_stream()
        except IOError as e:
            raise ModuleError(self.__class__.__name__, f"Cannot setup output stream: {str(e)}")
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")

    def play(self, buffer: bytes) -> None:
        assert self._pcm is not None
        self._pcm.write(buffer, exception_on_underflow=False)

    def drain(self) -> None:
        # XXX: apparently no native flush/wait/drain exposed, and might not even have started for small buffers
        try:
            assert self._pcm is not None
            self._pcm.write(b"\x00\x00" * self._pcm.get_write_available(), exception_on_underflow=False)
        except IOError as e:
            self._logger.warning(str(e))
            return
        while self._pcm.get_write_available() < self._buffer_length:
            self._logger.debug(f"Waiting for drain: {self._pcm.get_write_available()} < {self._buffer_length}")
            time.sleep(0.2)

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.stop_stream()
        self._pcm.close()
        self._pcm = None


class PyAudioRecorder(AudioSource):
    def __init__(self, sample_rate: int, *,
                 device: Optional[str] = None, buffer_length: float = 0.25,
                 **kwargs) -> None:
        """
        :param str device: Recording device to use such as ``USB PnP Sound Device: Audio (hw:1,0)``, none for default.
               If invalid, error out with a list of available devices.
        :param int buffer_length: Read size in seconds, default 250ms.
        :param kwargs: Extra options passed to :class:`pyaudio.Stream`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._p: pyaudio.PyAudio = pyaudio.PyAudio()
        self._pcm: Optional[pyaudio.Stream] = None
        self._device: Optional[int] = self._find_device(self._p, device) if device is not None else None
        self._sample_rate: int = sample_rate
        self._buffer_length: int = round(buffer_length * sample_rate)  # in frames
        self._options: Dict = kwargs

    @classmethod
    def _find_device(cls, p: pyaudio.PyAudio, device: str) -> int:
        devices: List[Dict] = [p.get_device_info_by_index(i) for i in range(p.get_device_count())]
        indices: Dict[str, int] = {_["name"]: _["index"] for _ in devices if _["maxInputChannels"] > 0}
        try:
            return indices[device]
        except KeyError:
            raise ModelNotFoundError(cls.__name__, device, None, options=list(indices.keys())) from None

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = self._p.open(rate=self._sample_rate, channels=1, format=pyaudio.paInt16,
                                     input=True, input_device_index=self._device, start=False,
                                     frames_per_buffer=self._buffer_length, **self._options)
            self._pcm.start_stream()
        except IOError as e:
            raise ModuleError(self.__class__.__name__, f"Cannot setup input stream: {str(e)}")
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")

    def __iter__(self) -> Iterator[bytes]:
        assert self._pcm is not None
        return self

    def __next__(self) -> bytes:
        assert self._pcm is not None
        return self._pcm.read(self._buffer_length, exception_on_overflow=False)

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.stop_stream()
        self._pcm.close()
        self._pcm = None