92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
# MIT License
|
|
# Copyright (c) 2024
|
|
"""Command line entry point."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from .dsl import load_dsl, save_dsl
|
|
from .executor import ExecContext, execute_spec
|
|
from .llm import DummyLLM, LLMClient
|
|
from .recorder import Recorder
|
|
from .schema import EventRecord
|
|
|
|
|
|
def cmd_record(args: argparse.Namespace) -> None:
|
|
"""Start multimodal recording."""
|
|
rec = Recorder(Path(args.out), hotkey=args.hotkey, fps=args.fps, screen=args.screen)
|
|
print(f"Recording... press {args.hotkey} to stop.")
|
|
session_dir = rec.start()
|
|
print(f"Session saved to: {session_dir}")
|
|
|
|
|
|
def _load_events(path: Path) -> list[EventRecord]:
|
|
events = []
|
|
with path.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
events.append(EventRecord.parse_obj(json.loads(line)))
|
|
return events
|
|
|
|
|
|
def cmd_infer(args: argparse.Namespace) -> None:
|
|
"""Infer DSL from recorded events."""
|
|
events = _load_events(Path(args.session))
|
|
client: LLMClient = DummyLLM()
|
|
spec = client.generate(events)
|
|
out_path = Path(args.output)
|
|
save_dsl(spec, out_path)
|
|
print(f"DSL saved to {out_path}")
|
|
|
|
|
|
def cmd_run(args: argparse.Namespace) -> None:
|
|
"""Execute DSL."""
|
|
spec = load_dsl(Path(args.dsl))
|
|
if args.params:
|
|
spec.params.update(json.loads(args.params))
|
|
ctx = ExecContext(allow_title=args.allow_title, dry_run=args.dry_run)
|
|
execute_spec(spec, ctx)
|
|
print("Done")
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
"""Build CLI parser."""
|
|
parser = argparse.ArgumentParser(description="示教式自动化原型")
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
p_rec = sub.add_parser("record", help="开始录制")
|
|
p_rec.add_argument("--out", type=str, default="sessions", help="输出目录")
|
|
p_rec.add_argument("--hotkey", type=str, default="F9", help="停止录制的热键")
|
|
p_rec.add_argument("--fps", type=int, default=12, help="录屏帧率")
|
|
p_rec.add_argument("--screen", type=int, default=0, help="屏幕编号,默认主屏")
|
|
p_rec.set_defaults(func=cmd_record)
|
|
|
|
p_inf = sub.add_parser("infer", help="LLM 归纳生成 DSL")
|
|
p_inf.add_argument("--session", type=str, required=True, help="events.jsonl 文件")
|
|
p_inf.add_argument("--output", type=str, default="flow.yaml", help="输出 DSL 路径")
|
|
p_inf.set_defaults(func=cmd_infer)
|
|
|
|
p_run = sub.add_parser("run", help="执行 DSL")
|
|
p_run.add_argument("--dsl", type=str, required=True, help="DSL YAML 文件")
|
|
p_run.add_argument("--params", type=str, help="JSON 参数覆盖")
|
|
p_run.add_argument("--allow-title", type=str, default="记事本|Notepad", help="允许的窗口标题正则")
|
|
p_run.add_argument("--dry-run", action="store_true", help="仅打印动作不执行")
|
|
p_run.set_defaults(func=cmd_run)
|
|
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
"""Entrypoint."""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|