import logging
import time
from abc import abstractmethod
from typing import Iterator, Optional

from sttts.api.model import SpeechSegmenter


class BufferedSegmenter(SpeechSegmenter):
    """
    Skeleton for implementations that continuously monitor a sliding window for voice activity.
    Provides buffering and buffer/pause limit checks.
    """

    def __init__(self, sample_rate: int, buffer_limit: float, pause_limit: float, *, frames: float) -> None:
        """
        :param float frames: Length of the sliding look behind window in seconds.
        """

        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._buffer: bytes = b""
        self._sample_rate: int = sample_rate
        self._num_frames: int = round(frames * sample_rate)
        self._buffer_len: int = self._num_frames * 2
        self._buffer_limit: int = round(buffer_limit * sample_rate) * 2
        self._pause_limit: float = float(pause_limit)
        self._pause_since: Optional[float] = None

    def __enter__(self) -> None:
        self._pause_since = self._now()

    def __exit__(self, *args) -> None:
        self._buffer = b""

    @abstractmethod
    def _check(self, buffer: bytes) -> bool:
        """Detect voice for the sliding window, ``False`` for silence."""
        raise NotImplementedError

    def push(self, buffer: bytes) -> Iterator[bytes]:
        if not buffer:
            return

        self._buffer += buffer
        if len(self._buffer) < self._buffer_len:
            return

        window: bytes = self._buffer[-self._buffer_len:]
        if not self._check(window):
            if self._pause_since is None:
                self._logger.info("Detected speech end")
                self._pause_since = self._now()
                self._buffer, result = b"", self._buffer
                yield result
            else:
                self._buffer = window
                if self._pause_since < self._now() - self._pause_limit:
                    self._pause_since = self._now()
                    raise TimeoutError(str(self._pause_limit))
        else:
            if self._pause_since is not None:
                self._logger.info("Detected speech start")
                self._pause_since = None
            elif len(self._buffer) >= self._buffer_limit:
                raise OverflowError(str(len(self._buffer)))

    @classmethod
    def _now(cls) -> float:
        return time.monotonic()