"""
`ALSA <https://www.alsa-project.org/wiki/Main_Page>`__ listening source and playback sink, using
`PyAlsaAudio <https://github.com/larsimmisch/pyalsaaudio>`__.
This should be the most low-level sound system implementation available per default for example even on a Raspberry Pi.

For `installation <http://larsimmisch.github.io/pyalsaaudio/pyalsaaudio.html#installation>`__, the corresponding native
library and headers are needed, such as from the ``libasound2-dev`` package.
For configuration, the ``amixer``, ``aplay``, and ``arecord`` CLI commands from the ``alsa-utils`` package might be
useful, too. Also note that the invoking user must typically be in the ``audio`` group.
"""

import errno
import logging
from typing import Optional, Iterator, Dict

import alsaaudio

from sttts.api.message import ModelError
from sttts.api.model import AudioSource, AudioSink


class AlsaPlayer(AudioSink):
    def __init__(self, sample_rate: int, *,
                 device: str = "default", buffer_length: float = 5.0, period_size: int = 256,
                 **kwargs) -> None:
        """
        :param str device: Playback PCM to use, as obtained by ``aplay -L``, for example ``default:CARD=Headphones``.
               Default ``default``.
        :param float buffer_length: Output buffer length in seconds, default 5.
               Generous to unblock the :class:`sttts.api.model.Synthesizer` running in parallel.
        :param int period_size: ALSA period size in `frames <https://www.alsa-project.org/wiki/FramesPeriods>`__.
        :param kwargs: Extra options passed to :class:`alsaaudio.PCM`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._sample_rate: int = sample_rate
        self._sample_format: int = alsaaudio.PCM_FORMAT_S16_LE
        self._period_size: int = period_size
        self._periods: int = round(buffer_length * sample_rate / period_size)
        self._device: str = device
        self._options: Dict = kwargs
        self._pcm: Optional[alsaaudio.PCM] = None

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = alsaaudio.PCM(alsaaudio.PCM_PLAYBACK, mode=alsaaudio.PCM_NORMAL,  # NB: blocking
                                      channels=1, rate=self._sample_rate, format=self._sample_format,
                                      periodsize=self._period_size, periods=self._periods, device=self._device,
                                      **self._options)
        except alsaaudio.ALSAAudioError as e:
            raise ModelError(self.__class__.__name__, self._device, str(e))
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")
            self._logger.debug(str(self._pcm.info()))

    def play(self, buffer: bytes) -> None:
        assert self._pcm is not None
        rv: int = self._pcm.write(buffer)
        if rv == -errno.EPIPE:
            self._logger.debug("Playback buffer not ready")
            rv = self._pcm.write(buffer)  # retry
        if rv == -errno.EPIPE:
            self._logger.warning("Playback buffer underrun")
        elif rv < 0:
            self._logger.error(f"Playback write error: {rv}")
        elif rv != len(buffer) // 2:  # number of frames
            self._logger.warning(f"Playback write buffer: Unexpected length {rv}/{len(buffer)}")

    def drain(self) -> None:
        assert self._pcm is not None
        self._pcm.drain()

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.close()
        self._pcm = None


class AlsaRecorder(AudioSource):
    def __init__(self, sample_rate: int, *,
                 device: Optional[str] = None, buffer_length: float = 0.25, warmup: int = 4, periods: int = 1,
                 **kwargs) -> None:
        """
        :param str device: Capture PCM to use, as obtained by ``arecord -L``, for example ``default:CARD=Device``.
               Default ``default``.
        :param float buffer_length: Adjust read size in seconds, default 250ms.
        :param int warmup: Skip the first reads, in case of microphone auto-gaining, default 4, thus 1 second.
        :param int periods: ALSA `periods <https://www.alsa-project.org/wiki/FramesPeriods>`__.
        :param kwargs: Extra options passed to :class:`alsaaudio.PCM`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._sample_rate: int = sample_rate
        self._warmup: int = warmup
        self._num_reads: int = 0
        self._sample_format: int = alsaaudio.PCM_FORMAT_S16_LE
        self._device: str = device if device is not None else "default"
        self._pcm: Optional[alsaaudio.PCM] = None
        self._periods: int = periods  # XXX: does not seem to have any effect
        self._period_size: int = round(buffer_length * sample_rate / periods)
        self._options: Dict = kwargs

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = alsaaudio.PCM(alsaaudio.PCM_CAPTURE, mode=alsaaudio.PCM_NORMAL,  # NB: blocking
                                      channels=1, rate=self._sample_rate, format=self._sample_format,
                                      periodsize=self._period_size, periods=self._periods, device=self._device,
                                      **self._options)
        except alsaaudio.ALSAAudioError as e:
            raise ModelError(self.__class__.__name__, self._device, str(e))
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")
            self._logger.debug(str(self._pcm.info()))

    def __iter__(self) -> Iterator[bytes]:
        assert self._pcm is not None
        return self

    def __next__(self) -> bytes:
        assert self._pcm is not None
        self._num_reads += 1
        while self._num_reads <= self._warmup:
            self._num_reads += 1
            self._pcm.read()
        rv, buffer = self._pcm.read()
        if rv == -errno.EPIPE:
            self._logger.debug("Capture buffer not ready")
            rv, buffer = self._pcm.read()  # direct retry
        if rv == -errno.EPIPE:
            self._logger.warning("Capture buffer overrun")
            return b""
        elif rv <= 0:
            self._logger.error(f"Capture read error: {rv}")
            return b""
        elif rv != len(buffer) // 2:  # number of frames
            self._logger.warning(f"Capture read buffer: Unexpected length {rv}/{len(buffer)}")
            return buffer
        else:
            return buffer

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.close()
        self._pcm = None
        self._num_reads = 0