397 lines
15 KiB
Python
397 lines
15 KiB
Python
# MIT License
|
||
# Copyright (c) 2024
|
||
"""多模态归纳:读取 session 目录,组装提示,调用 LLM,生成 DSL"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import base64
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import requests # type: ignore
|
||
|
||
try:
|
||
# 优先使用 python-dotenv,缺失则退回手动解析
|
||
from dotenv import load_dotenv # type: ignore
|
||
except Exception:
|
||
load_dotenv = None
|
||
|
||
from .prompt_templates import SYSTEM_PROMPT, render_user_prompt
|
||
from .schema import DSLSpec, EventRecord, FramePaths, UISnapshot, UISelector
|
||
|
||
|
||
# --------- Pydantic v1/v2 兼容辅助 ---------
|
||
def _model_validate(cls, data: Any) -> Any:
|
||
if hasattr(cls, "model_validate"):
|
||
return cls.model_validate(data) # type: ignore[attr-defined]
|
||
return cls.parse_obj(data) # type: ignore[attr-defined]
|
||
|
||
|
||
def _model_dump(obj: Any, **kwargs: Any) -> Dict[str, Any]:
|
||
if hasattr(obj, "model_dump"):
|
||
return obj.model_dump(**kwargs) # type: ignore[attr-defined]
|
||
return obj.dict(**kwargs) # type: ignore[attr-defined]
|
||
|
||
|
||
def _load_env_file() -> None:
|
||
"""加载项目根目录的 .env,优先使用 python-dotenv,缺失则手工解析"""
|
||
env_path = Path(__file__).resolve().parent.parent / ".env"
|
||
if load_dotenv:
|
||
load_dotenv(env_path)
|
||
return
|
||
if not env_path.exists():
|
||
return
|
||
for line in env_path.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line or line.startswith("#") or "=" not in line:
|
||
continue
|
||
key, val = line.split("=", 1)
|
||
os.environ.setdefault(key.strip(), val.strip())
|
||
|
||
|
||
def _coerce_assertions(spec_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""将 assertions 内的非字符串条目转换为字符串,防止验证失败"""
|
||
assertions = spec_dict.get("assertions")
|
||
if isinstance(assertions, list):
|
||
new_items = []
|
||
for item in assertions:
|
||
if isinstance(item, str):
|
||
new_items.append(item)
|
||
else:
|
||
try:
|
||
new_items.append(json.dumps(item, ensure_ascii=False))
|
||
except Exception:
|
||
new_items.append(str(item))
|
||
spec_dict["assertions"] = new_items
|
||
return spec_dict
|
||
|
||
|
||
def _strip_code_fences(text: str) -> str:
|
||
"""去除 ```json ... ``` 或 ``` ... ``` 包裹"""
|
||
stripped = text.strip()
|
||
if stripped.startswith("```"):
|
||
parts = stripped.split("```")
|
||
if len(parts) >= 3:
|
||
return parts[1].lstrip("json").strip() if parts[1].startswith("json") else parts[1].strip()
|
||
return stripped
|
||
|
||
|
||
def _normalize_steps(spec_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""规范化 steps 字段到 schema 支持的动作/字段"""
|
||
steps = spec_dict.get("steps")
|
||
if not isinstance(steps, list):
|
||
return spec_dict
|
||
normalized = []
|
||
for step in steps:
|
||
if not isinstance(step, dict):
|
||
continue
|
||
# 将 selector -> target
|
||
if "target" not in step and "selector" in step:
|
||
step["target"] = step["selector"]
|
||
step.pop("selector", None)
|
||
|
||
action = step.get("action")
|
||
# value -> text 归一化,兼容 set_value/type
|
||
if "value" in step and "text" not in step:
|
||
step["text"] = step.get("value")
|
||
step.pop("value", None)
|
||
|
||
# 处理 wait_for_window 自定义动作
|
||
if action == "wait_for_window":
|
||
title = step.pop("window_title_part", None)
|
||
timeout = step.pop("timeout", None)
|
||
step["action"] = "wait_for"
|
||
step["target"] = step.get("target") or {}
|
||
if title:
|
||
step["target"].setdefault("Name", title)
|
||
step["target"].setdefault("ControlType", "WindowControl")
|
||
if timeout:
|
||
secs = float(timeout) / 1000.0
|
||
step["waits"] = {"appear": secs, "disappear": 5.0}
|
||
# 若 action 不在允许列表,降级为 assert_exists
|
||
if step.get("action") not in {"click", "type", "set_value", "assert_exists", "wait_for"}:
|
||
step["action"] = "assert_exists"
|
||
|
||
# 标准化 ControlType 命名
|
||
tgt = step.get("target", {})
|
||
if isinstance(tgt, dict) and tgt.get("ControlType") == "Window":
|
||
tgt["ControlType"] = "WindowControl"
|
||
normalized.append(step)
|
||
spec_dict["steps"] = normalized
|
||
return spec_dict
|
||
|
||
|
||
# ---------------- LLM 抽象 ----------------
|
||
class LLMClient:
|
||
"""LLM 抽象接口"""
|
||
|
||
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
|
||
raise NotImplementedError
|
||
|
||
|
||
class DummyLLM(LLMClient):
|
||
"""纯文本离线生成,基于事件启发式"""
|
||
|
||
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
|
||
# 简单规则:点击 -> click,text_input -> type;若窗口标题包含记事本且有文本输入,补保存按钮
|
||
data = json.loads(user_prompt.split("事件摘要(JSON):")[-1])
|
||
steps: List[Dict[str, Any]] = []
|
||
params: Dict[str, Any] = {}
|
||
assertions: List[str] = []
|
||
saw_text = False
|
||
saw_notepad = False
|
||
for ev in data:
|
||
ev_type = ev.get("event_type")
|
||
selector = ev.get("uia_selector") or {}
|
||
if ev_type == "mouse_click":
|
||
steps.append({"action": "click", "target": selector})
|
||
elif ev_type == "text_input":
|
||
saw_text = True
|
||
params.setdefault("text", ev.get("text", ""))
|
||
steps.append({"action": "type", "target": selector, "text": "{{text}}"})
|
||
if ev.get("window_title") and "记事本" in ev.get("window_title", ""):
|
||
saw_notepad = True
|
||
if saw_notepad and saw_text:
|
||
assertions.append("文本已输入记事本")
|
||
steps.append({"action": "click", "target": {"Name": "保存", "ControlType": "Button"}})
|
||
if not assertions:
|
||
assertions.append("关键控件存在")
|
||
spec = {
|
||
"params": params,
|
||
"steps": steps or [{"action": "assert_exists", "target": {"Name": "dummy"}}],
|
||
"assertions": assertions,
|
||
"retry_policy": {"max_attempts": 2, "interval": 1.0},
|
||
"waits": {"appear": 5.0, "disappear": 5.0},
|
||
}
|
||
return json.dumps(spec, ensure_ascii=False)
|
||
|
||
|
||
class OpenAIVisionClient(LLMClient):
|
||
"""兼容 OpenAI 接口的多模态客户端,支持自定义 base_url 和 model"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
model: str = "gpt-5.1-high",
|
||
base_url: str = "https://api.wgetai.com/v1",
|
||
timeout: float = 120.0,
|
||
retries: int = 1,
|
||
) -> None:
|
||
self.api_key = api_key
|
||
self.model = model
|
||
self.base_url = base_url.rstrip("/")
|
||
self.timeout = timeout
|
||
self.retries = max(0, retries)
|
||
|
||
def generate(self, system_prompt: str, user_prompt: str, images: Optional[List[Dict[str, Any]]] = None) -> str:
|
||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||
content: List[Dict[str, Any]] = [{"type": "text", "text": user_prompt}]
|
||
for img in images or []:
|
||
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img['b64']}"}})
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": content},
|
||
],
|
||
"temperature": 0.2,
|
||
}
|
||
url = f"{self.base_url}/chat/completions"
|
||
last_err: Optional[Exception] = None
|
||
for attempt in range(self.retries + 1):
|
||
try:
|
||
resp = requests.post(url, headers=headers, json=payload, timeout=self.timeout)
|
||
resp.raise_for_status()
|
||
text = resp.json()["choices"][0]["message"]["content"]
|
||
return text
|
||
except Exception as exc: # noqa: BLE001
|
||
last_err = exc
|
||
if attempt < self.retries:
|
||
continue
|
||
raise
|
||
raise last_err or RuntimeError("LLM 调用失败")
|
||
|
||
|
||
# ---------------- 数据加载与压缩 ----------------
|
||
def _load_events(session_dir: Path) -> List[EventRecord]:
|
||
events_path = session_dir / "events.jsonl"
|
||
events: List[EventRecord] = []
|
||
with events_path.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
events.append(_model_validate(EventRecord, json.loads(line)))
|
||
return events
|
||
|
||
|
||
def _load_snapshot(path: Optional[str]) -> Optional[UISnapshot]:
|
||
if not path:
|
||
return None
|
||
p = Path(path)
|
||
if not p.exists():
|
||
return None
|
||
with p.open("r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return _model_validate(UISnapshot, data)
|
||
|
||
|
||
def _best_image(frame_paths: Optional[FramePaths]) -> Optional[str]:
|
||
if not frame_paths:
|
||
return None
|
||
for cand in [frame_paths.crop_element, frame_paths.crop_mouse, frame_paths.full]:
|
||
if cand and Path(cand).exists():
|
||
return cand
|
||
return None
|
||
|
||
|
||
def _selector_summary(selector: Optional[UISelector]) -> Dict[str, Any]:
|
||
if not selector:
|
||
return {}
|
||
return {
|
||
"AutomationId": selector.automation_id,
|
||
"Name": selector.name,
|
||
"ClassName": selector.class_name,
|
||
"ControlType": selector.control_type,
|
||
}
|
||
|
||
|
||
def _compress_tree(snapshot: Optional[UISnapshot], selector: Optional[UISelector]) -> List[Dict[str, Any]]:
|
||
"""压缩 UI 树:保留深度<=2,或与命中控件同名/同类型的兄弟"""
|
||
if not snapshot:
|
||
return []
|
||
nodes = []
|
||
for node in snapshot.tree:
|
||
if node.depth <= 2:
|
||
nodes.append(_model_dump(node, exclude_none=True))
|
||
else:
|
||
if selector and (node.name == selector.name or node.control_type == selector.control_type):
|
||
nodes.append(_model_dump(node, exclude_none=True))
|
||
return nodes
|
||
|
||
|
||
def _encode_image_b64(path: Optional[str]) -> Optional[str]:
|
||
if not path:
|
||
return None
|
||
try:
|
||
with open(path, "rb") as f:
|
||
return base64.b64encode(f.read()).decode("ascii")
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _pack_events(events: List[EventRecord], multimodal: bool) -> List[Dict[str, Any]]:
|
||
packed: List[Dict[str, Any]] = []
|
||
for ev in events:
|
||
if ev.event_type not in {"mouse_click", "text_input", "window_change"}:
|
||
continue
|
||
img_path = _best_image(ev.frame_paths)
|
||
snapshot = _load_snapshot(ev.ui_snapshot)
|
||
selector = ev.uia
|
||
tree = _compress_tree(snapshot, selector)
|
||
item: Dict[str, Any] = {
|
||
"event_type": ev.event_type,
|
||
"ts": ev.ts,
|
||
"video_time_offset_ms": ev.video_time_offset_ms,
|
||
"text": ev.text,
|
||
"window_title": ev.window.title if ev.window else None,
|
||
"window_process": ev.window.process_name if ev.window else None,
|
||
"uia_selector": _selector_summary(selector),
|
||
"uia_tree": tree,
|
||
"frame_path": img_path,
|
||
}
|
||
if multimodal and img_path:
|
||
b64 = _encode_image_b64(img_path)
|
||
if b64:
|
||
item["image_base64"] = b64
|
||
packed.append(item)
|
||
return packed
|
||
|
||
|
||
# ---------------- 主入口 ----------------
|
||
def infer_session(
|
||
session_dir: Path,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
model: str = "gpt-5.1-high",
|
||
timeout: float = 120.0,
|
||
retries: int = 1,
|
||
) -> DSLSpec:
|
||
"""读取 session 目录,返回 DSLSpec"""
|
||
events = _load_events(session_dir)
|
||
multimodal = api_key is not None
|
||
packed = _pack_events(events, multimodal=multimodal)
|
||
user_prompt = render_user_prompt(packed)
|
||
client: LLMClient
|
||
images_payload = [{"b64": e["image_base64"]} for e in packed if "image_base64" in e] if multimodal else None
|
||
|
||
raw: str
|
||
if multimodal:
|
||
client = OpenAIVisionClient(
|
||
api_key=api_key,
|
||
base_url=base_url or "https://api.wgetai.com/v1",
|
||
model=model,
|
||
timeout=timeout,
|
||
retries=retries,
|
||
)
|
||
try:
|
||
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=images_payload)
|
||
except Exception as exc: # noqa: BLE001
|
||
print(f"[warn] 多模态归纳失败,降级为文本-only(原因: {exc})")
|
||
client = DummyLLM()
|
||
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None)
|
||
else:
|
||
client = DummyLLM()
|
||
raw = client.generate(SYSTEM_PROMPT, user_prompt, images=None)
|
||
|
||
if not raw or not raw.strip():
|
||
raise RuntimeError("LLM 返回为空,无法解析为 JSON")
|
||
cleaned = _strip_code_fences(raw)
|
||
try:
|
||
spec_dict = json.loads(cleaned)
|
||
except Exception as exc:
|
||
preview = cleaned[:500]
|
||
raise RuntimeError(f"LLM 返回非 JSON,可见前 500 字符: {preview}") from exc
|
||
spec_dict = _coerce_assertions(spec_dict)
|
||
spec_dict = _normalize_steps(spec_dict)
|
||
return _model_validate(DSLSpec, spec_dict)
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="从 session 目录归纳 DSL(支持多模态)")
|
||
parser.add_argument("--session-dir", type=str, required=True, help="session 目录,包含 events.jsonl / manifest.json / frames / ui_snapshots")
|
||
parser.add_argument("--out", type=str, default="dsl.json", help="输出 DSL JSON 路径")
|
||
parser.add_argument("--api-key", type=str, help="LLM API Key,缺省读取环境变量 OPENAI_API_KEY")
|
||
parser.add_argument("--base-url", type=str, default="https://api.wgetai.com/v1", help="LLM Base URL")
|
||
parser.add_argument("--model", type=str, default="gpt-5.1-high", help="LLM 模型名")
|
||
parser.add_argument("--timeout", type=float, default=120.0, help="LLM 请求超时时间(秒)")
|
||
parser.add_argument("--retries", type=int, default=1, help="LLM 请求重试次数(额外重试次数)")
|
||
args = parser.parse_args()
|
||
|
||
_load_env_file()
|
||
|
||
session_dir = Path(args.session_dir)
|
||
api_key = args.api_key or os.environ.get("OPENAI_API_KEY")
|
||
base_url = args.base_url or os.environ.get("OPENAI_BASE_URL")
|
||
|
||
spec = infer_session(
|
||
session_dir,
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=args.model,
|
||
timeout=args.timeout,
|
||
retries=args.retries,
|
||
)
|
||
out_path = Path(args.out)
|
||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with out_path.open("w", encoding="utf-8") as f:
|
||
f.write(json.dumps(_model_dump(spec), ensure_ascii=False, indent=2))
|
||
print(f"DSL 写入: {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|