import logging
import math
import re
import struct
import time
from enum import Enum
from typing import Dict, TypeVar, Type, Optional, ContextManager

import numpy as np


class StringUtils:
    E = TypeVar("E", bound=Enum)

    @classmethod
    def strip_punctuation(cls, s: str) -> str:
        """Strip punctuation chars that might surround a hotword to be detected."""
        return s.strip().lstrip("\"'“¿([{-").rstrip("\"'.。,,!!??::;”)]}、").strip()

    @classmethod
    def get_enum_map(cls, e: Type[E], d: Dict[str, str]) -> Dict[str, E]:
        """Map configured hotwords to their enum value, with fallback to the enum name."""
        return {d.get(_.name.lower(), _.value): _ for _ in e}


class MathUtils:
    @classmethod
    def find_pow2(cls, target: int, narrow: int = 0) -> int:
        """Next greater (narrow=0) or smaller/equal (narrow=1) power of two."""
        return 2 ** (int(target).bit_length() - narrow)

    @classmethod
    def buf2arr(cls, buffer: bytes) -> np.ndarray:
        """Signed 16-bit little-endian PCM to single-precision numpy float."""
        return np.frombuffer(buffer, dtype=np.dtype(np.int16).newbyteorder("<")).astype(np.float32) / 32768.0


class SineGenerator:
    """Generate sine-wave beep sounds as 16bit LE buffer."""

    def __init__(self, sample_rate: int) -> None:
        self._sample_rate: int = sample_rate
        self._struct: struct.Struct = struct.Struct("<h")  # two-byte little-endian

    def generate(self, freq: int, volume: float, duration: float) -> bytes:
        val: float = max(0.0, min(32767.0, volume * 32767.0))
        rad: float = float(freq) * 2.0 * math.pi / float(self._sample_rate)
        return b"".join(self._struct.pack(int(val * math.sin(rad * float(i))))
                        for i in range(math.ceil(duration * float(self._sample_rate))))


class Rstrip:
    """Hold back trailing whitespace for example while printing LLM response tokens."""

    def __init__(self) -> None:
        self._buffer: str = ""

    def push(self, s: str) -> str:
        self._buffer, s = "", self._buffer + s

        m: Optional[re.Match] = re.match(r"\s+$", s)
        if m is not None:
            self._buffer, s = s[m.start():], s[:m.start()]

        return s


class NoopCtx(ContextManager[None]):
    def __enter__(self) -> None:
        pass

    def __exit__(self, *args) -> None:
        pass


class PerfCounter(ContextManager):
    """Context to measure elapsed time and a generic counter to show an average rate."""

    def __init__(self, logger: logging.Logger, level: int, unit: str, skip_delay: bool = False) -> None:
        self._logger: logging.Logger = logger
        self._level: int = level
        self._unit: str = unit
        self._skip_delay: bool = skip_delay
        self._counter: int = 0
        self._start_at: float = self._now()
        self._first_at: Optional[float] = None

    @classmethod
    def _now(cls) -> float:
        return time.perf_counter()

    def __enter__(self) -> "PerfCounter":
        self._counter = 0
        self._start_at = self._now()
        self._first_at = None
        return self

    def __call__(self, n: int = 1) -> None:
        if self._first_at is None:
            self._first_at = self._now()
        self._counter += n

    def __exit__(self, *args) -> None:
        if self._skip_delay and self._first_at is not None:
            duration: float = self._now() - self._first_at
        else:
            duration = self._now() - self._start_at
        if self._counter > 0 and duration > 0:
            self._logger.log(self._level, "{:.2f} {}/sec ({} / {:.3f}) ({:.3f} delay)".format(
                self._counter / duration, self._unit, self._counter, duration,
                self._first_at - self._start_at if self._first_at is not None else 0.0,
            ))