python-library-ai-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. ai_agent/__init__.py +66 -0
  2. ai_agent/agent.py +122 -0
  3. ai_agent/app/__init__.py +10 -0
  4. ai_agent/app/_workspace.py +127 -0
  5. ai_agent/app/app.py +321 -0
  6. ai_agent/app/harness_io.py +109 -0
  7. ai_agent/app/output_format.py +77 -0
  8. ai_agent/app/packet.py +39 -0
  9. ai_agent/app/session.py +742 -0
  10. ai_agent/app/session_store.py +85 -0
  11. ai_agent/builtin_tools/__init__.py +18 -0
  12. ai_agent/builtin_tools/current_time.py +39 -0
  13. ai_agent/builtin_tools/pack.py +20 -0
  14. ai_agent/builtin_tools/prefix.py +11 -0
  15. ai_agent/context.py +151 -0
  16. ai_agent/harness/__init__.py +3 -0
  17. ai_agent/harness/current_time.py +25 -0
  18. ai_agent/harness/harness.py +324 -0
  19. ai_agent/harness/process.py +115 -0
  20. ai_agent/harness/prompts.py +38 -0
  21. ai_agent/harness/sandbox.py +139 -0
  22. ai_agent/json_extract.py +70 -0
  23. ai_agent/listener.py +172 -0
  24. ai_agent/llm.py +39 -0
  25. ai_agent/llm_openai.py +117 -0
  26. ai_agent/loop.py +124 -0
  27. ai_agent/mcp_config.py +54 -0
  28. ai_agent/mcp_loader.py +110 -0
  29. ai_agent/memory/__init__.py +9 -0
  30. ai_agent/memory/compression_work.py +71 -0
  31. ai_agent/memory/compressor.py +339 -0
  32. ai_agent/memory/config.py +40 -0
  33. ai_agent/memory/context_builder.py +57 -0
  34. ai_agent/memory/memory_system.py +561 -0
  35. ai_agent/memory/models.py +76 -0
  36. ai_agent/memory/snapshot_merge.py +158 -0
  37. ai_agent/memory/store.py +107 -0
  38. ai_agent/memory/worker.py +227 -0
  39. ai_agent/plan/__init__.py +15 -0
  40. ai_agent/plan/complete.py +64 -0
  41. ai_agent/plan/delivery.py +41 -0
  42. ai_agent/plan/display.py +46 -0
  43. ai_agent/plan/models.py +44 -0
  44. ai_agent/plan/parse.py +39 -0
  45. ai_agent/plan/planner.py +204 -0
  46. ai_agent/plan/runner.py +281 -0
  47. ai_agent/react_tool_turn.py +39 -0
  48. ai_agent/rule/__init__.py +3 -0
  49. ai_agent/rule/rules.py +36 -0
  50. ai_agent/skill/__init__.py +5 -0
  51. ai_agent/skill/builtin_registry.py +56 -0
  52. ai_agent/skill/catalog.py +104 -0
  53. ai_agent/skill/frontmatter.py +83 -0
  54. ai_agent/skill/manager.py +486 -0
  55. ai_agent/skill/models.py +31 -0
  56. ai_agent/skill/roots.py +150 -0
  57. ai_agent/skill/skill_kit.py +80 -0
  58. ai_agent/skill/tool_declarations.py +68 -0
  59. ai_agent/tools.py +123 -0
  60. python_library_ai_agent-0.1.0.dist-info/METADATA +10 -0
  61. python_library_ai_agent-0.1.0.dist-info/RECORD +62 -0
  62. python_library_ai_agent-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ import subprocess
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ _DEFAULT_TIMEOUT = 120
8
+ _MAX_TIMEOUT = 600
9
+ _MAX_OUTPUT_CHARS = 100_000
10
+
11
+
12
+ def clamp_timeout(timeout_seconds: int) -> int:
13
+ if timeout_seconds <= 0:
14
+ return _DEFAULT_TIMEOUT
15
+ return min(timeout_seconds, _MAX_TIMEOUT)
16
+
17
+
18
+ def format_process_result(
19
+ *,
20
+ exit_code: int,
21
+ stdout: str,
22
+ stderr: str,
23
+ timed_out: bool = False,
24
+ ) -> str:
25
+ parts: list[str] = []
26
+ if timed_out:
27
+ parts.append("状态: 超时")
28
+ else:
29
+ parts.append(f"退出码: {exit_code}")
30
+ if stdout:
31
+ parts.append("--- stdout ---")
32
+ parts.append(stdout)
33
+ if stderr:
34
+ parts.append("--- stderr ---")
35
+ parts.append(stderr)
36
+ text = "\n".join(parts)
37
+ if len(text) > _MAX_OUTPUT_CHARS:
38
+ overflow = len(text) - _MAX_OUTPUT_CHARS
39
+ text = text[:_MAX_OUTPUT_CHARS] + f"\n...(输出已截断,省略 {overflow} 字符)"
40
+ return text
41
+
42
+
43
+ def run_shell(
44
+ command: str,
45
+ *,
46
+ work_dir: Path,
47
+ timeout_seconds: int = 0,
48
+ ) -> str:
49
+ command = command.strip()
50
+ if not command:
51
+ raise ValueError("command 不能为空")
52
+ timeout = clamp_timeout(timeout_seconds)
53
+ try:
54
+ completed = subprocess.run(
55
+ command,
56
+ shell=True,
57
+ cwd=work_dir,
58
+ capture_output=True,
59
+ text=True,
60
+ encoding="utf-8",
61
+ errors="replace",
62
+ timeout=timeout,
63
+ stdin=subprocess.DEVNULL,
64
+ )
65
+ except subprocess.TimeoutExpired as exc:
66
+ stdout = (exc.stdout or "") if isinstance(exc.stdout, str) else ""
67
+ stderr = (exc.stderr or "") if isinstance(exc.stderr, str) else ""
68
+ return format_process_result(
69
+ exit_code=-1,
70
+ stdout=stdout,
71
+ stderr=stderr,
72
+ timed_out=True,
73
+ )
74
+ return format_process_result(
75
+ exit_code=completed.returncode,
76
+ stdout=completed.stdout,
77
+ stderr=completed.stderr,
78
+ )
79
+
80
+
81
+ def run_python(
82
+ code: str,
83
+ *,
84
+ work_dir: Path,
85
+ timeout_seconds: int = 0,
86
+ ) -> str:
87
+ code = code.strip()
88
+ if not code:
89
+ raise ValueError("code 不能为空")
90
+ timeout = clamp_timeout(timeout_seconds)
91
+ try:
92
+ completed = subprocess.run(
93
+ [sys.executable, "-c", code],
94
+ cwd=work_dir,
95
+ capture_output=True,
96
+ text=True,
97
+ encoding="utf-8",
98
+ errors="replace",
99
+ timeout=timeout,
100
+ stdin=subprocess.DEVNULL,
101
+ )
102
+ except subprocess.TimeoutExpired as exc:
103
+ stdout = (exc.stdout or "") if isinstance(exc.stdout, str) else ""
104
+ stderr = (exc.stderr or "") if isinstance(exc.stderr, str) else ""
105
+ return format_process_result(
106
+ exit_code=-1,
107
+ stdout=stdout,
108
+ stderr=stderr,
109
+ timed_out=True,
110
+ )
111
+ return format_process_result(
112
+ exit_code=completed.returncode,
113
+ stdout=completed.stdout,
114
+ stderr=completed.stderr,
115
+ )
@@ -0,0 +1,38 @@
1
+ PLANNING_SYSTEM_PROMPT = """\
2
+
3
+ 你是任务规划助手。根据用户请求与下文业务规则、技能及工具说明,输出一份严格串行、不可并行的执行计划。
4
+
5
+
6
+
7
+ 规则:
8
+
9
+
10
+
11
+ - 若用户请求单一、无需工具,可只规划 1 步;较复杂或多阶段任务建议 2 到 6 步,每步只承担一个清晰目标;须体现规则与技能中的流程要求,用自然语言写 objective,勿套用应用外的硬性步骤模板。
12
+
13
+ - 搜索后聊天短答类任务:固定 2 步(搜索 → 按技能改写交付终稿);勿把 enable_skill 与阅读技能拆成两步,勿单独规划仅调用 builtin__current_time 的步骤——涉及时效时搜索步 objective 须写明「先调用 builtin__current_time 再搜索」,并将 builtin__current_time 写入 hint_tools;每一轮含时效的搜索步均须如此,不得因会话已用过时间工具而省略 hint_tools 或 objective 中的取时说明。
14
+
15
+ - 终稿步 objective 须写明改写 skill 路径(如 skills/chat-search-answer);框架会在该步预载技能正文,required_tool 勿填 enable_skill(填 null),改写须按该 skill 执行。用户要求简短、简要、一句话总结时,终稿步 objective 须写明按 chat-search-answer 的 brief 篇幅(约 80–160 汉字)。
16
+
17
+ - 规划 objective 勿写死具体公历日期或年份;时间窗口由执行阶段在调用 builtin__current_time 后确定。勿在 objective 里写与业务规则可能冲突的年份字面量。执行阶段取时与搜索须分两回合发起,勿在同一助手回合并行 tool call。
18
+
19
+ - 仅输出一个 JSON 对象,不要 Markdown 代码围栏,不要额外说明;summary 一句即可,勿在 JSON 外复述业务规则全文。
20
+
21
+ - JSON 结构:
22
+
23
+ {"summary": "可选一句规划说明", "steps": [
24
+
25
+ {"id": "step-1", "title": "...", "objective": "...", "optional": false,
26
+
27
+ "hint_tools": [], "required_tool": null}
28
+
29
+ ]}
30
+
31
+ - id 使用 step-1、step-2 等形式,互不重复。
32
+
33
+ - optional 为 true 表示执行阶段可判定跳过(例如终稿改写已有足够短答时)。
34
+
35
+ - hint_tools、required_tool 填对外工具名(如 server__tool);无则空列表或 null。
36
+
37
+ """
38
+
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ import fnmatch
4
+ from pathlib import Path
5
+
6
+ _MAX_READ_BYTES = 512 * 1024
7
+
8
+
9
+ class HarnessSandbox:
10
+ """将相对路径限定在构造时指定的工作区根目录内。"""
11
+
12
+ def __init__(self, workspace: Path) -> None:
13
+ root = workspace.expanduser().resolve()
14
+ root.mkdir(parents=True, exist_ok=True)
15
+ if not root.is_dir():
16
+ raise ValueError(f"工作区须为目录: {root}")
17
+ self._root = root
18
+
19
+ @property
20
+ def root(self) -> Path:
21
+ return self._root
22
+
23
+ def resolve_path(self, path: str) -> Path:
24
+ """
25
+ 解析相对工作区的路径并拒绝越界。
26
+
27
+ Args:
28
+ path: 相对工作区根的路径;空字符串表示工作区根
29
+
30
+ Returns:
31
+ 解析后的绝对路径
32
+
33
+ Raises:
34
+ ValueError: 路径非法或越出工作区
35
+ """
36
+ cleaned = path.strip()
37
+ if not cleaned or cleaned in (".", "./"):
38
+ return self._root
39
+ if Path(cleaned).is_absolute():
40
+ raise ValueError("path 须为相对工作区的路径")
41
+ target = (self._root / cleaned).resolve()
42
+ try:
43
+ target.relative_to(self._root)
44
+ except ValueError as exc:
45
+ raise ValueError(f"路径越出工作区: {path}") from exc
46
+ return target
47
+
48
+ def read_text_file(self, path: str, offset: int = 1, limit: int = 0) -> str:
49
+ """
50
+ 读取文本文件片段。
51
+
52
+ Args:
53
+ path: 相对工作区路径
54
+ offset: 起始行号,从 1 起
55
+ limit: 最多读取行数;0 表示读到文件末尾
56
+
57
+ Returns:
58
+ 带行号前缀的文本
59
+ """
60
+ if offset < 1:
61
+ raise ValueError("offset 须 >= 1")
62
+ if limit < 0:
63
+ raise ValueError("limit 须 >= 0")
64
+ target = self.resolve_path(path)
65
+ if not target.is_file():
66
+ raise ValueError(f"不是文件: {path}")
67
+ size = target.stat().st_size
68
+ if size > _MAX_READ_BYTES:
69
+ raise ValueError(f"文件过大(>{_MAX_READ_BYTES} 字节): {path}")
70
+ lines = target.read_text(encoding="utf-8", errors="replace").splitlines()
71
+ start = offset - 1
72
+ if start >= len(lines):
73
+ return ""
74
+ if limit > 0:
75
+ chunk = lines[start : start + limit]
76
+ else:
77
+ chunk = lines[start:]
78
+ width = len(str(start + len(chunk)))
79
+ numbered: list[str] = []
80
+ for index, line in enumerate(chunk, start=offset):
81
+ numbered.append(f"{index:>{width}}|{line}")
82
+ return "\n".join(numbered)
83
+
84
+ def write_text_file(self, path: str, content: str, append: bool = False) -> str:
85
+ """
86
+ 写入或追加文本文件。
87
+
88
+ Args:
89
+ path: 相对工作区路径
90
+ content: 写入内容
91
+ append: 为 True 时追加,否则覆盖
92
+
93
+ Returns:
94
+ 简短结果说明
95
+ """
96
+ target = self.resolve_path(path)
97
+ target.parent.mkdir(parents=True, exist_ok=True)
98
+ mode = "a" if append else "w"
99
+ with target.open(mode, encoding="utf-8", newline="\n") as handle:
100
+ handle.write(content)
101
+ action = "已追加" if append else "已写入"
102
+ return f"{action} {len(content)} 字符到 {path}"
103
+
104
+ def list_entries(
105
+ self,
106
+ path: str = "",
107
+ *,
108
+ max_entries: int = 200,
109
+ pattern: str = "",
110
+ ) -> list[str]:
111
+ """
112
+ 列出工作区内相对路径(文件与目录)。
113
+
114
+ Args:
115
+ path: 相对工作区根的子目录;空表示整棵工作区
116
+ max_entries: 最多返回条数
117
+ pattern: 可选 glob 片段,仅保留路径匹配者
118
+
119
+ Returns:
120
+ 相对路径列表,字典序
121
+ """
122
+ if max_entries < 1:
123
+ raise ValueError("max_entries 须 >= 1")
124
+ base = self.resolve_path(path)
125
+ if not base.is_dir():
126
+ raise ValueError(f"不是目录: {path or '.'}")
127
+ entries: list[str] = []
128
+ for item in sorted(base.rglob("*")):
129
+ rel = item.relative_to(self._root).as_posix()
130
+ if item.is_dir():
131
+ rel = f"{rel}/"
132
+ if pattern.strip() and not fnmatch.fnmatch(rel, pattern.strip()):
133
+ continue
134
+ entries.append(rel)
135
+ if len(entries) >= max_entries:
136
+ break
137
+ if not path.strip() and not entries:
138
+ entries.append("./")
139
+ return entries
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from typing import Any
6
+
7
+ _JSON_DECODER = json.JSONDecoder()
8
+
9
+ _FENCE_RE = re.compile(
10
+ r"```(?:json)?\s*([\s\S]*?)\s*```",
11
+ re.IGNORECASE,
12
+ )
13
+
14
+
15
+ def extract_first_json_value(text: str) -> Any | None:
16
+ """
17
+ 从可能含说明文字、围栏或多个 JSON 的文本中解析第一个完整 JSON 值。
18
+
19
+ Args:
20
+ text: 模型原始输出
21
+
22
+ Returns:
23
+ 解析到的对象或数组;无法解析时为 None
24
+ """
25
+ stripped = text.strip()
26
+ if not stripped:
27
+ return None
28
+ candidates: list[str] = []
29
+ for match in _FENCE_RE.finditer(stripped):
30
+ inner = match.group(1).strip()
31
+ if inner:
32
+ candidates.append(inner)
33
+ candidates.append(stripped)
34
+ for candidate in candidates:
35
+ value = _decode_first_json(candidate)
36
+ if value is not None:
37
+ return value
38
+ return None
39
+
40
+
41
+ def extract_first_json_object(text: str) -> dict[str, Any] | None:
42
+ """
43
+ 解析第一个 JSON 对象(dict)。
44
+
45
+ Args:
46
+ text: 模型原始输出
47
+
48
+ Returns:
49
+ 根节点为对象时返回该 dict,否则 None
50
+ """
51
+ value = extract_first_json_value(text)
52
+ if isinstance(value, dict):
53
+ return value
54
+ return None
55
+
56
+
57
+ def _decode_first_json(text: str) -> Any | None:
58
+ idx = 0
59
+ length = len(text)
60
+ while idx < length:
61
+ ch = text[idx]
62
+ if ch in "{[":
63
+ try:
64
+ value, _end = _JSON_DECODER.raw_decode(text, idx)
65
+ except json.JSONDecodeError:
66
+ idx += 1
67
+ continue
68
+ return value
69
+ idx += 1
70
+ return None
ai_agent/listener.py ADDED
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable, Callable, Iterable, Sequence
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from ai_agent.context import RunContext, ToolInvocation
9
+
10
+ if TYPE_CHECKING:
11
+ from ai_agent.app.packet import RunOutputPacket
12
+ from ai_agent.plan.models import Plan, PlanStep
13
+
14
+ ListenerCallback = Callable[..., Any] | Callable[..., Awaitable[Any]]
15
+
16
+
17
+ @dataclass
18
+ class AgentListener:
19
+ """
20
+ 规划、逐步执行与应用交付的可选回调集合。
21
+
22
+ 流式思考与回答经 ``on_thinking_delta`` / ``on_output_delta`` 推送,
23
+ 调用方可读 ``RunContext.phase`` 区分规划、计划步与直连 ReAct。
24
+ 各钩子未设置时不调用;回调可为同步或 async 函数。
25
+ """
26
+
27
+ on_run_start: ListenerCallback | None = None
28
+ on_run_end: ListenerCallback | None = None
29
+ on_thinking_delta: ListenerCallback | None = None
30
+ on_output_delta: ListenerCallback | None = None
31
+ on_tool_start: ListenerCallback | None = None
32
+ on_tool_end: ListenerCallback | None = None
33
+ on_plan_start: ListenerCallback | None = None
34
+ on_plan_ready: ListenerCallback | None = None
35
+ on_plan_step_start: ListenerCallback | None = None
36
+ on_plan_step_end: ListenerCallback | None = None
37
+ on_app_run_end: ListenerCallback | None = None
38
+
39
+
40
+ def normalize_listeners(
41
+ listeners: AgentListener | Iterable[AgentListener] | None,
42
+ ) -> list[AgentListener]:
43
+ """
44
+ 将单个或多个 listener 规范为列表。
45
+
46
+ Args:
47
+ listeners: 单个 ``AgentListener``、序列,或 ``None``
48
+
49
+ Returns:
50
+ 供 ``AgentContext`` 使用的 listener 列表
51
+ """
52
+ if listeners is None:
53
+ return []
54
+ if isinstance(listeners, AgentListener):
55
+ return [listeners]
56
+ return list(listeners)
57
+
58
+
59
+ async def _invoke(callback: ListenerCallback, *args: Any) -> None:
60
+ result = callback(*args)
61
+ if inspect.isawaitable(result):
62
+ await result
63
+
64
+
65
+ async def notify_run_start(listeners: Sequence[AgentListener], run: RunContext) -> None:
66
+ for listener in listeners:
67
+ if listener.on_run_start is not None:
68
+ await _invoke(listener.on_run_start, run)
69
+
70
+
71
+ async def notify_run_end(listeners: Sequence[AgentListener], run: RunContext) -> None:
72
+ for listener in listeners:
73
+ if listener.on_run_end is not None:
74
+ await _invoke(listener.on_run_end, run)
75
+
76
+
77
+ async def notify_thinking_delta(
78
+ listeners: Sequence[AgentListener],
79
+ delta: str,
80
+ run: RunContext,
81
+ ) -> None:
82
+ if not delta:
83
+ return
84
+ for listener in listeners:
85
+ if listener.on_thinking_delta is not None:
86
+ await _invoke(listener.on_thinking_delta, delta, run)
87
+
88
+
89
+ async def notify_output_delta(
90
+ listeners: Sequence[AgentListener],
91
+ delta: str,
92
+ run: RunContext,
93
+ ) -> None:
94
+ if not delta:
95
+ return
96
+ for listener in listeners:
97
+ if listener.on_output_delta is not None:
98
+ await _invoke(listener.on_output_delta, delta, run)
99
+
100
+
101
+ async def notify_tool_start(
102
+ listeners: Sequence[AgentListener],
103
+ invocation: ToolInvocation,
104
+ run: RunContext,
105
+ ) -> None:
106
+ for listener in listeners:
107
+ if listener.on_tool_start is not None:
108
+ await _invoke(listener.on_tool_start, invocation, run)
109
+
110
+
111
+ async def notify_tool_end(
112
+ listeners: Sequence[AgentListener],
113
+ invocation: ToolInvocation,
114
+ run: RunContext,
115
+ ) -> None:
116
+ for listener in listeners:
117
+ if listener.on_tool_end is not None:
118
+ await _invoke(listener.on_tool_end, invocation, run)
119
+
120
+
121
+ async def notify_plan_start(listeners: Sequence[AgentListener]) -> None:
122
+ for listener in listeners:
123
+ if listener.on_plan_start is not None:
124
+ await _invoke(listener.on_plan_start)
125
+
126
+
127
+ async def notify_plan_ready(listeners: Sequence[AgentListener], plan: Plan) -> None:
128
+ for listener in listeners:
129
+ if listener.on_plan_ready is not None:
130
+ await _invoke(listener.on_plan_ready, plan)
131
+
132
+
133
+ async def notify_plan_step_start(
134
+ listeners: Sequence[AgentListener],
135
+ *,
136
+ step_index: int,
137
+ step: PlanStep,
138
+ plan: Plan,
139
+ ) -> None:
140
+ for listener in listeners:
141
+ if listener.on_plan_step_start is not None:
142
+ await _invoke(listener.on_plan_step_start, step_index, step, plan)
143
+
144
+
145
+ async def notify_plan_step_end(
146
+ listeners: Sequence[AgentListener],
147
+ *,
148
+ step_index: int,
149
+ step: PlanStep,
150
+ plan: Plan,
151
+ output: str,
152
+ skipped: bool,
153
+ ) -> None:
154
+ for listener in listeners:
155
+ if listener.on_plan_step_end is not None:
156
+ await _invoke(
157
+ listener.on_plan_step_end,
158
+ step_index,
159
+ step,
160
+ plan,
161
+ output,
162
+ skipped,
163
+ )
164
+
165
+
166
+ async def notify_app_run_end(
167
+ listeners: Sequence[AgentListener],
168
+ packet: RunOutputPacket,
169
+ ) -> None:
170
+ for listener in listeners:
171
+ if listener.on_app_run_end is not None:
172
+ await _invoke(listener.on_app_run_end, packet)
ai_agent/llm.py ADDED
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, AsyncIterator, Protocol
6
+
7
+ from ai_agent.context import RunContext
8
+
9
+
10
+ class StreamKind(str, Enum):
11
+ """语言模型流式片段类型(库内统一)。"""
12
+
13
+ TEXT = "text"
14
+ REASONING = "reasoning"
15
+ TOOL_CALL = "tool_call"
16
+ DONE = "done"
17
+
18
+
19
+ @dataclass
20
+ class StreamChunk:
21
+ """语言模型流式输出的一段;各适配器均映射为本类型。"""
22
+
23
+ kind: StreamKind
24
+ delta: str = ""
25
+ tool_call_id: str | None = None
26
+ tool_name: str | None = None
27
+ tool_arguments: dict[str, Any] | None = None
28
+
29
+
30
+ class LLMClient(Protocol):
31
+ """语言模型流式接口;库内由 OpenAILLM 实现。"""
32
+
33
+ async def stream(
34
+ self,
35
+ context: RunContext,
36
+ *,
37
+ tools: list[dict[str, Any]] | None = None,
38
+ ) -> AsyncIterator[StreamChunk]:
39
+ """按当前上下文流式请求语言模型。"""
ai_agent/llm_openai.py ADDED
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any, AsyncIterator
5
+
6
+ from ai_agent.context import RunContext
7
+ from ai_agent.llm import StreamChunk, StreamKind
8
+
9
+
10
+ def _merge_extra_body(kwargs: dict[str, Any], patch: dict[str, Any]) -> None:
11
+ extra = dict(kwargs.get("extra_body") or {})
12
+ extra.update(patch)
13
+ kwargs["extra_body"] = extra
14
+
15
+
16
+ def _apply_thinking_kwargs(kwargs: dict[str, Any], *, base_url: str) -> None:
17
+ """按网关形态写入思考模式参数(DeepSeek 用 extra_body,其余兼容端用 enable_thinking)。"""
18
+ if "deepseek" in base_url.lower():
19
+ _merge_extra_body(kwargs, {"thinking": {"type": "enabled"}})
20
+ return
21
+ kwargs["enable_thinking"] = True
22
+
23
+
24
+ class OpenAILLM:
25
+ """库内:包装 ``AsyncOpenAI``,将官方流式事件映射为 ``StreamChunk``。"""
26
+
27
+ def __init__(
28
+ self,
29
+ client: Any,
30
+ *,
31
+ model: str,
32
+ base_url: str = "",
33
+ temperature: float | None = None,
34
+ max_tokens: int | None = None,
35
+ thinking_enabled: bool = False,
36
+ ) -> None:
37
+ self._client = client
38
+ self.model = model
39
+ self._base_url = base_url.strip()
40
+ self.temperature = temperature
41
+ self.max_tokens = max_tokens
42
+ self.thinking_enabled = thinking_enabled
43
+
44
+ async def stream(
45
+ self,
46
+ context: RunContext,
47
+ *,
48
+ tools: list[dict[str, Any]] | None = None,
49
+ ) -> AsyncIterator[StreamChunk]:
50
+ kwargs: dict[str, Any] = {
51
+ "model": self.model,
52
+ "messages": context.api_messages(),
53
+ "stream": True,
54
+ }
55
+ if tools:
56
+ kwargs["tools"] = tools
57
+ if self.temperature is not None:
58
+ kwargs["temperature"] = self.temperature
59
+ if self.max_tokens is not None:
60
+ kwargs["max_tokens"] = self.max_tokens
61
+ if self.thinking_enabled:
62
+ _apply_thinking_kwargs(kwargs, base_url=self._base_url)
63
+
64
+ stream = await self._client.chat.completions.create(**kwargs)
65
+ tool_bufs: dict[int, dict[str, Any]] = {}
66
+
67
+ async for event in stream:
68
+ if not event.choices:
69
+ continue
70
+ choice = event.choices[0]
71
+ delta = choice.delta
72
+
73
+ if delta.content:
74
+ yield StreamChunk(kind=StreamKind.TEXT, delta=delta.content)
75
+
76
+ reasoning = getattr(delta, "reasoning_content", None)
77
+ if reasoning:
78
+ yield StreamChunk(kind=StreamKind.REASONING, delta=reasoning)
79
+
80
+ if delta.tool_calls:
81
+ for tc in delta.tool_calls:
82
+ idx = int(tc.index or 0)
83
+ buf = tool_bufs.setdefault(
84
+ idx,
85
+ {"id": "", "name": "", "arguments": ""},
86
+ )
87
+ if tc.id:
88
+ buf["id"] = tc.id
89
+ if tc.function and tc.function.name:
90
+ buf["name"] = tc.function.name
91
+ if tc.function and tc.function.arguments:
92
+ buf["arguments"] += tc.function.arguments
93
+
94
+ if choice.finish_reason == "tool_calls":
95
+ async for chunk in _flush_tool_bufs(tool_bufs):
96
+ yield chunk
97
+ tool_bufs.clear()
98
+
99
+ async for chunk in _flush_tool_bufs(tool_bufs):
100
+ yield chunk
101
+ yield StreamChunk(kind=StreamKind.DONE)
102
+
103
+
104
+ async def _flush_tool_bufs(tool_bufs: dict[int, dict[str, Any]]) -> AsyncIterator[StreamChunk]:
105
+ for buf in tool_bufs.values():
106
+ args_raw = buf.get("arguments") or "{}"
107
+ try:
108
+ args = json.loads(args_raw)
109
+ except json.JSONDecodeError:
110
+ args = {}
111
+ if buf.get("id") and buf.get("name"):
112
+ yield StreamChunk(
113
+ kind=StreamKind.TOOL_CALL,
114
+ tool_call_id=buf["id"],
115
+ tool_name=buf["name"],
116
+ tool_arguments=args if isinstance(args, dict) else {},
117
+ )