agentrun-inner-test 0.0.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-inner-test might be problematic. Click here for more details.
- agentrun/__init__.py +358 -0
- agentrun/agent_runtime/__client_async_template.py +466 -0
- agentrun/agent_runtime/__endpoint_async_template.py +345 -0
- agentrun/agent_runtime/__init__.py +53 -0
- agentrun/agent_runtime/__runtime_async_template.py +477 -0
- agentrun/agent_runtime/api/__data_async_template.py +58 -0
- agentrun/agent_runtime/api/__init__.py +6 -0
- agentrun/agent_runtime/api/control.py +1362 -0
- agentrun/agent_runtime/api/data.py +98 -0
- agentrun/agent_runtime/client.py +868 -0
- agentrun/agent_runtime/endpoint.py +649 -0
- agentrun/agent_runtime/model.py +362 -0
- agentrun/agent_runtime/runtime.py +904 -0
- agentrun/credential/__client_async_template.py +177 -0
- agentrun/credential/__credential_async_template.py +216 -0
- agentrun/credential/__init__.py +28 -0
- agentrun/credential/api/__init__.py +5 -0
- agentrun/credential/api/control.py +606 -0
- agentrun/credential/client.py +319 -0
- agentrun/credential/credential.py +381 -0
- agentrun/credential/model.py +248 -0
- agentrun/integration/__init__.py +21 -0
- agentrun/integration/agentscope/__init__.py +13 -0
- agentrun/integration/agentscope/adapter.py +17 -0
- agentrun/integration/agentscope/builtin.py +88 -0
- agentrun/integration/agentscope/message_adapter.py +185 -0
- agentrun/integration/agentscope/model_adapter.py +60 -0
- agentrun/integration/agentscope/tool_adapter.py +59 -0
- agentrun/integration/builtin/__init__.py +18 -0
- agentrun/integration/builtin/knowledgebase.py +137 -0
- agentrun/integration/builtin/model.py +93 -0
- agentrun/integration/builtin/sandbox.py +1234 -0
- agentrun/integration/builtin/toolset.py +47 -0
- agentrun/integration/crewai/__init__.py +13 -0
- agentrun/integration/crewai/adapter.py +9 -0
- agentrun/integration/crewai/builtin.py +88 -0
- agentrun/integration/crewai/model_adapter.py +31 -0
- agentrun/integration/crewai/tool_adapter.py +26 -0
- agentrun/integration/google_adk/__init__.py +13 -0
- agentrun/integration/google_adk/adapter.py +15 -0
- agentrun/integration/google_adk/builtin.py +88 -0
- agentrun/integration/google_adk/message_adapter.py +144 -0
- agentrun/integration/google_adk/model_adapter.py +46 -0
- agentrun/integration/google_adk/tool_adapter.py +235 -0
- agentrun/integration/langchain/__init__.py +31 -0
- agentrun/integration/langchain/adapter.py +15 -0
- agentrun/integration/langchain/builtin.py +94 -0
- agentrun/integration/langchain/message_adapter.py +141 -0
- agentrun/integration/langchain/model_adapter.py +37 -0
- agentrun/integration/langchain/tool_adapter.py +50 -0
- agentrun/integration/langgraph/__init__.py +36 -0
- agentrun/integration/langgraph/adapter.py +20 -0
- agentrun/integration/langgraph/agent_converter.py +1073 -0
- agentrun/integration/langgraph/builtin.py +88 -0
- agentrun/integration/pydantic_ai/__init__.py +13 -0
- agentrun/integration/pydantic_ai/adapter.py +13 -0
- agentrun/integration/pydantic_ai/builtin.py +88 -0
- agentrun/integration/pydantic_ai/model_adapter.py +44 -0
- agentrun/integration/pydantic_ai/tool_adapter.py +19 -0
- agentrun/integration/utils/__init__.py +112 -0
- agentrun/integration/utils/adapter.py +560 -0
- agentrun/integration/utils/canonical.py +164 -0
- agentrun/integration/utils/converter.py +134 -0
- agentrun/integration/utils/model.py +110 -0
- agentrun/integration/utils/tool.py +1759 -0
- agentrun/knowledgebase/__client_async_template.py +173 -0
- agentrun/knowledgebase/__init__.py +53 -0
- agentrun/knowledgebase/__knowledgebase_async_template.py +438 -0
- agentrun/knowledgebase/api/__data_async_template.py +414 -0
- agentrun/knowledgebase/api/__init__.py +19 -0
- agentrun/knowledgebase/api/control.py +606 -0
- agentrun/knowledgebase/api/data.py +624 -0
- agentrun/knowledgebase/client.py +311 -0
- agentrun/knowledgebase/knowledgebase.py +748 -0
- agentrun/knowledgebase/model.py +270 -0
- agentrun/memory_collection/__client_async_template.py +178 -0
- agentrun/memory_collection/__init__.py +37 -0
- agentrun/memory_collection/__memory_collection_async_template.py +457 -0
- agentrun/memory_collection/api/__init__.py +5 -0
- agentrun/memory_collection/api/control.py +610 -0
- agentrun/memory_collection/client.py +323 -0
- agentrun/memory_collection/memory_collection.py +844 -0
- agentrun/memory_collection/model.py +162 -0
- agentrun/model/__client_async_template.py +357 -0
- agentrun/model/__init__.py +57 -0
- agentrun/model/__model_proxy_async_template.py +270 -0
- agentrun/model/__model_service_async_template.py +267 -0
- agentrun/model/api/__init__.py +6 -0
- agentrun/model/api/control.py +1173 -0
- agentrun/model/api/data.py +196 -0
- agentrun/model/client.py +674 -0
- agentrun/model/model.py +235 -0
- agentrun/model/model_proxy.py +439 -0
- agentrun/model/model_service.py +438 -0
- agentrun/sandbox/__aio_sandbox_async_template.py +523 -0
- agentrun/sandbox/__browser_sandbox_async_template.py +110 -0
- agentrun/sandbox/__client_async_template.py +491 -0
- agentrun/sandbox/__code_interpreter_sandbox_async_template.py +463 -0
- agentrun/sandbox/__init__.py +69 -0
- agentrun/sandbox/__sandbox_async_template.py +463 -0
- agentrun/sandbox/__template_async_template.py +152 -0
- agentrun/sandbox/aio_sandbox.py +912 -0
- agentrun/sandbox/api/__aio_data_async_template.py +335 -0
- agentrun/sandbox/api/__browser_data_async_template.py +140 -0
- agentrun/sandbox/api/__code_interpreter_data_async_template.py +206 -0
- agentrun/sandbox/api/__init__.py +19 -0
- agentrun/sandbox/api/__sandbox_data_async_template.py +107 -0
- agentrun/sandbox/api/aio_data.py +551 -0
- agentrun/sandbox/api/browser_data.py +172 -0
- agentrun/sandbox/api/code_interpreter_data.py +396 -0
- agentrun/sandbox/api/control.py +1051 -0
- agentrun/sandbox/api/playwright_async.py +492 -0
- agentrun/sandbox/api/playwright_sync.py +492 -0
- agentrun/sandbox/api/sandbox_data.py +154 -0
- agentrun/sandbox/browser_sandbox.py +185 -0
- agentrun/sandbox/client.py +925 -0
- agentrun/sandbox/code_interpreter_sandbox.py +823 -0
- agentrun/sandbox/model.py +384 -0
- agentrun/sandbox/sandbox.py +848 -0
- agentrun/sandbox/template.py +217 -0
- agentrun/server/__init__.py +191 -0
- agentrun/server/agui_normalizer.py +180 -0
- agentrun/server/agui_protocol.py +797 -0
- agentrun/server/invoker.py +309 -0
- agentrun/server/model.py +427 -0
- agentrun/server/openai_protocol.py +535 -0
- agentrun/server/protocol.py +140 -0
- agentrun/server/server.py +208 -0
- agentrun/toolset/__client_async_template.py +62 -0
- agentrun/toolset/__init__.py +51 -0
- agentrun/toolset/__toolset_async_template.py +204 -0
- agentrun/toolset/api/__init__.py +17 -0
- agentrun/toolset/api/control.py +262 -0
- agentrun/toolset/api/mcp.py +100 -0
- agentrun/toolset/api/openapi.py +1251 -0
- agentrun/toolset/client.py +102 -0
- agentrun/toolset/model.py +321 -0
- agentrun/toolset/toolset.py +271 -0
- agentrun/utils/__data_api_async_template.py +721 -0
- agentrun/utils/__init__.py +5 -0
- agentrun/utils/__resource_async_template.py +158 -0
- agentrun/utils/config.py +270 -0
- agentrun/utils/control_api.py +105 -0
- agentrun/utils/data_api.py +1121 -0
- agentrun/utils/exception.py +151 -0
- agentrun/utils/helper.py +108 -0
- agentrun/utils/log.py +77 -0
- agentrun/utils/model.py +168 -0
- agentrun/utils/resource.py +291 -0
- agentrun_inner_test-0.0.62.dist-info/METADATA +265 -0
- agentrun_inner_test-0.0.62.dist-info/RECORD +154 -0
- agentrun_inner_test-0.0.62.dist-info/WHEEL +5 -0
- agentrun_inner_test-0.0.62.dist-info/licenses/LICENSE +201 -0
- agentrun_inner_test-0.0.62.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1073 @@
|
|
|
1
|
+
"""LangGraph/LangChain 事件转换模块 / LangGraph/LangChain Event Converter
|
|
2
|
+
|
|
3
|
+
提供将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件的方法。
|
|
4
|
+
|
|
5
|
+
使用示例:
|
|
6
|
+
|
|
7
|
+
# 使用 AgentRunConverter 类(推荐)
|
|
8
|
+
>>> converter = AgentRunConverter()
|
|
9
|
+
>>> async for event in agent.astream_events(input_data, version="v2"):
|
|
10
|
+
... for item in converter.convert(event):
|
|
11
|
+
... yield item
|
|
12
|
+
|
|
13
|
+
# 使用静态方法(无状态)
|
|
14
|
+
>>> async for event in agent.astream_events(input_data, version="v2"):
|
|
15
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
16
|
+
... yield item
|
|
17
|
+
|
|
18
|
+
# 使用 stream (updates 模式)
|
|
19
|
+
>>> for event in agent.stream(input_data, stream_mode="updates"):
|
|
20
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
21
|
+
... yield item
|
|
22
|
+
|
|
23
|
+
# 使用 astream (updates 模式)
|
|
24
|
+
>>> async for event in agent.astream(input_data, stream_mode="updates"):
|
|
25
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
26
|
+
... yield item
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import json
|
|
30
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
31
|
+
|
|
32
|
+
from agentrun.server.model import AgentResult, EventType
|
|
33
|
+
from agentrun.utils.log import logger
|
|
34
|
+
|
|
35
|
+
# 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象)
|
|
36
|
+
_TOOL_INPUT_INTERNAL_KEYS = frozenset({
|
|
37
|
+
"runtime", # MCP ToolRuntime 对象
|
|
38
|
+
"__pregel_runtime",
|
|
39
|
+
"__pregel_task_id",
|
|
40
|
+
"__pregel_send",
|
|
41
|
+
"__pregel_read",
|
|
42
|
+
"__pregel_checkpointer",
|
|
43
|
+
"__pregel_scratchpad",
|
|
44
|
+
"__pregel_call",
|
|
45
|
+
"config", # LangGraph config 对象,包含内部状态
|
|
46
|
+
"configurable",
|
|
47
|
+
})
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AgentRunConverter:
|
|
51
|
+
"""AgentRun 事件转换器
|
|
52
|
+
|
|
53
|
+
将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件。
|
|
54
|
+
此类维护必要的状态以确保:
|
|
55
|
+
1. 流式工具调用的 tool_call_id 一致性
|
|
56
|
+
2. AG-UI 协议要求的事件顺序(TOOL_CALL_START → TOOL_CALL_ARGS → TOOL_CALL_END)
|
|
57
|
+
|
|
58
|
+
在流式工具调用中,第一个 chunk 包含 id 和 name,后续 chunk 只有 index 和 args。
|
|
59
|
+
此类维护 index -> id 的映射,确保所有相关事件使用相同的 tool_call_id。
|
|
60
|
+
|
|
61
|
+
同时,此类跟踪已发送 TOOL_CALL_START 的工具调用,确保:
|
|
62
|
+
- 在流式场景中,TOOL_CALL_START 在第一个参数 chunk 前发送
|
|
63
|
+
- 避免在 on_tool_start 中重复发送 TOOL_CALL_START
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> from agentrun.integration.langgraph import AgentRunConverter
|
|
67
|
+
>>>
|
|
68
|
+
>>> async def invoke_agent(request: AgentRequest):
|
|
69
|
+
... converter = AgentRunConverter()
|
|
70
|
+
... async for event in agent.astream_events(input, version="v2"):
|
|
71
|
+
... for item in converter.convert(event):
|
|
72
|
+
... yield item
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self) -> None:
|
|
76
|
+
self._tool_call_id_map: Dict[int, str] = {}
|
|
77
|
+
self._tool_call_started_set: set = set()
|
|
78
|
+
# tool_name -> [tool_call_id] 队列映射
|
|
79
|
+
# 用于在 on_tool_start 中查找对应的 tool_call_id(当 runtime.tool_call_id 不可用时)
|
|
80
|
+
self._tool_name_to_call_ids: Dict[str, List[str]] = {}
|
|
81
|
+
# run_id -> tool_call_id 映射
|
|
82
|
+
# 用于在 on_tool_end 中查找对应的 tool_call_id
|
|
83
|
+
self._run_id_to_tool_call_id: Dict[str, str] = {}
|
|
84
|
+
|
|
85
|
+
def convert(
|
|
86
|
+
self,
|
|
87
|
+
event: Union[Dict[str, Any], Any],
|
|
88
|
+
messages_key: str = "messages",
|
|
89
|
+
) -> Iterator[Union[AgentResult, str]]:
|
|
90
|
+
"""转换单个事件为 AG-UI 协议事件
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
event: LangGraph/LangChain 流式事件(StreamEvent 对象或 Dict)
|
|
94
|
+
messages_key: state 中消息列表的 key,默认 "messages"
|
|
95
|
+
|
|
96
|
+
Yields:
|
|
97
|
+
str (文本内容) 或 AgentResult (AG-UI 事件)
|
|
98
|
+
"""
|
|
99
|
+
# 调试日志:输入事件
|
|
100
|
+
event_dict = self._event_to_dict(event)
|
|
101
|
+
event_type = event_dict.get("event", "")
|
|
102
|
+
|
|
103
|
+
# 始终打印事件类型,用于调试
|
|
104
|
+
logger.debug(
|
|
105
|
+
f"[AgentRunConverter] Raw event type: {type(event).__name__}, "
|
|
106
|
+
f"event_type={event_type}, "
|
|
107
|
+
f"is_dict={isinstance(event, dict)}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if event_type in (
|
|
111
|
+
"on_chat_model_stream",
|
|
112
|
+
"on_tool_start",
|
|
113
|
+
"on_tool_end",
|
|
114
|
+
):
|
|
115
|
+
logger.debug(
|
|
116
|
+
f"[AgentRunConverter] Input event: type={event_type}, "
|
|
117
|
+
f"run_id={event_dict.get('run_id', '')}, "
|
|
118
|
+
f"name={event_dict.get('name', '')}, "
|
|
119
|
+
f"tool_call_started_set={self._tool_call_started_set}, "
|
|
120
|
+
f"tool_name_to_call_ids={self._tool_name_to_call_ids}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
for item in self.to_agui_events(
|
|
124
|
+
event,
|
|
125
|
+
messages_key,
|
|
126
|
+
self._tool_call_id_map,
|
|
127
|
+
self._tool_call_started_set,
|
|
128
|
+
self._tool_name_to_call_ids,
|
|
129
|
+
self._run_id_to_tool_call_id,
|
|
130
|
+
):
|
|
131
|
+
# 调试日志:输出事件
|
|
132
|
+
if isinstance(item, AgentResult):
|
|
133
|
+
logger.debug(f"[AgentRunConverter] Output event: {item}")
|
|
134
|
+
yield item
|
|
135
|
+
|
|
136
|
+
def reset(self) -> None:
|
|
137
|
+
"""重置状态,清空 tool_call_id 映射和已发送状态
|
|
138
|
+
|
|
139
|
+
在处理新的请求时,建议创建新的 AgentRunConverter 实例,
|
|
140
|
+
而不是复用旧实例并调用 reset。
|
|
141
|
+
"""
|
|
142
|
+
self._tool_call_id_map.clear()
|
|
143
|
+
self._tool_call_started_set.clear()
|
|
144
|
+
self._tool_name_to_call_ids.clear()
|
|
145
|
+
self._run_id_to_tool_call_id.clear()
|
|
146
|
+
|
|
147
|
+
# =========================================================================
|
|
148
|
+
# 内部工具方法(静态方法)
|
|
149
|
+
# =========================================================================
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _format_tool_output(output: Any) -> str:
|
|
153
|
+
"""格式化工具输出为字符串,优先提取常见字段或 content 属性,最后回退到 JSON/str。"""
|
|
154
|
+
if output is None:
|
|
155
|
+
return ""
|
|
156
|
+
# dict-like
|
|
157
|
+
if isinstance(output, dict):
|
|
158
|
+
for key in ("content", "result", "output"):
|
|
159
|
+
if key in output:
|
|
160
|
+
v = output[key]
|
|
161
|
+
if isinstance(v, (dict, list)):
|
|
162
|
+
return json.dumps(v, ensure_ascii=False)
|
|
163
|
+
return str(v) if v is not None else ""
|
|
164
|
+
try:
|
|
165
|
+
return json.dumps(output, ensure_ascii=False)
|
|
166
|
+
except Exception:
|
|
167
|
+
return str(output)
|
|
168
|
+
|
|
169
|
+
# 对象有 content 属性
|
|
170
|
+
if hasattr(output, "content"):
|
|
171
|
+
c = AgentRunConverter._get_message_content(output)
|
|
172
|
+
if isinstance(c, (dict, list)):
|
|
173
|
+
try:
|
|
174
|
+
return json.dumps(c, ensure_ascii=False)
|
|
175
|
+
except Exception:
|
|
176
|
+
return str(c)
|
|
177
|
+
return c or ""
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
return str(output)
|
|
181
|
+
except Exception:
|
|
182
|
+
return ""
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _safe_json_dumps(obj: Any) -> str:
|
|
186
|
+
"""JSON 序列化兜底,无法序列化则回退到 str。"""
|
|
187
|
+
try:
|
|
188
|
+
return json.dumps(obj, ensure_ascii=False, default=str)
|
|
189
|
+
except Exception:
|
|
190
|
+
try:
|
|
191
|
+
return str(obj)
|
|
192
|
+
except Exception:
|
|
193
|
+
return ""
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def _filter_tool_input(tool_input: Any) -> Any:
|
|
197
|
+
"""过滤工具输入中的内部字段,只保留用户传入的实际参数。
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
tool_input: 工具输入(可能是 dict 或其他类型)
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
过滤后的工具输入
|
|
204
|
+
"""
|
|
205
|
+
if not isinstance(tool_input, dict):
|
|
206
|
+
return tool_input
|
|
207
|
+
|
|
208
|
+
filtered = {}
|
|
209
|
+
for key, value in tool_input.items():
|
|
210
|
+
# 跳过内部字段
|
|
211
|
+
if key in _TOOL_INPUT_INTERNAL_KEYS:
|
|
212
|
+
continue
|
|
213
|
+
# 跳过所有下划线前缀的内部字段(包含单下划线与双下划线)
|
|
214
|
+
if key.startswith("_"):
|
|
215
|
+
continue
|
|
216
|
+
filtered[key] = value
|
|
217
|
+
|
|
218
|
+
return filtered
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _extract_tool_call_id(tool_input: Any) -> Optional[str]:
|
|
222
|
+
"""从工具输入中提取原始的 tool_call_id。
|
|
223
|
+
|
|
224
|
+
MCP 工具会在 input 中注入 runtime 对象,其中包含 LLM 返回的原始 tool_call_id。
|
|
225
|
+
使用这个 ID 可以保证工具调用事件的 ID 一致性。
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
tool_input: 工具输入(可能是 dict 或其他类型)
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
tool_call_id 或 None
|
|
232
|
+
"""
|
|
233
|
+
if not isinstance(tool_input, dict):
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
# 尝试从 runtime 对象中提取 tool_call_id
|
|
237
|
+
runtime = tool_input.get("runtime")
|
|
238
|
+
if runtime is not None and hasattr(runtime, "tool_call_id"):
|
|
239
|
+
tc_id = runtime.tool_call_id
|
|
240
|
+
if isinstance(tc_id, str) and tc_id:
|
|
241
|
+
return tc_id
|
|
242
|
+
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _extract_content(chunk: Any) -> Optional[str]:
|
|
247
|
+
"""从 chunk 中提取文本内容"""
|
|
248
|
+
if chunk is None:
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
if hasattr(chunk, "content"):
|
|
252
|
+
content = chunk.content
|
|
253
|
+
if isinstance(content, str):
|
|
254
|
+
return content if content else None
|
|
255
|
+
if isinstance(content, list):
|
|
256
|
+
text_parts = []
|
|
257
|
+
for item in content:
|
|
258
|
+
if isinstance(item, str):
|
|
259
|
+
text_parts.append(item)
|
|
260
|
+
elif isinstance(item, dict) and item.get("type") == "text":
|
|
261
|
+
text_parts.append(item.get("text", ""))
|
|
262
|
+
return "".join(text_parts) if text_parts else None
|
|
263
|
+
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _extract_tool_call_chunks(chunk: Any) -> List[Dict]:
|
|
268
|
+
"""从 AIMessageChunk 中提取工具调用增量"""
|
|
269
|
+
tool_calls = []
|
|
270
|
+
|
|
271
|
+
if hasattr(chunk, "tool_call_chunks") and chunk.tool_call_chunks:
|
|
272
|
+
for tc in chunk.tool_call_chunks:
|
|
273
|
+
if isinstance(tc, dict):
|
|
274
|
+
tool_calls.append(tc)
|
|
275
|
+
else:
|
|
276
|
+
tool_calls.append({
|
|
277
|
+
"id": getattr(tc, "id", None),
|
|
278
|
+
"name": getattr(tc, "name", None),
|
|
279
|
+
"args": getattr(tc, "args", None),
|
|
280
|
+
"index": getattr(tc, "index", None),
|
|
281
|
+
})
|
|
282
|
+
|
|
283
|
+
return tool_calls
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def _get_message_type(msg: Any) -> str:
|
|
287
|
+
"""获取消息类型"""
|
|
288
|
+
if hasattr(msg, "type"):
|
|
289
|
+
return str(msg.type).lower()
|
|
290
|
+
|
|
291
|
+
if isinstance(msg, dict):
|
|
292
|
+
msg_type = msg.get("type", msg.get("role", ""))
|
|
293
|
+
return str(msg_type).lower()
|
|
294
|
+
|
|
295
|
+
class_name = type(msg).__name__.lower()
|
|
296
|
+
if "ai" in class_name or "assistant" in class_name:
|
|
297
|
+
return "ai"
|
|
298
|
+
if "tool" in class_name:
|
|
299
|
+
return "tool"
|
|
300
|
+
if "human" in class_name or "user" in class_name:
|
|
301
|
+
return "human"
|
|
302
|
+
|
|
303
|
+
return "unknown"
|
|
304
|
+
|
|
305
|
+
@staticmethod
|
|
306
|
+
def _get_message_content(msg: Any) -> Optional[str]:
|
|
307
|
+
"""获取消息内容"""
|
|
308
|
+
if hasattr(msg, "content"):
|
|
309
|
+
content = msg.content
|
|
310
|
+
if isinstance(content, str):
|
|
311
|
+
return content
|
|
312
|
+
return str(content) if content else None
|
|
313
|
+
|
|
314
|
+
if isinstance(msg, dict):
|
|
315
|
+
return msg.get("content")
|
|
316
|
+
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
@staticmethod
|
|
320
|
+
def _get_message_tool_calls(msg: Any) -> List[Dict]:
|
|
321
|
+
"""获取消息中的工具调用"""
|
|
322
|
+
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
323
|
+
tool_calls = []
|
|
324
|
+
for tc in msg.tool_calls:
|
|
325
|
+
if isinstance(tc, dict):
|
|
326
|
+
tool_calls.append(tc)
|
|
327
|
+
else:
|
|
328
|
+
tool_calls.append({
|
|
329
|
+
"id": getattr(tc, "id", None),
|
|
330
|
+
"name": getattr(tc, "name", None),
|
|
331
|
+
"args": getattr(tc, "args", None),
|
|
332
|
+
})
|
|
333
|
+
return tool_calls
|
|
334
|
+
|
|
335
|
+
if isinstance(msg, dict) and msg.get("tool_calls"):
|
|
336
|
+
return msg["tool_calls"]
|
|
337
|
+
|
|
338
|
+
return []
|
|
339
|
+
|
|
340
|
+
@staticmethod
|
|
341
|
+
def _get_tool_call_id(msg: Any) -> Optional[str]:
|
|
342
|
+
"""获取 ToolMessage 的 tool_call_id"""
|
|
343
|
+
if hasattr(msg, "tool_call_id"):
|
|
344
|
+
return msg.tool_call_id
|
|
345
|
+
|
|
346
|
+
if isinstance(msg, dict):
|
|
347
|
+
return msg.get("tool_call_id")
|
|
348
|
+
|
|
349
|
+
return None
|
|
350
|
+
|
|
351
|
+
# =========================================================================
|
|
352
|
+
# 事件格式检测(静态方法)
|
|
353
|
+
# =========================================================================
|
|
354
|
+
|
|
355
|
+
@staticmethod
|
|
356
|
+
def _event_to_dict(event: Any) -> Dict[str, Any]:
|
|
357
|
+
"""将 StreamEvent 或 dict 标准化为 dict 以便后续处理"""
|
|
358
|
+
if isinstance(event, dict):
|
|
359
|
+
return event
|
|
360
|
+
|
|
361
|
+
result: Dict[str, Any] = {}
|
|
362
|
+
# 常见属性映射,兼容多种 StreamEvent 实现
|
|
363
|
+
if hasattr(event, "event"):
|
|
364
|
+
result["event"] = getattr(event, "event")
|
|
365
|
+
if hasattr(event, "data"):
|
|
366
|
+
result["data"] = getattr(event, "data")
|
|
367
|
+
if hasattr(event, "name"):
|
|
368
|
+
result["name"] = getattr(event, "name")
|
|
369
|
+
if hasattr(event, "run_id"):
|
|
370
|
+
result["run_id"] = getattr(event, "run_id")
|
|
371
|
+
|
|
372
|
+
return result
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
def is_astream_events_format(event_dict: Dict[str, Any]) -> bool:
|
|
376
|
+
"""检测是否是 astream_events 格式的事件
|
|
377
|
+
|
|
378
|
+
astream_events 格式特征:有 "event" 字段,值以 "on_" 开头
|
|
379
|
+
"""
|
|
380
|
+
event_type = event_dict.get("event", "")
|
|
381
|
+
return isinstance(event_type, str) and event_type.startswith("on_")
|
|
382
|
+
|
|
383
|
+
@staticmethod
|
|
384
|
+
def is_stream_updates_format(event_dict: Dict[str, Any]) -> bool:
|
|
385
|
+
"""检测是否是 stream/astream(stream_mode="updates") 格式的事件
|
|
386
|
+
|
|
387
|
+
updates 格式特征:{node_name: {messages_key: [...]}} 或 {node_name: state_dict}
|
|
388
|
+
没有 "event" 字段,键是 node 名称(如 "model", "agent", "tools"),值是 state 更新
|
|
389
|
+
|
|
390
|
+
与 values 格式的区别:
|
|
391
|
+
- updates: {node_name: {messages: [...]}} - 嵌套结构
|
|
392
|
+
- values: {messages: [...]} - 扁平结构
|
|
393
|
+
"""
|
|
394
|
+
if "event" in event_dict:
|
|
395
|
+
return False
|
|
396
|
+
|
|
397
|
+
# 如果直接包含 "messages" 键且值是 list,这是 values 格式,不是 updates
|
|
398
|
+
if "messages" in event_dict and isinstance(
|
|
399
|
+
event_dict["messages"], list
|
|
400
|
+
):
|
|
401
|
+
return False
|
|
402
|
+
|
|
403
|
+
# 检查是否有类似 node 更新的结构
|
|
404
|
+
for key, value in event_dict.items():
|
|
405
|
+
if key == "__end__":
|
|
406
|
+
continue
|
|
407
|
+
# value 应该是一个 dict(state 更新),包含 messages 等字段
|
|
408
|
+
if isinstance(value, dict):
|
|
409
|
+
return True
|
|
410
|
+
|
|
411
|
+
return False
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def is_stream_values_format(event_dict: Dict[str, Any]) -> bool:
|
|
415
|
+
"""检测是否是 stream/astream(stream_mode="values") 格式的事件
|
|
416
|
+
|
|
417
|
+
values 格式特征:直接是完整 state,如 {messages: [...], ...}
|
|
418
|
+
没有 "event" 字段,直接包含 "messages" 或类似的 state 字段
|
|
419
|
+
|
|
420
|
+
与 updates 格式的区别:
|
|
421
|
+
- values: {messages: [...]} - 扁平结构,messages 值直接是 list
|
|
422
|
+
- updates: {node_name: {messages: [...]}} - 嵌套结构
|
|
423
|
+
"""
|
|
424
|
+
if "event" in event_dict:
|
|
425
|
+
return False
|
|
426
|
+
|
|
427
|
+
# 检查是否直接包含 messages 列表(扁平结构)
|
|
428
|
+
if "messages" in event_dict and isinstance(
|
|
429
|
+
event_dict["messages"], list
|
|
430
|
+
):
|
|
431
|
+
return True
|
|
432
|
+
|
|
433
|
+
return False
|
|
434
|
+
|
|
435
|
+
# =========================================================================
|
|
436
|
+
# 事件转换器(静态方法)
|
|
437
|
+
# =========================================================================
|
|
438
|
+
|
|
439
|
+
@staticmethod
|
|
440
|
+
def _convert_stream_updates_event(
|
|
441
|
+
event_dict: Dict[str, Any],
|
|
442
|
+
messages_key: str = "messages",
|
|
443
|
+
) -> Iterator[Union[AgentResult, str]]:
|
|
444
|
+
"""转换 stream/astream(stream_mode="updates") 格式的单个事件
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
event_dict: 事件字典,格式为 {node_name: state_update}
|
|
448
|
+
messages_key: state 中消息列表的 key
|
|
449
|
+
|
|
450
|
+
Yields:
|
|
451
|
+
str (文本内容) 或 AgentResult (事件)
|
|
452
|
+
|
|
453
|
+
Note:
|
|
454
|
+
在 updates 模式下,工具调用和结果在不同的事件中:
|
|
455
|
+
- AI 消息包含 tool_calls(仅发送 TOOL_CALL_START + TOOL_CALL_ARGS)
|
|
456
|
+
- Tool 消息包含结果(发送 TOOL_CALL_RESULT + TOOL_CALL_END)
|
|
457
|
+
"""
|
|
458
|
+
for node_name, state_update in event_dict.items():
|
|
459
|
+
if node_name == "__end__":
|
|
460
|
+
continue
|
|
461
|
+
|
|
462
|
+
if not isinstance(state_update, dict):
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
messages = state_update.get(messages_key, [])
|
|
466
|
+
if not isinstance(messages, list):
|
|
467
|
+
# 尝试其他常见的 key
|
|
468
|
+
for alt_key in ("message", "output", "response"):
|
|
469
|
+
if alt_key in state_update:
|
|
470
|
+
alt_value = state_update[alt_key]
|
|
471
|
+
if isinstance(alt_value, list):
|
|
472
|
+
messages = alt_value
|
|
473
|
+
break
|
|
474
|
+
elif hasattr(alt_value, "content"):
|
|
475
|
+
messages = [alt_value]
|
|
476
|
+
break
|
|
477
|
+
|
|
478
|
+
for msg in messages:
|
|
479
|
+
msg_type = AgentRunConverter._get_message_type(msg)
|
|
480
|
+
|
|
481
|
+
if msg_type == "ai":
|
|
482
|
+
# 文本内容
|
|
483
|
+
content = AgentRunConverter._get_message_content(msg)
|
|
484
|
+
if content:
|
|
485
|
+
yield content
|
|
486
|
+
|
|
487
|
+
# 工具调用(仅发送 START 和 ARGS,END 在收到结果后发送)
|
|
488
|
+
for tc in AgentRunConverter._get_message_tool_calls(msg):
|
|
489
|
+
tc_id = tc.get("id", "")
|
|
490
|
+
tc_name = tc.get("name", "")
|
|
491
|
+
tc_args = tc.get("args", {})
|
|
492
|
+
|
|
493
|
+
if tc_id:
|
|
494
|
+
# 发送带有完整参数的 TOOL_CALL_CHUNK
|
|
495
|
+
args_str = ""
|
|
496
|
+
if tc_args:
|
|
497
|
+
args_str = (
|
|
498
|
+
AgentRunConverter._safe_json_dumps(tc_args)
|
|
499
|
+
if isinstance(tc_args, dict)
|
|
500
|
+
else str(tc_args)
|
|
501
|
+
)
|
|
502
|
+
yield AgentResult(
|
|
503
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
504
|
+
data={
|
|
505
|
+
"id": tc_id,
|
|
506
|
+
"name": tc_name,
|
|
507
|
+
"args_delta": args_str,
|
|
508
|
+
},
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
elif msg_type == "tool":
|
|
512
|
+
# 工具结果
|
|
513
|
+
tool_call_id = AgentRunConverter._get_tool_call_id(msg)
|
|
514
|
+
if tool_call_id:
|
|
515
|
+
tool_content = AgentRunConverter._get_message_content(
|
|
516
|
+
msg
|
|
517
|
+
)
|
|
518
|
+
yield AgentResult(
|
|
519
|
+
event=EventType.TOOL_RESULT,
|
|
520
|
+
data={
|
|
521
|
+
"id": tool_call_id,
|
|
522
|
+
"result": (
|
|
523
|
+
str(tool_content) if tool_content else ""
|
|
524
|
+
),
|
|
525
|
+
},
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def _convert_stream_values_event(
|
|
530
|
+
event_dict: Dict[str, Any],
|
|
531
|
+
messages_key: str = "messages",
|
|
532
|
+
) -> Iterator[Union[AgentResult, str]]:
|
|
533
|
+
"""转换 stream/astream(stream_mode="values") 格式的单个事件
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
event_dict: 事件字典,格式为完整的 state dict
|
|
537
|
+
messages_key: state 中消息列表的 key
|
|
538
|
+
|
|
539
|
+
Yields:
|
|
540
|
+
str (文本内容) 或 AgentResult (事件)
|
|
541
|
+
|
|
542
|
+
Note:
|
|
543
|
+
在 values 模式下,工具调用和结果可能在同一事件中或不同事件中。
|
|
544
|
+
我们只处理最后一条消息。
|
|
545
|
+
"""
|
|
546
|
+
messages = event_dict.get(messages_key, [])
|
|
547
|
+
if not isinstance(messages, list):
|
|
548
|
+
return
|
|
549
|
+
|
|
550
|
+
# 对于 values 模式,我们只关心最后一条消息(通常是最新的)
|
|
551
|
+
if not messages:
|
|
552
|
+
return
|
|
553
|
+
|
|
554
|
+
last_msg = messages[-1]
|
|
555
|
+
msg_type = AgentRunConverter._get_message_type(last_msg)
|
|
556
|
+
|
|
557
|
+
if msg_type == "ai":
|
|
558
|
+
content = AgentRunConverter._get_message_content(last_msg)
|
|
559
|
+
if content:
|
|
560
|
+
yield content
|
|
561
|
+
|
|
562
|
+
# 工具调用
|
|
563
|
+
for tc in AgentRunConverter._get_message_tool_calls(last_msg):
|
|
564
|
+
tc_id = tc.get("id", "")
|
|
565
|
+
tc_name = tc.get("name", "")
|
|
566
|
+
tc_args = tc.get("args", {})
|
|
567
|
+
|
|
568
|
+
if tc_id:
|
|
569
|
+
# 发送带有完整参数的 TOOL_CALL_CHUNK
|
|
570
|
+
args_str = ""
|
|
571
|
+
if tc_args:
|
|
572
|
+
args_str = (
|
|
573
|
+
AgentRunConverter._safe_json_dumps(tc_args)
|
|
574
|
+
if isinstance(tc_args, dict)
|
|
575
|
+
else str(tc_args)
|
|
576
|
+
)
|
|
577
|
+
yield AgentResult(
|
|
578
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
579
|
+
data={
|
|
580
|
+
"id": tc_id,
|
|
581
|
+
"name": tc_name,
|
|
582
|
+
"args_delta": args_str,
|
|
583
|
+
},
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
elif msg_type == "tool":
|
|
587
|
+
tool_call_id = AgentRunConverter._get_tool_call_id(last_msg)
|
|
588
|
+
if tool_call_id:
|
|
589
|
+
tool_content = AgentRunConverter._get_message_content(last_msg)
|
|
590
|
+
yield AgentResult(
|
|
591
|
+
event=EventType.TOOL_RESULT,
|
|
592
|
+
data={
|
|
593
|
+
"id": tool_call_id,
|
|
594
|
+
"result": str(tool_content) if tool_content else "",
|
|
595
|
+
},
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
@staticmethod
|
|
599
|
+
def _convert_astream_events_event(
|
|
600
|
+
event_dict: Dict[str, Any],
|
|
601
|
+
tool_call_id_map: Optional[Dict[int, str]] = None,
|
|
602
|
+
tool_call_started_set: Optional[set] = None,
|
|
603
|
+
tool_name_to_call_ids: Optional[Dict[str, List[str]]] = None,
|
|
604
|
+
run_id_to_tool_call_id: Optional[Dict[str, str]] = None,
|
|
605
|
+
) -> Iterator[Union[AgentResult, str]]:
|
|
606
|
+
"""转换 astream_events 格式的单个事件
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
event_dict: 事件字典,格式为 {"event": "on_xxx", "data": {...}}
|
|
610
|
+
tool_call_id_map: 可选的 index -> tool_call_id 映射字典。
|
|
611
|
+
在流式工具调用中,第一个 chunk 有 id,后续只有 index。
|
|
612
|
+
此映射用于确保所有 chunk 使用一致的 tool_call_id。
|
|
613
|
+
tool_call_started_set: 可选的已发送 TOOL_CALL_START 的 tool_call_id 集合。
|
|
614
|
+
用于确保每个工具调用只发送一次 TOOL_CALL_START。
|
|
615
|
+
tool_name_to_call_ids: 可选的 tool_name -> [tool_call_id] 队列映射。
|
|
616
|
+
用于在 on_tool_start 中查找对应的 tool_call_id。
|
|
617
|
+
run_id_to_tool_call_id: 可选的 run_id -> tool_call_id 映射。
|
|
618
|
+
用于在 on_tool_end 中查找对应的 tool_call_id。
|
|
619
|
+
|
|
620
|
+
Yields:
|
|
621
|
+
str (文本内容) 或 AgentResult (事件)
|
|
622
|
+
"""
|
|
623
|
+
event_type = event_dict.get("event", "")
|
|
624
|
+
data = event_dict.get("data", {})
|
|
625
|
+
|
|
626
|
+
# 1. LangGraph 格式: on_chat_model_stream
|
|
627
|
+
if event_type == "on_chat_model_stream":
|
|
628
|
+
chunk = data.get("chunk")
|
|
629
|
+
if chunk:
|
|
630
|
+
# 文本内容
|
|
631
|
+
content = AgentRunConverter._extract_content(chunk)
|
|
632
|
+
if content:
|
|
633
|
+
yield content
|
|
634
|
+
|
|
635
|
+
# 流式工具调用参数
|
|
636
|
+
for tc in AgentRunConverter._extract_tool_call_chunks(chunk):
|
|
637
|
+
tc_index = tc.get("index")
|
|
638
|
+
tc_raw_id = tc.get("id")
|
|
639
|
+
tc_name = tc.get("name", "")
|
|
640
|
+
tc_args = tc.get("args", "")
|
|
641
|
+
|
|
642
|
+
# 解析 tool_call_id:
|
|
643
|
+
# 1. 如果有 id 且非空,使用它并更新映射
|
|
644
|
+
# 2. 如果 id 为空但有 index,从映射中查找
|
|
645
|
+
# 3. 最后回退到使用 index 字符串
|
|
646
|
+
if tc_raw_id:
|
|
647
|
+
tc_id = tc_raw_id
|
|
648
|
+
# 更新映射(如果提供了映射字典)
|
|
649
|
+
# 重要:即使这个 chunk 没有 args,也要更新映射,
|
|
650
|
+
# 因为后续 chunk 可能只有 index 没有 id
|
|
651
|
+
if (
|
|
652
|
+
tool_call_id_map is not None
|
|
653
|
+
and tc_index is not None
|
|
654
|
+
):
|
|
655
|
+
tool_call_id_map[tc_index] = tc_id
|
|
656
|
+
elif tc_index is not None:
|
|
657
|
+
# 从映射中查找,如果没有则使用 index
|
|
658
|
+
if (
|
|
659
|
+
tool_call_id_map is not None
|
|
660
|
+
and tc_index in tool_call_id_map
|
|
661
|
+
):
|
|
662
|
+
tc_id = tool_call_id_map[tc_index]
|
|
663
|
+
else:
|
|
664
|
+
tc_id = str(tc_index)
|
|
665
|
+
else:
|
|
666
|
+
tc_id = ""
|
|
667
|
+
|
|
668
|
+
if not tc_id:
|
|
669
|
+
continue
|
|
670
|
+
|
|
671
|
+
# 流式工具调用:第一个 chunk 包含 id 和 name,后续只有 args_delta
|
|
672
|
+
# 协议层会自动处理 START/END 边界事件
|
|
673
|
+
is_first_chunk = (
|
|
674
|
+
tc_raw_id
|
|
675
|
+
and tc_name
|
|
676
|
+
and (
|
|
677
|
+
tool_call_started_set is None
|
|
678
|
+
or tc_id not in tool_call_started_set
|
|
679
|
+
)
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
if is_first_chunk:
|
|
683
|
+
if tool_call_started_set is not None:
|
|
684
|
+
tool_call_started_set.add(tc_id)
|
|
685
|
+
# 记录 tool_name -> tool_call_id 映射,用于 on_tool_start 查找
|
|
686
|
+
if tool_name_to_call_ids is not None and tc_name:
|
|
687
|
+
if tc_name not in tool_name_to_call_ids:
|
|
688
|
+
tool_name_to_call_ids[tc_name] = []
|
|
689
|
+
tool_name_to_call_ids[tc_name].append(tc_id)
|
|
690
|
+
# 第一个 chunk 包含 id 和 name
|
|
691
|
+
args_delta = ""
|
|
692
|
+
if tc_args:
|
|
693
|
+
args_delta = (
|
|
694
|
+
AgentRunConverter._safe_json_dumps(tc_args)
|
|
695
|
+
if isinstance(tc_args, (dict, list))
|
|
696
|
+
else str(tc_args)
|
|
697
|
+
)
|
|
698
|
+
yield AgentResult(
|
|
699
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
700
|
+
data={
|
|
701
|
+
"id": tc_id,
|
|
702
|
+
"name": tc_name,
|
|
703
|
+
"args_delta": args_delta,
|
|
704
|
+
},
|
|
705
|
+
)
|
|
706
|
+
elif tc_args:
|
|
707
|
+
# 后续 chunk 只有 args_delta
|
|
708
|
+
args_delta = (
|
|
709
|
+
AgentRunConverter._safe_json_dumps(tc_args)
|
|
710
|
+
if isinstance(tc_args, (dict, list))
|
|
711
|
+
else str(tc_args)
|
|
712
|
+
)
|
|
713
|
+
yield AgentResult(
|
|
714
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
715
|
+
data={
|
|
716
|
+
"id": tc_id,
|
|
717
|
+
"args_delta": args_delta,
|
|
718
|
+
},
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
# 2. LangChain 格式: on_chain_stream
|
|
722
|
+
elif (
|
|
723
|
+
event_type == "on_chain_stream"
|
|
724
|
+
and event_dict.get("name") == "model"
|
|
725
|
+
):
|
|
726
|
+
chunk_data = data.get("chunk", {})
|
|
727
|
+
if isinstance(chunk_data, dict):
|
|
728
|
+
messages = chunk_data.get("messages", [])
|
|
729
|
+
|
|
730
|
+
for msg in messages:
|
|
731
|
+
content = AgentRunConverter._get_message_content(msg)
|
|
732
|
+
if content:
|
|
733
|
+
yield content
|
|
734
|
+
|
|
735
|
+
for tc in AgentRunConverter._get_message_tool_calls(msg):
|
|
736
|
+
tc_id = tc.get("id", "")
|
|
737
|
+
tc_name = tc.get("name", "")
|
|
738
|
+
tc_args = tc.get("args", {})
|
|
739
|
+
|
|
740
|
+
if tc_id:
|
|
741
|
+
# 检查是否已经发送过这个 tool call
|
|
742
|
+
already_started = (
|
|
743
|
+
tool_call_started_set is not None
|
|
744
|
+
and tc_id in tool_call_started_set
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
if not already_started:
|
|
748
|
+
# 标记为已开始,防止 on_tool_start 重复发送
|
|
749
|
+
if tool_call_started_set is not None:
|
|
750
|
+
tool_call_started_set.add(tc_id)
|
|
751
|
+
|
|
752
|
+
# 记录 tool_name -> tool_call_id 映射
|
|
753
|
+
if (
|
|
754
|
+
tool_name_to_call_ids is not None
|
|
755
|
+
and tc_name
|
|
756
|
+
):
|
|
757
|
+
tool_name_to_call_ids.setdefault(
|
|
758
|
+
tc_name, []
|
|
759
|
+
).append(tc_id)
|
|
760
|
+
|
|
761
|
+
args_delta = ""
|
|
762
|
+
if tc_args:
|
|
763
|
+
args_delta = (
|
|
764
|
+
AgentRunConverter._safe_json_dumps(
|
|
765
|
+
tc_args
|
|
766
|
+
)
|
|
767
|
+
if isinstance(tc_args, dict)
|
|
768
|
+
else str(tc_args)
|
|
769
|
+
)
|
|
770
|
+
yield AgentResult(
|
|
771
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
772
|
+
data={
|
|
773
|
+
"id": tc_id,
|
|
774
|
+
"name": tc_name,
|
|
775
|
+
"args_delta": args_delta,
|
|
776
|
+
},
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# 3. 工具开始
|
|
780
|
+
elif event_type == "on_tool_start":
|
|
781
|
+
run_id = event_dict.get("run_id", "")
|
|
782
|
+
tool_name = event_dict.get("name", "")
|
|
783
|
+
tool_input_raw = data.get("input", {})
|
|
784
|
+
# 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性
|
|
785
|
+
tool_call_id = AgentRunConverter._extract_tool_call_id(
|
|
786
|
+
tool_input_raw
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# 如果 runtime.tool_call_id 不可用,尝试从 tool_name_to_call_ids 映射中查找
|
|
790
|
+
# 这用于处理非 MCP 工具的情况,其中 on_chat_model_stream 已经发送了 TOOL_CALL_START
|
|
791
|
+
if (
|
|
792
|
+
not tool_call_id
|
|
793
|
+
and tool_name_to_call_ids is not None
|
|
794
|
+
and tool_name
|
|
795
|
+
):
|
|
796
|
+
call_ids = tool_name_to_call_ids.get(tool_name, [])
|
|
797
|
+
if call_ids:
|
|
798
|
+
# 使用队列中的第一个 ID(FIFO),并从队列中移除
|
|
799
|
+
tool_call_id = call_ids.pop(0)
|
|
800
|
+
|
|
801
|
+
# 最后回退到 run_id
|
|
802
|
+
if not tool_call_id:
|
|
803
|
+
tool_call_id = run_id
|
|
804
|
+
|
|
805
|
+
# 记录 run_id -> tool_call_id 映射,用于 on_tool_end 查找
|
|
806
|
+
if run_id_to_tool_call_id is not None and run_id and tool_call_id:
|
|
807
|
+
run_id_to_tool_call_id[run_id] = tool_call_id
|
|
808
|
+
|
|
809
|
+
# 过滤掉内部字段(如 MCP 注入的 runtime)
|
|
810
|
+
tool_input = AgentRunConverter._filter_tool_input(tool_input_raw)
|
|
811
|
+
|
|
812
|
+
if tool_call_id:
|
|
813
|
+
# 检查是否已在 on_chat_model_stream 中发送过
|
|
814
|
+
already_started = (
|
|
815
|
+
tool_call_started_set is not None
|
|
816
|
+
and tool_call_id in tool_call_started_set
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
if not already_started:
|
|
820
|
+
# 非流式场景或未收到流式事件,发送完整的 TOOL_CALL_CHUNK
|
|
821
|
+
if tool_call_started_set is not None:
|
|
822
|
+
tool_call_started_set.add(tool_call_id)
|
|
823
|
+
|
|
824
|
+
args_delta = ""
|
|
825
|
+
if tool_input:
|
|
826
|
+
args_delta = (
|
|
827
|
+
AgentRunConverter._safe_json_dumps(tool_input)
|
|
828
|
+
if isinstance(tool_input, dict)
|
|
829
|
+
else str(tool_input)
|
|
830
|
+
)
|
|
831
|
+
yield AgentResult(
|
|
832
|
+
event=EventType.TOOL_CALL_CHUNK,
|
|
833
|
+
data={
|
|
834
|
+
"id": tool_call_id,
|
|
835
|
+
"name": tool_name,
|
|
836
|
+
"args_delta": args_delta,
|
|
837
|
+
},
|
|
838
|
+
)
|
|
839
|
+
# 协议层会自动处理边界事件,无需手动发送 TOOL_CALL_END
|
|
840
|
+
|
|
841
|
+
# 4. 工具结束
|
|
842
|
+
elif event_type == "on_tool_end":
|
|
843
|
+
run_id = event_dict.get("run_id", "")
|
|
844
|
+
output = data.get("output", "")
|
|
845
|
+
tool_input_raw = data.get("input", {})
|
|
846
|
+
# 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性
|
|
847
|
+
tool_call_id = AgentRunConverter._extract_tool_call_id(
|
|
848
|
+
tool_input_raw
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
# 如果 runtime.tool_call_id 不可用,尝试从 run_id_to_tool_call_id 映射中查找
|
|
852
|
+
# 这个映射在 on_tool_start 中建立
|
|
853
|
+
if (
|
|
854
|
+
not tool_call_id
|
|
855
|
+
and run_id_to_tool_call_id is not None
|
|
856
|
+
and run_id
|
|
857
|
+
):
|
|
858
|
+
tool_call_id = run_id_to_tool_call_id.get(run_id)
|
|
859
|
+
|
|
860
|
+
# 最后回退到 run_id
|
|
861
|
+
if not tool_call_id:
|
|
862
|
+
tool_call_id = run_id
|
|
863
|
+
|
|
864
|
+
if tool_call_id:
|
|
865
|
+
# 工具执行完成后发送结果
|
|
866
|
+
yield AgentResult(
|
|
867
|
+
event=EventType.TOOL_RESULT,
|
|
868
|
+
data={
|
|
869
|
+
"id": tool_call_id,
|
|
870
|
+
"result": AgentRunConverter._format_tool_output(output),
|
|
871
|
+
},
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
# 5. LLM 结束
|
|
875
|
+
elif event_type == "on_chat_model_end":
|
|
876
|
+
# 无状态模式下不处理,避免重复
|
|
877
|
+
pass
|
|
878
|
+
|
|
879
|
+
# 6. 工具错误
|
|
880
|
+
elif event_type == "on_tool_error":
|
|
881
|
+
run_id = event_dict.get("run_id", "")
|
|
882
|
+
error = data.get("error")
|
|
883
|
+
tool_input_raw = data.get("input", {})
|
|
884
|
+
tool_name = event_dict.get("name", "")
|
|
885
|
+
# 优先使用 runtime 中的原始 tool_call_id
|
|
886
|
+
tool_call_id = AgentRunConverter._extract_tool_call_id(
|
|
887
|
+
tool_input_raw
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# 如果 runtime.tool_call_id 不可用,尝试从 run_id_to_tool_call_id 映射中查找
|
|
891
|
+
if (
|
|
892
|
+
not tool_call_id
|
|
893
|
+
and run_id_to_tool_call_id is not None
|
|
894
|
+
and run_id
|
|
895
|
+
):
|
|
896
|
+
tool_call_id = run_id_to_tool_call_id.get(run_id)
|
|
897
|
+
|
|
898
|
+
# 最后回退到 run_id
|
|
899
|
+
if not tool_call_id:
|
|
900
|
+
tool_call_id = run_id
|
|
901
|
+
|
|
902
|
+
# 格式化错误信息
|
|
903
|
+
error_message = ""
|
|
904
|
+
if error is not None:
|
|
905
|
+
if isinstance(error, Exception):
|
|
906
|
+
error_message = f"{type(error).__name__}: {str(error)}"
|
|
907
|
+
elif isinstance(error, str):
|
|
908
|
+
error_message = error
|
|
909
|
+
else:
|
|
910
|
+
error_message = str(error)
|
|
911
|
+
|
|
912
|
+
# 发送 ERROR 事件
|
|
913
|
+
yield AgentResult(
|
|
914
|
+
event=EventType.ERROR,
|
|
915
|
+
data={
|
|
916
|
+
"message": (
|
|
917
|
+
f"Tool '{tool_name}' error: {error_message}"
|
|
918
|
+
if tool_name
|
|
919
|
+
else error_message
|
|
920
|
+
),
|
|
921
|
+
"code": "TOOL_ERROR",
|
|
922
|
+
"tool_call_id": tool_call_id,
|
|
923
|
+
},
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
# 7. LLM 错误
|
|
927
|
+
elif event_type == "on_llm_error":
|
|
928
|
+
error = data.get("error")
|
|
929
|
+
error_message = ""
|
|
930
|
+
if error is not None:
|
|
931
|
+
if isinstance(error, Exception):
|
|
932
|
+
error_message = f"{type(error).__name__}: {str(error)}"
|
|
933
|
+
elif isinstance(error, str):
|
|
934
|
+
error_message = error
|
|
935
|
+
else:
|
|
936
|
+
error_message = str(error)
|
|
937
|
+
|
|
938
|
+
yield AgentResult(
|
|
939
|
+
event=EventType.ERROR,
|
|
940
|
+
data={
|
|
941
|
+
"message": f"LLM error: {error_message}",
|
|
942
|
+
"code": "LLM_ERROR",
|
|
943
|
+
},
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
# 8. Chain 错误
|
|
947
|
+
elif event_type == "on_chain_error":
|
|
948
|
+
error = data.get("error")
|
|
949
|
+
chain_name = event_dict.get("name", "")
|
|
950
|
+
error_message = ""
|
|
951
|
+
if error is not None:
|
|
952
|
+
if isinstance(error, Exception):
|
|
953
|
+
error_message = f"{type(error).__name__}: {str(error)}"
|
|
954
|
+
elif isinstance(error, str):
|
|
955
|
+
error_message = error
|
|
956
|
+
else:
|
|
957
|
+
error_message = str(error)
|
|
958
|
+
|
|
959
|
+
yield AgentResult(
|
|
960
|
+
event=EventType.ERROR,
|
|
961
|
+
data={
|
|
962
|
+
"message": (
|
|
963
|
+
f"Chain '{chain_name}' error: {error_message}"
|
|
964
|
+
if chain_name
|
|
965
|
+
else error_message
|
|
966
|
+
),
|
|
967
|
+
"code": "CHAIN_ERROR",
|
|
968
|
+
},
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
# 9. Retriever 错误
|
|
972
|
+
elif event_type == "on_retriever_error":
|
|
973
|
+
error = data.get("error")
|
|
974
|
+
retriever_name = event_dict.get("name", "")
|
|
975
|
+
error_message = ""
|
|
976
|
+
if error is not None:
|
|
977
|
+
if isinstance(error, Exception):
|
|
978
|
+
error_message = f"{type(error).__name__}: {str(error)}"
|
|
979
|
+
elif isinstance(error, str):
|
|
980
|
+
error_message = error
|
|
981
|
+
else:
|
|
982
|
+
error_message = str(error)
|
|
983
|
+
|
|
984
|
+
yield AgentResult(
|
|
985
|
+
event=EventType.ERROR,
|
|
986
|
+
data={
|
|
987
|
+
"message": (
|
|
988
|
+
f"Retriever '{retriever_name}' error: {error_message}"
|
|
989
|
+
if retriever_name
|
|
990
|
+
else error_message
|
|
991
|
+
),
|
|
992
|
+
"code": "RETRIEVER_ERROR",
|
|
993
|
+
},
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
# =========================================================================
|
|
997
|
+
# 主要 API(静态方法)
|
|
998
|
+
# =========================================================================
|
|
999
|
+
|
|
1000
|
+
@staticmethod
|
|
1001
|
+
def to_agui_events(
|
|
1002
|
+
event: Union[Dict[str, Any], Any],
|
|
1003
|
+
messages_key: str = "messages",
|
|
1004
|
+
tool_call_id_map: Optional[Dict[int, str]] = None,
|
|
1005
|
+
tool_call_started_set: Optional[set] = None,
|
|
1006
|
+
tool_name_to_call_ids: Optional[Dict[str, List[str]]] = None,
|
|
1007
|
+
run_id_to_tool_call_id: Optional[Dict[str, str]] = None,
|
|
1008
|
+
) -> Iterator[Union[AgentResult, str]]:
|
|
1009
|
+
"""将 LangGraph/LangChain 流式事件转换为 AG-UI 协议事件
|
|
1010
|
+
|
|
1011
|
+
支持多种调用方式产生的事件格式:
|
|
1012
|
+
- agent.astream_events(input, version="v2")
|
|
1013
|
+
- agent.stream(input, stream_mode="updates")
|
|
1014
|
+
- agent.astream(input, stream_mode="updates")
|
|
1015
|
+
- agent.stream(input, stream_mode="values")
|
|
1016
|
+
- agent.astream(input, stream_mode="values")
|
|
1017
|
+
|
|
1018
|
+
Args:
|
|
1019
|
+
event: LangGraph/LangChain 流式事件(StreamEvent 对象或 Dict)
|
|
1020
|
+
messages_key: state 中消息列表的 key,默认 "messages"
|
|
1021
|
+
tool_call_id_map: 可选的 index -> tool_call_id 映射字典,用于流式工具调用
|
|
1022
|
+
的 ID 一致性。如果提供,函数会自动更新此映射。
|
|
1023
|
+
tool_call_started_set: 可选的已发送 TOOL_CALL_START 的 tool_call_id 集合。
|
|
1024
|
+
用于确保每个工具调用只发送一次 TOOL_CALL_START,
|
|
1025
|
+
并在正确的时机发送 TOOL_CALL_END。
|
|
1026
|
+
tool_name_to_call_ids: 可选的 tool_name -> [tool_call_id] 队列映射。
|
|
1027
|
+
用于在 on_tool_start 中查找对应的 tool_call_id。
|
|
1028
|
+
run_id_to_tool_call_id: 可选的 run_id -> tool_call_id 映射。
|
|
1029
|
+
用于在 on_tool_end 中查找对应的 tool_call_id。
|
|
1030
|
+
|
|
1031
|
+
Yields:
|
|
1032
|
+
str (文本内容) 或 AgentResult (AG-UI 事件)
|
|
1033
|
+
|
|
1034
|
+
Example:
|
|
1035
|
+
>>> # 使用 astream_events(推荐使用 AgentRunConverter 类)
|
|
1036
|
+
>>> async for event in agent.astream_events(input, version="v2"):
|
|
1037
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
1038
|
+
... yield item
|
|
1039
|
+
|
|
1040
|
+
>>> # 使用 stream (updates 模式)
|
|
1041
|
+
>>> for event in agent.stream(input, stream_mode="updates"):
|
|
1042
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
1043
|
+
... yield item
|
|
1044
|
+
|
|
1045
|
+
>>> # 使用 astream (updates 模式)
|
|
1046
|
+
>>> async for event in agent.astream(input, stream_mode="updates"):
|
|
1047
|
+
... for item in AgentRunConverter.to_agui_events(event):
|
|
1048
|
+
... yield item
|
|
1049
|
+
"""
|
|
1050
|
+
event_dict = AgentRunConverter._event_to_dict(event)
|
|
1051
|
+
|
|
1052
|
+
# 根据事件格式选择对应的转换器
|
|
1053
|
+
if AgentRunConverter.is_astream_events_format(event_dict):
|
|
1054
|
+
# astream_events 格式:{"event": "on_xxx", "data": {...}}
|
|
1055
|
+
yield from AgentRunConverter._convert_astream_events_event(
|
|
1056
|
+
event_dict,
|
|
1057
|
+
tool_call_id_map,
|
|
1058
|
+
tool_call_started_set,
|
|
1059
|
+
tool_name_to_call_ids,
|
|
1060
|
+
run_id_to_tool_call_id,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
elif AgentRunConverter.is_stream_updates_format(event_dict):
|
|
1064
|
+
# stream/astream(stream_mode="updates") 格式:{node_name: state_update}
|
|
1065
|
+
yield from AgentRunConverter._convert_stream_updates_event(
|
|
1066
|
+
event_dict, messages_key
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
elif AgentRunConverter.is_stream_values_format(event_dict):
|
|
1070
|
+
# stream/astream(stream_mode="values") 格式:完整 state dict
|
|
1071
|
+
yield from AgentRunConverter._convert_stream_values_event(
|
|
1072
|
+
event_dict, messages_key
|
|
1073
|
+
)
|