import logging
import signal
from typing import Iterator, Optional
from prompt_toolkit import print_formatted_text
from prompt_toolkit.application import create_app_session
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.output import create_output
from .api.message import ConfigError, ModuleError, CtlMessageType, SttMessageType
from .api.model import AudioSource, SpeechSegmenter, Recognizer
from .config import Config
from .factory import Factory, FactoryError
from .log import setup_logging, redirect_logging
from .queue import Message, MessageQueueR, MessageQueueW, DummyQueue
from .utils.process import Rusage
class MainStt:
def __init__(self, logger: logging.Logger, factory: Factory) -> None:
self._logger: logging.Logger = logger
self._factory: Factory = factory
def _listen(self,
ctl_queue: MessageQueueR[CtlMessageType],
source: AudioSource,
segmenter: SpeechSegmenter, recognizer: Recognizer) -> Iterator[Message[SttMessageType]]:
recording: bool = False
for raw in source:
event: Optional[Message[CtlMessageType]] = ctl_queue.get_nowait()
if event is not None and event.msg != CtlMessageType.Listen:
raise InterruptedError(str(event))
elif event is not None:
self._logger.info("Wakeup event, continuing")
try:
for buf in segmenter.push(raw):
for message in recognizer.accept(buf):
self._logger.info(f"Input: {message}")
if message.msg == SttMessageType.Start:
if not recording:
recording = True
yield message
else:
yield Message(SttMessageType.Utterance, message.data)
elif message.msg == SttMessageType.Reset:
if not recording:
yield Message(SttMessageType.Utterance, message.data)
else:
recording = False
yield message
elif message.msg == SttMessageType.Commit:
if not recording: # spurious
yield Message(SttMessageType.Utterance, message.data)
else: # success, don't listen until triggered again
yield message
return
elif message.msg == SttMessageType.Stop:
if not recording:
raise InterruptedError(str(message))
else:
yield Message(SttMessageType.Utterance, message.data)
elif message.msg == SttMessageType.Utterance:
yield message # will get ignored if not started, though
else:
raise ValueError(message.msg.name)
except TimeoutError as e: # silence timeout during utterance
if recording:
yield Message(SttMessageType.Reset)
raise OverflowError(f"Silence limit reached: {str(e)}") from None
except OverflowError as e: # too long utterance segment buffer
if recording:
yield Message(SttMessageType.Reset)
raise OverflowError(f"Utterance limit reached: {str(e)}") from None
raise InterruptedError(f"{source.__class__.__name__} closed or depleted")
def _run(self, ctl_queue: MessageQueueR[CtlMessageType], sst_queue: MessageQueueW[SttMessageType]) -> None:
with self._factory.create_recognizer() as recognizer:
source: AudioSource = self._factory.create_source_for(recognizer)
segmenter: SpeechSegmenter = self._factory.create_speech_segmenter_for(recognizer)
for event in ctl_queue: # block until listen command, otherwise propagate stop and exit
if event.msg != CtlMessageType.Listen:
self._logger.warning(f"Interrupted: {str(event)}")
return
else:
self._logger.info("Wakeup event, listening")
while True:
with source, segmenter: # keep recognizer and source, mute only
try:
recognizer.reset()
for message in self._listen(ctl_queue, source, segmenter, recognizer):
sst_queue.put(message)
except OverflowError as e:
self._logger.warning(f"Rollback: {e}")
continue
except InterruptedError as e:
self._logger.warning(f"Interrupted: {e}")
return
else: # commit
self._logger.info("Start processing, mute")
break
def _run_cli(self, ctl_queue: MessageQueueR[CtlMessageType]) -> None:
with self._factory.create_recognizer() as recognizer:
source: AudioSource = self._factory.create_source_for(recognizer)
segmenter: SpeechSegmenter = self._factory.create_speech_segmenter_for(recognizer)
while True:
with source, segmenter:
try:
recognizer.reset()
for message in self._listen(ctl_queue, source, segmenter, recognizer):
print_formatted_text(
FormattedText([("ansired bold", ">")]),
FormattedText([("bold", f"[{message.msg.name}]")])
if message.data is None or message.msg != SttMessageType.Utterance
else message.data.rstrip(),
end="\n", flush=True
)
except OverflowError as e:
self._logger.warning(f"Rollback: {e}")
continue
except InterruptedError as e:
self._logger.warning(f"Interrupted: {e}")
return
else:
break
@classmethod
def run_cli(cls, config: Config) -> int:
logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
with redirect_logging(logger) as stdout, Rusage(logger, logging.INFO):
with create_app_session(output=create_output(stdout)):
try:
factory: Factory = Factory(config)
cls(logger, factory)._run_cli(DummyQueue())
except KeyboardInterrupt:
return 0
except FactoryError as e:
logger.error(str(e), exc_info=e.__cause__)
return 1
except (ConfigError, ModuleError) as e:
logger.error(str(e))
return 1
except BaseException as e:
logger.error(str(e), exc_info=e)
return 1
else:
return 0
@classmethod
def run(cls, config: Config, ctl: MessageQueueR[CtlMessageType], q: MessageQueueW[SttMessageType]) -> int:
signal.signal(signal.SIGINT, signal.SIG_IGN) # handled and propagated by llm main
logger: logging.Logger = setup_logging(config.logging, cls.__name__, config.log_level)
with redirect_logging(logger), Rusage(logger, logging.INFO):
try:
factory: Factory = Factory(config)
cls(logger, factory)._run(ctl, q)
except FactoryError as e:
logger.error(str(e), exc_info=e.__cause__)
return 1
except (ConfigError, ModuleError) as e:
logger.error(str(e))
return 1
except BaseException as e:
logger.error(str(e), exc_info=e)
return 1
else:
return 0
finally:
q.put(Message(SttMessageType.Stop))
q.close()