import logging
from typing import Tuple

import librosa
import librosa.feature
import numpy as np

from sttts.api.message import ModuleError
from sttts.utils.utils import MathUtils, PerfCounter
from .buffer import BufferedSegmenter


class BandSegmenter(BufferedSegmenter):
    """
    Use the `librosa <https://librosa.org/doc/latest/index.html>`__
    `STFT <https://librosa.org/doc/main/generated/librosa.stft.html>`__ FFT implementation as simple band-pass filter.
    The average contribution of typical voice frequencies is compared against other frequencies in the spectrum.
    This gives a voice-vs-noise estimate, with a configurable threshold.
    """

    def __init__(self, sample_rate: int, buffer_limit: float, pause_limit: float, *,
                 frames: float = 1.0, threshold: float = 1.0,
                 freq_start: int = 256, freq_end: int = 4096) -> None:
        """
        :param float frames: Length of the sliding look behind window in seconds (1.0).
               This also directly influences the possible FFT resolution.
        :param float threshold: Average voice frequency compared to other frequencies, default 1.0.
        :param int freq_start: Lower band-pass, where the human voice typically starts (256).
        :param int freq_end: Upper band-pass, where the human voice typically ends (4096).
        """

        super().__init__(sample_rate, buffer_limit, pause_limit, frames=frames)
        self._threshold: float = threshold
        self._fft_width: int = MathUtils.find_pow2(self._num_frames, narrow=1)

        fft_frequencies: np.ndarray = librosa.fft_frequencies(sr=self._sample_rate, n_fft=self._fft_width)
        self._voice_frequencies: Tuple[int, int] = self._freq_bounds(fft_frequencies, freq_start, freq_end)
        self._logger.info(f"Using {self._fft_width} FFT, {len(fft_frequencies)} bands until {fft_frequencies[-1]} Hz, "
                          f"{self._voice_frequencies[1] - self._voice_frequencies[0]} voice bands")

    def _freq_bounds(self, fft_frequencies: np.ndarray, freq_start: int, freq_end: int) -> Tuple[int, int]:
        index_start = next((i for i, f in enumerate(fft_frequencies) if f >= freq_start), None)  # type: ignore
        index_end = next((i for i, f in enumerate(fft_frequencies) if f > freq_end), None)  # type: ignore
        if index_start is None or index_end is None or \
                index_start >= index_end or index_start == 0 or index_end >= len(fft_frequencies) - 1:
            raise ModuleError(self.__class__.__name__, f"Cannot split FFT by {freq_start}-{freq_end} Hz")
        return index_start, index_end

    def _check(self, buffer: bytes) -> bool:
        with PerfCounter(self._logger, logging.DEBUG, "msec") as counter:
            fft: np.ndarray = np.abs(librosa.stft(
                y=MathUtils.buf2arr(buffer), n_fft=self._fft_width, win_length=self._fft_width,
                hop_length=self._fft_width // 2, center=True,
            ))
            counter(round(self._num_frames / self._sample_rate * 1000.0))

            num_bands, num_samples = fft.shape
            noise_lo = fft[:self._voice_frequencies[0]].flatten()
            voice = fft[self._voice_frequencies[0]:self._voice_frequencies[1]].flatten()
            noise_hi = fft[self._voice_frequencies[1]:].flatten()
            voice_avg: float = sum(voice) / len(voice)
            noise_avg: float = (sum(noise_lo) + sum(noise_hi)) / (len(noise_lo) + len(noise_hi))
            ratio: float = voice_avg / noise_avg

        self._logger.debug(f"FFT {self._fft_width} ({num_bands} bands, {num_samples} samples), "
                           f"Voice/Noise {ratio:.6f} ({voice_avg:.6f} / {noise_avg:.6f})")
        return ratio >= self._threshold