import logging
import time
from typing import Iterator, Optional

import pocketsphinx

from sttts.api.model import SpeechSegmenter


class SphinxSegmenter(SpeechSegmenter):
    """
    Use `PocketSphinx <https://github.com/cmusphinx/pocketsphinx>`__ :class:`pocketsphinx.Endpointer`
    for VAD voice activity detection, similar to the basic :class:`pocketsphinx.Segmenter`.
    """

    def __init__(self, sample_rate: int, buffer_limit: float, pause_limit: float, *,
                 mode: int = 0, window: float = 0.3, ratio: float = 0.9,
                 **kwargs) -> None:
        """
        :param int mode: Aggressiveness of voice activity detection (0-3, loose-strict, default 0).
        :param float window: Length in seconds of window for decision (0.3).
        :param float ratio: Fraction of window that must be speech or non-speech to make a transition (0.9).
        :param kwargs: Extra options passed to :class:`pocketsphinx.Endpointer`.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._segmenter = pocketsphinx.Endpointer(sample_rate=sample_rate, vad_mode=mode, window=window, ratio=ratio,
                                                  **kwargs)
        self._buffer: bytes = b""
        self._speech_buffer: bytes = b""
        self._buffer_limit: float = round(buffer_limit * 2 * sample_rate)
        self._pause_limit: float = float(pause_limit)
        self._pause_since: float = self._now()

    def __enter__(self) -> None:
        self._pause_since = self._now()

    def __exit__(self, *args) -> None:
        self._segmenter.end_stream(b"\x00\x00")  # XXX:
        self._buffer = b""
        self._speech_buffer = b""

    def push(self, buffer: bytes) -> Iterator[bytes]:
        if not buffer:
            return

        self._buffer += buffer
        while len(self._buffer) >= self._segmenter.frame_bytes:
            buffer, self._buffer = (self._buffer[:self._segmenter.frame_bytes],
                                    self._buffer[self._segmenter.frame_bytes:])
            speech: Optional[bytes] = self._segmenter.process(buffer)
            if speech is not None:
                if not self._speech_buffer:
                    self._logger.info("Detected speech start")
                self._speech_buffer += speech

                if not self._segmenter.in_speech:
                    self._logger.info("Detected speech end")
                    if self._speech_buffer:
                        buffer, self._speech_buffer = self._speech_buffer, b""
                        yield buffer

                self._pause_since = self._now()
            elif self._pause_since < self._now() - self._pause_limit:
                self._pause_since = self._now()
                raise TimeoutError(str(self._pause_limit))

        if len(self._buffer) >= self._buffer_limit:
            raise OverflowError(str(len(self._buffer)))

    @classmethod
    def _now(cls) -> float:
        return time.monotonic()