import logging
from typing import List

from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import create_app_session
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.output import create_output

from .api.message import ConfigError, ModuleError, SttMessageType, LlmMessageType
from .config import Config
from .factory import Factory, FactoryError
from .log import setup_logging, redirect_logging
from .queue import Message, MessageQueueR, MessageQueueW
from .utils.process import Rusage
from .utils.utils import Rstrip, PerfCounter


class MainLlm:
    def __init__(self, logger: logging.Logger, factory: Factory) -> None:
        self._logger: logging.Logger = logger
        self._factory: Factory = factory

    def _run(self, sst: MessageQueueR[SttMessageType], tts: MessageQueueW[LlmMessageType]) -> None:
        with self._factory.create_processor() as processor, processor:
            recording: bool = False
            utterance: List[str] = []

            tts.put(Message(LlmMessageType.End))  # all good, trigger first listen (if synthesizer process is also ok)
            for message in sst:
                self._logger.debug(str(message))

                if message.msg == SttMessageType.Start and not recording:
                    recording = True
                    tts.put(Message(LlmMessageType.FeedbackPos, message.data))
                    continue
                elif message.msg == SttMessageType.Reset and recording:
                    recording = False
                    utterance.clear()
                    tts.put(Message(LlmMessageType.FeedbackNeg, message.data))
                    continue
                elif message.msg == SttMessageType.Commit and recording:
                    recording = False
                    pass
                elif message.msg == SttMessageType.Stop:
                    tts.put(Message(LlmMessageType.FeedbackNeg, message.data))
                    return
                elif message.msg == SttMessageType.Utterance and message.data is not None:
                    if recording:
                        utterance.append(message.data)
                    continue
                else:
                    raise ValueError(message.msg.name)

                tts.put(Message(LlmMessageType.Start))
                self._logger.info("Starting LLM processing")
                with PerfCounter(self._logger, logging.INFO, "tokens", skip_delay=True) as counter:
                    for token in processor.generate(utterance):
                        tts.put(Message(LlmMessageType.Token, token))
                        counter()
                utterance.clear()
                tts.put(Message(LlmMessageType.End))

    @classmethod
    def run(cls, config: Config, sst: MessageQueueR[SttMessageType], tts: MessageQueueW[LlmMessageType]) -> int:
        logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
        with redirect_logging(logger), Rusage(logger, logging.INFO):
            try:
                factory: Factory = Factory(config)
                cls(logger, factory)._run(sst, tts)
            except KeyboardInterrupt:
                return 0
            except FactoryError as e:
                logger.error(str(e), exc_info=e.__cause__)
                return 1
            except (ConfigError, ModuleError) as e:
                logger.error(str(e))
                return 1
            except BaseException as e:
                logger.error(str(e), exc_info=e)
                return 1
            else:
                return 0
            finally:
                tts.put(Message(LlmMessageType.Stop))
                tts.close()

    @classmethod
    def run_cli(cls, config: Config) -> int:
        logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
        try:
            multiline: bool = False
            with redirect_logging(logger) as stdout, \
                    Rusage(logger, logging.INFO), \
                    create_app_session(output=create_output(stdout)):
                prompt_session: PromptSession = PromptSession(multiline=multiline, output=create_output(stdout))

                factory: Factory = Factory(config)
                with factory.create_processor() as processor, processor:
                    try:
                        while True:
                            prompt: str = prompt_session.prompt(
                                message=[("ansigreen bold", "> ")],
                                placeholder="<Meta+Enter>" if multiline else None,
                            )
                            if not prompt:  # or ^C or ^D
                                break

                            stripper: Rstrip = Rstrip()
                            print_formatted_text(FormattedText([("ansired bold", "> ")]), end="", flush=True)
                            with PerfCounter(logger, logging.INFO, "tokens", skip_delay=True) as counter:
                                for token in processor.generate(prompt.strip().splitlines(keepends=True)):
                                    print_formatted_text(stripper.push(token), end="", flush=True)
                                    counter()
                                print_formatted_text("\n", end="", flush=True)
                    except (KeyboardInterrupt, EOFError):
                        return 0
        except FactoryError as e:
            logger.error(str(e), exc_info=e.__cause__)
            return 1
        except (ConfigError, ModuleError) as e:
            logger.error(str(e))
            return 1
        except BaseException as e:
            logger.error(str(e), exc_info=e)
            return 1
        else:
            return 0