"""
Listening source and playback sink for PulseAudio servers, using `PaSimple <https://pypi.org/project/pasimple/>`__,
that in turn requires the native ``libpulse-simple.so.0`` library.
This sound server implementation should be in use per default on various Linux desktop distributions.
"""

import logging
from typing import Optional, Iterator, Dict

import pasimple

from sttts.api.message import ModelError
from sttts.api.model import AudioSink, AudioSource


class PulsePlayer(AudioSink):
    def __init__(self, sample_rate: int, *,
                 device: Optional[str] = None, buffer_length: float = 5.0,
                 **kwargs) -> None:
        """
        :param str device: Playback device to use, none for default.
        :param int buffer_length: Output buffer length in seconds, default 5.
               Generous to unblock the :class:`sttts.api.model.Synthesizer` running in parallel.
        :param kwargs: Extra options passed to :class:`pasimple.PaSimple`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._device: Optional[str] = device
        self._sample_format: int = pasimple.PA_SAMPLE_S16LE
        self._sample_rate: int = sample_rate
        self._buffer_length: int = round(buffer_length * pasimple.format2width(self._sample_format) * sample_rate)
        self._options: Dict = kwargs
        self._pcm: Optional[pasimple.PaSimple] = None

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = pasimple.PaSimple(pasimple.PA_STREAM_PLAYBACK,
                                          format=self._sample_format, channels=1, rate=self._sample_rate,
                                          app_name=__package__ or self.__class__.__name__, device_name=self._device,
                                          tlength=self._buffer_length,
                                          **self._options)
        except pasimple.PaSimpleError as e:
            raise ModelError(self.__class__.__name__, self._device, str(e))
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")
            self._pcm.__enter__()

    def play(self, buffer: bytes) -> None:
        assert self._pcm is not None
        self._pcm.write(buffer)

    def drain(self) -> None:
        assert self._pcm is not None
        self._pcm.drain()

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.__exit__(*args)
        self._pcm = None


class PulseRecorder(AudioSource):
    def __init__(self, sample_rate: int, *,
                 device: Optional[str] = None, buffer_length: float = 0.25, warmup: int = 4,
                 **kwargs) -> None:
        """
        :param str device: Recording device to use, none for default.
        :param float buffer_length: Adjust read size in seconds, the default 250ms.
        :param int warmup: Skip the first reads, in case of microphone auto-gaining, default 4, thus 1 second.
        :param kwargs: Extra options passed to :class:`pasimple.PaSimple`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._device: Optional[str] = device
        self._warmup: int = warmup
        self._num_reads: int = 0
        self._sample_format: int = pasimple.PA_SAMPLE_S16LE
        self._sample_rate: int = sample_rate
        self._read_size: int = round(buffer_length * sample_rate * pasimple.format2width(self._sample_format))
        self._options: Dict = kwargs
        self._pcm: Optional[pasimple.PaSimple] = None

    def __enter__(self) -> None:
        try:
            assert self._pcm is None
            self._pcm = pasimple.PaSimple(pasimple.PA_STREAM_RECORD,
                                          format=self._sample_format, channels=1, rate=self._sample_rate,
                                          app_name=__package__ or self.__class__.__name__, device_name=self._device,
                                          fragsize=self._read_size,
                                          **self._options)
        except pasimple.PaSimpleError as e:
            raise ModelError(self.__class__.__name__, self._device, str(e))
        else:
            self._logger.info(f"Entering {self._pcm}[{self._device}]")
            self._pcm.__enter__()

    def __iter__(self) -> Iterator[bytes]:
        assert self._pcm is not None
        return self

    def __next__(self) -> bytes:
        assert self._pcm is not None
        self._num_reads += 1
        while self._num_reads <= self._warmup:
            self._num_reads += 1
            self._pcm.read(self._read_size)
        return self._pcm.read(self._read_size)

    def __exit__(self, *args) -> None:
        assert self._pcm is not None
        self._pcm.__exit__(*args)
        self._pcm = None
        self._num_reads = 0