#!/usr/bin/env python3

"""
Framework for running Ollama or GPT4All models with voice recognition and text-to-speech output.
"""

import argparse
import sys
from multiprocessing import set_start_method, Process
from pathlib import Path

from .config import Config, ConfigError
from .main_llm import MainLlm
from .main_stt import MainStt
from .main_tts import MainTts
from .queue import IpcMessageQueue


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--config", type=Path, metavar="YAML", required=True,
                        help="configuration file")
    parser.add_argument("--log-level", type=str, metavar="LEVEL", default="INFO",
                        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
                        help="logging level (DEBUG, INFO, WARNING, or ERROR)")
    parser.add_argument("--cli", type=str, metavar="MODE", default=None,
                        choices=["stt", "llm", "tts"],
                        help="run certain cli instead of full pipeline (stt, llm, or tts)")

    try:
        args = parser.parse_args()
        config: Config = Config.from_file(filename=args.config, log_level=args.log_level)
    except ConfigError as e:
        print(str(e), file=sys.stderr)
        return 1

    if args.cli is None:
        pass
    elif args.cli == "stt":
        return MainStt.run_cli(config)
    elif args.cli == "llm":
        return MainLlm.run_cli(config)
    elif args.cli == "tts":
        return MainTts.run_cli(config)
    else:
        return 1

    set_start_method("spawn")
    ctl_queue: IpcMessageQueue = IpcMessageQueue()
    stt_queue: IpcMessageQueue = IpcMessageQueue()
    tts_queue: IpcMessageQueue = IpcMessageQueue()

    sst: Process = Process(target=MainStt.run, args=(config, ctl_queue, stt_queue), daemon=True)
    tts: Process = Process(target=MainTts.run, args=(config, tts_queue, ctl_queue), daemon=True)
    sst.start()
    tts.start()

    try:
        rv: int = MainLlm.run(config, stt_queue, tts_queue)
    except BaseException as e:
        print(str(e), file=sys.stderr)
        rv = 1
        tts.terminate()
        sst.terminate()

    tts.join()
    sst.join()
    return max(rv, 1 if tts.exitcode != 0 else 0, 1 if sst.exitcode != 0 else 1)


if __name__ == "__main__":
    sys.exit(main())