from typing import Dict, Iterator

from sttts.api.message import Message, LlmMessageType
from sttts.api.model import OutputFilter
from sttts.utils.utils import SineGenerator


class BeepFeedbackGenerator(OutputFilter):
    """
    Play beep sounds for certain internal messages.
    """

    def __init__(self, sample_rate: int, *, volume: float = 0.5) -> None:
        """
        :param float volume: Volume of the generated PCM, default 0.5.
        """

        beeper: SineGenerator = SineGenerator(sample_rate)
        self._feedback: Dict[LlmMessageType, bytes] = {
            LlmMessageType.Start: beeper.generate(800, volume, 0.25),
            LlmMessageType.End: beeper.generate(400, volume, 0.5),
            LlmMessageType.FeedbackPos: beeper.generate(800, volume, 0.25),
            LlmMessageType.FeedbackNeg: beeper.generate(400, volume, 0.5),
        }

    def accept(self, message: Message[LlmMessageType]) -> Iterator[Message[LlmMessageType]]:
        yield message

    def generate(self, message: Message[LlmMessageType]) -> Iterator[bytes]:
        if message.msg in self._feedback.keys():
            yield self._feedback[message.msg]