import logging
import signal
from typing import Iterator, Optional

from prompt_toolkit import print_formatted_text
from prompt_toolkit.application import create_app_session
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.output import create_output

from .api.message import ConfigError, ModuleError, CtlMessageType, SttMessageType
from .api.model import AudioSource, SpeechSegmenter, Recognizer
from .config import Config
from .factory import Factory, FactoryError
from .log import setup_logging, redirect_logging
from .queue import Message, MessageQueueR, MessageQueueW, DummyQueue
from .utils.process import Rusage


class MainStt:
    def __init__(self, logger: logging.Logger, factory: Factory) -> None:
        self._logger: logging.Logger = logger
        self._factory: Factory = factory

    def _listen(self,
                ctl_queue: MessageQueueR[CtlMessageType],
                source: AudioSource,
                segmenter: SpeechSegmenter, recognizer: Recognizer) -> Iterator[Message[SttMessageType]]:
        recording: bool = False
        for raw in source:
            event: Optional[Message[CtlMessageType]] = ctl_queue.get_nowait()
            if event is not None and event.msg != CtlMessageType.Listen:
                raise InterruptedError(str(event))
            elif event is not None:
                self._logger.info("Wakeup event, continuing")

            try:
                for buf in segmenter.push(raw):
                    for message in recognizer.accept(buf):
                        self._logger.info(f"Input: {message}")
                        if message.msg == SttMessageType.Start:
                            if not recording:
                                recording = True
                                yield message
                            else:
                                yield Message(SttMessageType.Utterance, message.data)
                        elif message.msg == SttMessageType.Reset:
                            if not recording:
                                yield Message(SttMessageType.Utterance, message.data)
                            else:
                                recording = False
                                yield message
                        elif message.msg == SttMessageType.Commit:
                            if not recording:  # spurious
                                yield Message(SttMessageType.Utterance, message.data)
                            else:  # success, don't listen until triggered again
                                yield message
                                return
                        elif message.msg == SttMessageType.Stop:
                            if not recording:
                                raise InterruptedError(str(message))
                            else:
                                yield Message(SttMessageType.Utterance, message.data)
                        elif message.msg == SttMessageType.Utterance:
                            yield message  # will get ignored if not started, though
                        else:
                            raise ValueError(message.msg.name)
            except TimeoutError as e:  # silence timeout during utterance
                if recording:
                    yield Message(SttMessageType.Reset)
                    raise OverflowError(f"Silence limit reached: {str(e)}") from None
            except OverflowError as e:  # too long utterance segment buffer
                if recording:
                    yield Message(SttMessageType.Reset)
                raise OverflowError(f"Utterance limit reached: {str(e)}") from None

        raise InterruptedError(f"{source.__class__.__name__} closed or depleted")

    def _run(self, ctl_queue: MessageQueueR[CtlMessageType], sst_queue: MessageQueueW[SttMessageType]) -> None:
        with self._factory.create_recognizer() as recognizer:
            source: AudioSource = self._factory.create_source_for(recognizer)
            segmenter: SpeechSegmenter = self._factory.create_speech_segmenter_for(recognizer)

            for event in ctl_queue:  # block until listen command, otherwise propagate stop and exit
                if event.msg != CtlMessageType.Listen:
                    self._logger.warning(f"Interrupted: {str(event)}")
                    return
                else:
                    self._logger.info("Wakeup event, listening")

                while True:
                    with source, segmenter:  # keep recognizer and source, mute only
                        try:
                            recognizer.reset()
                            for message in self._listen(ctl_queue, source, segmenter, recognizer):
                                sst_queue.put(message)
                        except OverflowError as e:
                            self._logger.warning(f"Rollback: {e}")
                            continue
                        except InterruptedError as e:
                            self._logger.warning(f"Interrupted: {e}")
                            return
                        else:  # commit
                            self._logger.info("Start processing, mute")
                            break

    def _run_cli(self, ctl_queue: MessageQueueR[CtlMessageType]) -> None:
        with self._factory.create_recognizer() as recognizer:
            source: AudioSource = self._factory.create_source_for(recognizer)
            segmenter: SpeechSegmenter = self._factory.create_speech_segmenter_for(recognizer)

            while True:
                with source, segmenter:
                    try:
                        recognizer.reset()
                        for message in self._listen(ctl_queue, source, segmenter, recognizer):
                            print_formatted_text(
                                FormattedText([("ansired bold", ">")]),
                                FormattedText([("bold", f"[{message.msg.name}]")])
                                if message.data is None or message.msg != SttMessageType.Utterance
                                else message.data.rstrip(),
                                end="\n", flush=True
                            )
                    except OverflowError as e:
                        self._logger.warning(f"Rollback: {e}")
                        continue
                    except InterruptedError as e:
                        self._logger.warning(f"Interrupted: {e}")
                        return
                    else:
                        break

    @classmethod
    def run_cli(cls, config: Config) -> int:
        logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
        with redirect_logging(logger) as stdout, Rusage(logger, logging.INFO):
            with create_app_session(output=create_output(stdout)):
                try:
                    factory: Factory = Factory(config)
                    cls(logger, factory)._run_cli(DummyQueue())
                except KeyboardInterrupt:
                    return 0
                except FactoryError as e:
                    logger.error(str(e), exc_info=e.__cause__)
                    return 1
                except (ConfigError, ModuleError) as e:
                    logger.error(str(e))
                    return 1
                except BaseException as e:
                    logger.error(str(e), exc_info=e)
                    return 1
                else:
                    return 0

    @classmethod
    def run(cls, config: Config, ctl: MessageQueueR[CtlMessageType], q: MessageQueueW[SttMessageType]) -> int:
        signal.signal(signal.SIGINT, signal.SIG_IGN)  # handled and propagated by llm main
        logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
        with redirect_logging(logger), Rusage(logger, logging.INFO):
            try:
                factory: Factory = Factory(config)
                cls(logger, factory)._run(ctl, q)
            except FactoryError as e:
                logger.error(str(e), exc_info=e.__cause__)
                return 1
            except (ConfigError, ModuleError) as e:
                logger.error(str(e))
                return 1
            except BaseException as e:
                logger.error(str(e), exc_info=e)
                return 1
            else:
                return 0
            finally:
                q.put(Message(SttMessageType.Stop))
                q.close()