import logging
from typing import List
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import create_app_session
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.output import create_output
from .api.message import ConfigError, ModuleError, SttMessageType, LlmMessageType
from .config import Config
from .factory import Factory, FactoryError
from .log import setup_logging, redirect_logging
from .queue import Message, MessageQueueR, MessageQueueW
from .utils.process import Rusage
from .utils.utils import Rstrip, PerfCounter
class MainLlm:
def __init__(self, logger: logging.Logger, factory: Factory) -> None:
self._logger: logging.Logger = logger
self._factory: Factory = factory
def _run(self, sst: MessageQueueR[SttMessageType], tts: MessageQueueW[LlmMessageType]) -> None:
with self._factory.create_processor() as processor, processor:
recording: bool = False
utterance: List[str] = []
tts.put(Message(LlmMessageType.End)) # all good, trigger first listen (if synthesizer process is also ok)
for message in sst:
self._logger.debug(str(message))
if message.msg == SttMessageType.Start and not recording:
recording = True
tts.put(Message(LlmMessageType.FeedbackPos, message.data))
continue
elif message.msg == SttMessageType.Reset and recording:
recording = False
utterance.clear()
tts.put(Message(LlmMessageType.FeedbackNeg, message.data))
continue
elif message.msg == SttMessageType.Commit and recording:
recording = False
pass
elif message.msg == SttMessageType.Stop:
tts.put(Message(LlmMessageType.FeedbackNeg, message.data))
return
elif message.msg == SttMessageType.Utterance and message.data is not None:
if recording:
utterance.append(message.data)
continue
else:
raise ValueError(message.msg.name)
tts.put(Message(LlmMessageType.Start))
self._logger.info("Starting LLM processing")
with PerfCounter(self._logger, logging.INFO, "tokens", skip_delay=True) as counter:
for token in processor.generate(utterance):
tts.put(Message(LlmMessageType.Token, token))
counter()
utterance.clear()
tts.put(Message(LlmMessageType.End))
@classmethod
def run(cls, config: Config, sst: MessageQueueR[SttMessageType], tts: MessageQueueW[LlmMessageType]) -> int:
logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
with redirect_logging(logger), Rusage(logger, logging.INFO):
try:
factory: Factory = Factory(config)
cls(logger, factory)._run(sst, tts)
except KeyboardInterrupt:
return 0
except FactoryError as e:
logger.error(str(e), exc_info=e.__cause__)
return 1
except (ConfigError, ModuleError) as e:
logger.error(str(e))
return 1
except BaseException as e:
logger.error(str(e), exc_info=e)
return 1
else:
return 0
finally:
tts.put(Message(LlmMessageType.Stop))
tts.close()
@classmethod
def run_cli(cls, config: Config) -> int:
logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
try:
multiline: bool = False
with redirect_logging(logger) as stdout, \
Rusage(logger, logging.INFO), \
create_app_session(output=create_output(stdout)):
prompt_session: PromptSession = PromptSession(multiline=multiline, output=create_output(stdout))
factory: Factory = Factory(config)
with factory.create_processor() as processor, processor:
try:
while True:
prompt: str = prompt_session.prompt(
message=[("ansigreen bold", "> ")],
placeholder="<Meta+Enter>" if multiline else None,
)
if not prompt: # or ^C or ^D
break
stripper: Rstrip = Rstrip()
print_formatted_text(FormattedText([("ansired bold", "> ")]), end="", flush=True)
with PerfCounter(logger, logging.INFO, "tokens", skip_delay=True) as counter:
for token in processor.generate(prompt.strip().splitlines(keepends=True)):
print_formatted_text(stripper.push(token), end="", flush=True)
counter()
print_formatted_text("\n", end="", flush=True)
except (KeyboardInterrupt, EOFError):
return 0
except FactoryError as e:
logger.error(str(e), exc_info=e.__cause__)
return 1
except (ConfigError, ModuleError) as e:
logger.error(str(e))
return 1
except BaseException as e:
logger.error(str(e), exc_info=e)
return 1
else:
return 0