import math

import numpy as np

from sttts.utils.utils import MathUtils
from .buffer import BufferedSegmenter


class SimpleSegmenter(BufferedSegmenter):
    """
    Determine silence/speech audio by a simple absolute RMS/volume threshold, which can require tweaking and good
    recording environments.
    """

    def __init__(self, sample_rate: int, buffer_limit: float, pause_limit: float, *,
                 frames: float = 2.0, threshold: float = 0.2) -> None:
        """
        :param float frames: Length of the sliding look behind window in seconds (2.0).
        :param float threshold: RMS threshold, smaller values will be considered silent. Default 0.2.
        """
        super().__init__(sample_rate, buffer_limit, pause_limit, frames=frames)
        self._threshold: float = threshold

    def _check(self, buffer: bytes) -> bool:
        rms: float = math.sqrt(np.average(MathUtils.buf2arr(buffer) ** 2))
        return rms >= self._threshold


class MedianSegmenter(BufferedSegmenter):
    """
    Determine silence/speech audio by comparing the RMS with the median (percentile) energy.
    Idea: If the median is smaller than the average, there's peaks, i.e., a flat noise/silence distribution.

    This simple method should be self-adaptive wrt. background noise to automatically detect volume outliers.
    The calculation is applied to a sliding window of past audio frames, with a change from speech to silence
    leading to returning the buffered utterance as a whole.
    """

    def __init__(self, sample_rate: int, buffer_limit: float, pause_limit: float, *,
                 frames: float = 2.0, percentile: int = 50, threshold: float = 0.5) -> None:
        """
        :param float frames: Length of the sliding look behind window in seconds (2.0).
        :param int percentile: Percentile that is compared with the RMS energy, for example 50 for median (default).
        :param float threshold: Percentile by RMS factor, greater will be considered as silence. Default 0.5.
        """

        super().__init__(sample_rate, buffer_limit, pause_limit, frames=frames)
        self._energy_percentile: int = percentile
        self._energy_threshold: float = threshold

    def _check(self, buffer: bytes) -> bool:
        window: np.ndarray = MathUtils.buf2arr(buffer) ** 2  # abs energy

        rms: float = math.sqrt(np.average(window))
        med: float = math.sqrt(np.percentile(window, self._energy_percentile))
        fac: float = med / rms if rms > 0.0 else 0.0
        self._logger.debug(f"RMS: {rms:.6f}, MED {med:.6f}, MED/RMS {fac:.6f}")

        return fac <= self._energy_threshold