agentrun-inner-test 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun/__init__.py +325 -0
- agentrun/agent_runtime/__client_async_template.py +466 -0
- agentrun/agent_runtime/__endpoint_async_template.py +345 -0
- agentrun/agent_runtime/__init__.py +53 -0
- agentrun/agent_runtime/__runtime_async_template.py +477 -0
- agentrun/agent_runtime/api/__data_async_template.py +58 -0
- agentrun/agent_runtime/api/__init__.py +6 -0
- agentrun/agent_runtime/api/control.py +1362 -0
- agentrun/agent_runtime/api/data.py +98 -0
- agentrun/agent_runtime/client.py +868 -0
- agentrun/agent_runtime/endpoint.py +649 -0
- agentrun/agent_runtime/model.py +362 -0
- agentrun/agent_runtime/runtime.py +904 -0
- agentrun/credential/__client_async_template.py +177 -0
- agentrun/credential/__credential_async_template.py +216 -0
- agentrun/credential/__init__.py +28 -0
- agentrun/credential/api/__init__.py +5 -0
- agentrun/credential/api/control.py +606 -0
- agentrun/credential/client.py +319 -0
- agentrun/credential/credential.py +381 -0
- agentrun/credential/model.py +248 -0
- agentrun/integration/__init__.py +21 -0
- agentrun/integration/agentscope/__init__.py +12 -0
- agentrun/integration/agentscope/adapter.py +17 -0
- agentrun/integration/agentscope/builtin.py +65 -0
- agentrun/integration/agentscope/message_adapter.py +185 -0
- agentrun/integration/agentscope/model_adapter.py +60 -0
- agentrun/integration/agentscope/tool_adapter.py +59 -0
- agentrun/integration/builtin/__init__.py +16 -0
- agentrun/integration/builtin/model.py +93 -0
- agentrun/integration/builtin/sandbox.py +1234 -0
- agentrun/integration/builtin/toolset.py +47 -0
- agentrun/integration/crewai/__init__.py +12 -0
- agentrun/integration/crewai/adapter.py +9 -0
- agentrun/integration/crewai/builtin.py +65 -0
- agentrun/integration/crewai/model_adapter.py +31 -0
- agentrun/integration/crewai/tool_adapter.py +26 -0
- agentrun/integration/google_adk/__init__.py +12 -0
- agentrun/integration/google_adk/adapter.py +15 -0
- agentrun/integration/google_adk/builtin.py +65 -0
- agentrun/integration/google_adk/message_adapter.py +144 -0
- agentrun/integration/google_adk/model_adapter.py +46 -0
- agentrun/integration/google_adk/tool_adapter.py +235 -0
- agentrun/integration/langchain/__init__.py +30 -0
- agentrun/integration/langchain/adapter.py +15 -0
- agentrun/integration/langchain/builtin.py +71 -0
- agentrun/integration/langchain/message_adapter.py +141 -0
- agentrun/integration/langchain/model_adapter.py +37 -0
- agentrun/integration/langchain/tool_adapter.py +50 -0
- agentrun/integration/langgraph/__init__.py +35 -0
- agentrun/integration/langgraph/adapter.py +20 -0
- agentrun/integration/langgraph/agent_converter.py +1073 -0
- agentrun/integration/langgraph/builtin.py +65 -0
- agentrun/integration/pydantic_ai/__init__.py +12 -0
- agentrun/integration/pydantic_ai/adapter.py +13 -0
- agentrun/integration/pydantic_ai/builtin.py +65 -0
- agentrun/integration/pydantic_ai/model_adapter.py +44 -0
- agentrun/integration/pydantic_ai/tool_adapter.py +19 -0
- agentrun/integration/utils/__init__.py +112 -0
- agentrun/integration/utils/adapter.py +560 -0
- agentrun/integration/utils/canonical.py +164 -0
- agentrun/integration/utils/converter.py +134 -0
- agentrun/integration/utils/model.py +110 -0
- agentrun/integration/utils/tool.py +1759 -0
- agentrun/model/__client_async_template.py +357 -0
- agentrun/model/__init__.py +57 -0
- agentrun/model/__model_proxy_async_template.py +270 -0
- agentrun/model/__model_service_async_template.py +267 -0
- agentrun/model/api/__init__.py +6 -0
- agentrun/model/api/control.py +1173 -0
- agentrun/model/api/data.py +196 -0
- agentrun/model/client.py +674 -0
- agentrun/model/model.py +235 -0
- agentrun/model/model_proxy.py +439 -0
- agentrun/model/model_service.py +438 -0
- agentrun/sandbox/__aio_sandbox_async_template.py +523 -0
- agentrun/sandbox/__browser_sandbox_async_template.py +110 -0
- agentrun/sandbox/__client_async_template.py +491 -0
- agentrun/sandbox/__code_interpreter_sandbox_async_template.py +463 -0
- agentrun/sandbox/__init__.py +69 -0
- agentrun/sandbox/__sandbox_async_template.py +463 -0
- agentrun/sandbox/__template_async_template.py +152 -0
- agentrun/sandbox/aio_sandbox.py +905 -0
- agentrun/sandbox/api/__aio_data_async_template.py +335 -0
- agentrun/sandbox/api/__browser_data_async_template.py +140 -0
- agentrun/sandbox/api/__code_interpreter_data_async_template.py +206 -0
- agentrun/sandbox/api/__init__.py +19 -0
- agentrun/sandbox/api/__sandbox_data_async_template.py +107 -0
- agentrun/sandbox/api/aio_data.py +551 -0
- agentrun/sandbox/api/browser_data.py +172 -0
- agentrun/sandbox/api/code_interpreter_data.py +396 -0
- agentrun/sandbox/api/control.py +1051 -0
- agentrun/sandbox/api/playwright_async.py +492 -0
- agentrun/sandbox/api/playwright_sync.py +492 -0
- agentrun/sandbox/api/sandbox_data.py +154 -0
- agentrun/sandbox/browser_sandbox.py +185 -0
- agentrun/sandbox/client.py +925 -0
- agentrun/sandbox/code_interpreter_sandbox.py +823 -0
- agentrun/sandbox/model.py +397 -0
- agentrun/sandbox/sandbox.py +848 -0
- agentrun/sandbox/template.py +217 -0
- agentrun/server/__init__.py +191 -0
- agentrun/server/agui_normalizer.py +180 -0
- agentrun/server/agui_protocol.py +797 -0
- agentrun/server/invoker.py +309 -0
- agentrun/server/model.py +427 -0
- agentrun/server/openai_protocol.py +535 -0
- agentrun/server/protocol.py +140 -0
- agentrun/server/server.py +208 -0
- agentrun/toolset/__client_async_template.py +62 -0
- agentrun/toolset/__init__.py +51 -0
- agentrun/toolset/__toolset_async_template.py +204 -0
- agentrun/toolset/api/__init__.py +17 -0
- agentrun/toolset/api/control.py +262 -0
- agentrun/toolset/api/mcp.py +100 -0
- agentrun/toolset/api/openapi.py +1251 -0
- agentrun/toolset/client.py +102 -0
- agentrun/toolset/model.py +321 -0
- agentrun/toolset/toolset.py +270 -0
- agentrun/utils/__data_api_async_template.py +720 -0
- agentrun/utils/__init__.py +5 -0
- agentrun/utils/__resource_async_template.py +158 -0
- agentrun/utils/config.py +258 -0
- agentrun/utils/control_api.py +78 -0
- agentrun/utils/data_api.py +1120 -0
- agentrun/utils/exception.py +151 -0
- agentrun/utils/helper.py +108 -0
- agentrun/utils/log.py +77 -0
- agentrun/utils/model.py +168 -0
- agentrun/utils/resource.py +291 -0
- agentrun_inner_test-0.0.46.dist-info/METADATA +263 -0
- agentrun_inner_test-0.0.46.dist-info/RECORD +135 -0
- agentrun_inner_test-0.0.46.dist-info/WHEEL +5 -0
- agentrun_inner_test-0.0.46.dist-info/licenses/LICENSE +201 -0
- agentrun_inner_test-0.0.46.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,797 @@
|
|
|
1
|
+
"""AG-UI (Agent-User Interaction Protocol) 协议实现
|
|
2
|
+
|
|
3
|
+
AG-UI 是一种开源、轻量级、基于事件的协议,用于标准化 AI Agent 与前端应用之间的交互。
|
|
4
|
+
参考: https://docs.ag-ui.com/
|
|
5
|
+
|
|
6
|
+
本实现使用 ag-ui-protocol 包提供的事件类型和编码器,
|
|
7
|
+
将 AgentResult 事件转换为 AG-UI SSE 格式。
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
AsyncIterator,
|
|
14
|
+
Dict,
|
|
15
|
+
Iterator,
|
|
16
|
+
List,
|
|
17
|
+
Optional,
|
|
18
|
+
TYPE_CHECKING,
|
|
19
|
+
)
|
|
20
|
+
import uuid
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ag_ui.core import (
|
|
24
|
+
Message as AguiMessage,
|
|
25
|
+
)
|
|
26
|
+
from ag_ui.encoder import EventEncoder
|
|
27
|
+
|
|
28
|
+
from fastapi import APIRouter, Request
|
|
29
|
+
from fastapi.responses import StreamingResponse
|
|
30
|
+
import pydash
|
|
31
|
+
|
|
32
|
+
from ..utils.helper import merge, MergeOptions
|
|
33
|
+
from .model import (
|
|
34
|
+
AgentEvent,
|
|
35
|
+
AgentRequest,
|
|
36
|
+
EventType,
|
|
37
|
+
Message,
|
|
38
|
+
MessageRole,
|
|
39
|
+
ServerConfig,
|
|
40
|
+
Tool,
|
|
41
|
+
ToolCall,
|
|
42
|
+
)
|
|
43
|
+
from .protocol import BaseProtocolHandler
|
|
44
|
+
|
|
45
|
+
if TYPE_CHECKING:
|
|
46
|
+
from .invoker import AgentInvoker
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ============================================================================
|
|
50
|
+
# AG-UI 协议处理器
|
|
51
|
+
# ============================================================================
|
|
52
|
+
|
|
53
|
+
DEFAULT_PREFIX = "/ag-ui/agent"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class TextState:
|
|
58
|
+
started: bool = False
|
|
59
|
+
ended: bool = False
|
|
60
|
+
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class ToolCallState:
|
|
65
|
+
name: str = ""
|
|
66
|
+
started: bool = False
|
|
67
|
+
ended: bool = False
|
|
68
|
+
has_result: bool = False
|
|
69
|
+
is_hitl: bool = False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class StreamStateMachine:
|
|
74
|
+
text: TextState = field(default_factory=TextState)
|
|
75
|
+
tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict)
|
|
76
|
+
tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict)
|
|
77
|
+
run_errored: bool = False
|
|
78
|
+
|
|
79
|
+
def end_all_tools(
|
|
80
|
+
self, encoder: "EventEncoder", exclude: Optional[str] = None
|
|
81
|
+
) -> Iterator[str]:
|
|
82
|
+
from ag_ui.core import ToolCallEndEvent
|
|
83
|
+
|
|
84
|
+
for tool_id, state in self.tool_call_states.items():
|
|
85
|
+
if exclude and tool_id == exclude:
|
|
86
|
+
continue
|
|
87
|
+
if state.started and not state.ended:
|
|
88
|
+
yield encoder.encode(ToolCallEndEvent(tool_call_id=tool_id))
|
|
89
|
+
state.ended = True
|
|
90
|
+
|
|
91
|
+
def ensure_text_started(self, encoder: "EventEncoder") -> Iterator[str]:
|
|
92
|
+
from ag_ui.core import TextMessageStartEvent
|
|
93
|
+
|
|
94
|
+
if not self.text.started or self.text.ended:
|
|
95
|
+
if self.text.ended:
|
|
96
|
+
self.text = TextState()
|
|
97
|
+
yield encoder.encode(
|
|
98
|
+
TextMessageStartEvent(
|
|
99
|
+
message_id=self.text.message_id,
|
|
100
|
+
role="assistant",
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
self.text.started = True
|
|
104
|
+
self.text.ended = False
|
|
105
|
+
|
|
106
|
+
def end_text_if_open(self, encoder: "EventEncoder") -> Iterator[str]:
|
|
107
|
+
from ag_ui.core import TextMessageEndEvent
|
|
108
|
+
|
|
109
|
+
if self.text.started and not self.text.ended:
|
|
110
|
+
yield encoder.encode(
|
|
111
|
+
TextMessageEndEvent(message_id=self.text.message_id)
|
|
112
|
+
)
|
|
113
|
+
self.text.ended = True
|
|
114
|
+
|
|
115
|
+
def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None:
|
|
116
|
+
if not tool_id or delta is None:
|
|
117
|
+
return
|
|
118
|
+
if delta:
|
|
119
|
+
self.tool_result_chunks.setdefault(tool_id, []).append(delta)
|
|
120
|
+
|
|
121
|
+
def pop_tool_result_chunks(self, tool_id: str) -> str:
|
|
122
|
+
return "".join(self.tool_result_chunks.pop(tool_id, []))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class AGUIProtocolHandler(BaseProtocolHandler):
|
|
126
|
+
"""AG-UI 协议处理器
|
|
127
|
+
|
|
128
|
+
实现 AG-UI (Agent-User Interaction Protocol) 兼容接口。
|
|
129
|
+
参考: https://docs.ag-ui.com/
|
|
130
|
+
|
|
131
|
+
使用 ag-ui-protocol 包提供的事件类型和编码器。
|
|
132
|
+
|
|
133
|
+
特点:
|
|
134
|
+
- 基于事件的流式通信
|
|
135
|
+
- 完整支持所有 AG-UI 事件类型
|
|
136
|
+
- 支持状态同步
|
|
137
|
+
- 支持工具调用
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
>>> from agentrun.server import AgentRunServer, AGUIProtocolHandler
|
|
141
|
+
>>>
|
|
142
|
+
>>> server = AgentRunServer(
|
|
143
|
+
... invoke_agent=my_agent,
|
|
144
|
+
... protocols=[AGUIProtocolHandler()]
|
|
145
|
+
... )
|
|
146
|
+
>>> server.start(port=8000)
|
|
147
|
+
# 可访问: POST http://localhost:8000/ag-ui/agent
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
name = "ag-ui"
|
|
151
|
+
|
|
152
|
+
def __init__(self, config: Optional[ServerConfig] = None):
|
|
153
|
+
from ag_ui.encoder import EventEncoder
|
|
154
|
+
|
|
155
|
+
self._config = config.agui if config else None
|
|
156
|
+
self._encoder = EventEncoder()
|
|
157
|
+
|
|
158
|
+
def get_prefix(self) -> str:
|
|
159
|
+
"""AG-UI 协议建议使用 /ag-ui/agent 前缀"""
|
|
160
|
+
return pydash.get(self._config, "prefix", DEFAULT_PREFIX)
|
|
161
|
+
|
|
162
|
+
def as_fastapi_router(self, agent_invoker: "AgentInvoker") -> APIRouter:
|
|
163
|
+
"""创建 AG-UI 协议的 FastAPI Router"""
|
|
164
|
+
router = APIRouter()
|
|
165
|
+
|
|
166
|
+
@router.post("")
|
|
167
|
+
async def run_agent(request: Request):
|
|
168
|
+
"""AG-UI 运行 Agent 端点
|
|
169
|
+
|
|
170
|
+
接收 AG-UI 格式的请求,返回 SSE 事件流。
|
|
171
|
+
"""
|
|
172
|
+
sse_headers = {
|
|
173
|
+
"Cache-Control": "no-cache",
|
|
174
|
+
"Connection": "keep-alive",
|
|
175
|
+
"X-Accel-Buffering": "no",
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
request_data = await request.json()
|
|
180
|
+
agent_request, context = await self.parse_request(
|
|
181
|
+
request, request_data
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# 使用 invoke_stream 获取流式结果
|
|
185
|
+
event_stream = self._format_stream(
|
|
186
|
+
agent_invoker.invoke_stream(agent_request),
|
|
187
|
+
context,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return StreamingResponse(
|
|
191
|
+
event_stream,
|
|
192
|
+
media_type=self._encoder.get_content_type(),
|
|
193
|
+
headers=sse_headers,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
except ValueError as e:
|
|
197
|
+
return StreamingResponse(
|
|
198
|
+
self._error_stream(str(e)),
|
|
199
|
+
media_type=self._encoder.get_content_type(),
|
|
200
|
+
headers=sse_headers,
|
|
201
|
+
)
|
|
202
|
+
except Exception as e:
|
|
203
|
+
return StreamingResponse(
|
|
204
|
+
self._error_stream(f"Internal error: {str(e)}"),
|
|
205
|
+
media_type=self._encoder.get_content_type(),
|
|
206
|
+
headers=sse_headers,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
@router.get("/health")
|
|
210
|
+
async def health_check():
|
|
211
|
+
"""健康检查端点"""
|
|
212
|
+
return {"status": "ok", "protocol": "ag-ui", "version": "1.0"}
|
|
213
|
+
|
|
214
|
+
return router
|
|
215
|
+
|
|
216
|
+
async def parse_request(
|
|
217
|
+
self,
|
|
218
|
+
request: Request,
|
|
219
|
+
request_data: Dict[str, Any],
|
|
220
|
+
) -> tuple[AgentRequest, Dict[str, Any]]:
|
|
221
|
+
"""解析 AG-UI 格式的请求
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
request: FastAPI Request 对象
|
|
225
|
+
request_data: HTTP 请求体 JSON 数据
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
tuple: (AgentRequest, context)
|
|
229
|
+
"""
|
|
230
|
+
# 创建上下文
|
|
231
|
+
context = {
|
|
232
|
+
"thread_id": request_data.get("threadId") or str(uuid.uuid4()),
|
|
233
|
+
"run_id": request_data.get("runId") or str(uuid.uuid4()),
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# 解析消息列表
|
|
237
|
+
messages = self._parse_messages(request_data.get("messages", []))
|
|
238
|
+
|
|
239
|
+
# 解析工具列表
|
|
240
|
+
tools = self._parse_tools(request_data.get("tools"))
|
|
241
|
+
|
|
242
|
+
# 构建 AgentRequest
|
|
243
|
+
agent_request = AgentRequest(
|
|
244
|
+
protocol="agui", # 设置协议名称
|
|
245
|
+
messages=messages,
|
|
246
|
+
stream=True, # AG-UI 总是流式
|
|
247
|
+
tools=tools,
|
|
248
|
+
raw_request=request, # 保留原始请求对象
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return agent_request, context
|
|
252
|
+
|
|
253
|
+
def _parse_messages(
|
|
254
|
+
self, raw_messages: List[Dict[str, Any]]
|
|
255
|
+
) -> List[Message]:
|
|
256
|
+
"""解析消息列表
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
raw_messages: 原始消息数据
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
标准化的消息列表
|
|
263
|
+
"""
|
|
264
|
+
messages = []
|
|
265
|
+
|
|
266
|
+
for msg_data in raw_messages:
|
|
267
|
+
if not isinstance(msg_data, dict):
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
role_str = msg_data.get("role", "user")
|
|
271
|
+
try:
|
|
272
|
+
role = MessageRole(role_str)
|
|
273
|
+
except ValueError:
|
|
274
|
+
role = MessageRole.USER
|
|
275
|
+
|
|
276
|
+
# 解析 tool_calls
|
|
277
|
+
tool_calls = None
|
|
278
|
+
if msg_data.get("toolCalls"):
|
|
279
|
+
tool_calls = [
|
|
280
|
+
ToolCall(
|
|
281
|
+
id=tc.get("id", ""),
|
|
282
|
+
type=tc.get("type", "function"),
|
|
283
|
+
function=tc.get("function", {}),
|
|
284
|
+
)
|
|
285
|
+
for tc in msg_data["toolCalls"]
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
messages.append(
|
|
289
|
+
Message(
|
|
290
|
+
id=msg_data.get("id"),
|
|
291
|
+
role=role,
|
|
292
|
+
content=msg_data.get("content"),
|
|
293
|
+
name=msg_data.get("name"),
|
|
294
|
+
tool_calls=tool_calls,
|
|
295
|
+
tool_call_id=msg_data.get("toolCallId"),
|
|
296
|
+
)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return messages
|
|
300
|
+
|
|
301
|
+
def _parse_tools(
|
|
302
|
+
self, raw_tools: Optional[List[Dict[str, Any]]]
|
|
303
|
+
) -> Optional[List[Tool]]:
|
|
304
|
+
"""解析工具列表
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
raw_tools: 原始工具数据
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
标准化的工具列表
|
|
311
|
+
"""
|
|
312
|
+
if not raw_tools:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
tools = []
|
|
316
|
+
for tool_data in raw_tools:
|
|
317
|
+
if not isinstance(tool_data, dict):
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
tools.append(
|
|
321
|
+
Tool(
|
|
322
|
+
type=tool_data.get("type", "function"),
|
|
323
|
+
function=tool_data.get("function", {}),
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
return tools if tools else None
|
|
328
|
+
|
|
329
|
+
async def _format_stream(
|
|
330
|
+
self,
|
|
331
|
+
event_stream: AsyncIterator[AgentEvent],
|
|
332
|
+
context: Dict[str, Any],
|
|
333
|
+
) -> AsyncIterator[str]:
|
|
334
|
+
"""将 AgentEvent 流转换为 AG-UI SSE 格式
|
|
335
|
+
|
|
336
|
+
自动生成边界事件:
|
|
337
|
+
- RUN_STARTED / RUN_FINISHED(生命周期)
|
|
338
|
+
- TEXT_MESSAGE_START / TEXT_MESSAGE_END(文本边界)
|
|
339
|
+
- TOOL_CALL_START / TOOL_CALL_END(工具调用边界)
|
|
340
|
+
|
|
341
|
+
注意:RUN_ERROR 之后不能再发送任何事件(包括 RUN_FINISHED)
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
event_stream: AgentEvent 流
|
|
345
|
+
context: 上下文信息
|
|
346
|
+
|
|
347
|
+
Yields:
|
|
348
|
+
SSE 格式的字符串
|
|
349
|
+
"""
|
|
350
|
+
from ag_ui.core import RunFinishedEvent, RunStartedEvent
|
|
351
|
+
|
|
352
|
+
state = StreamStateMachine()
|
|
353
|
+
|
|
354
|
+
# 发送 RUN_STARTED
|
|
355
|
+
yield self._encoder.encode(
|
|
356
|
+
RunStartedEvent(
|
|
357
|
+
thread_id=context.get("thread_id"),
|
|
358
|
+
run_id=context.get("run_id"),
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
async for event in event_stream:
|
|
363
|
+
# RUN_ERROR 后不再处理任何事件
|
|
364
|
+
if state.run_errored:
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
# 检查是否是错误事件
|
|
368
|
+
if event.event == EventType.ERROR:
|
|
369
|
+
state.run_errored = True
|
|
370
|
+
|
|
371
|
+
# 处理边界事件注入
|
|
372
|
+
for sse_data in self._process_event_with_boundaries(
|
|
373
|
+
event,
|
|
374
|
+
context,
|
|
375
|
+
state,
|
|
376
|
+
):
|
|
377
|
+
if sse_data:
|
|
378
|
+
yield sse_data
|
|
379
|
+
|
|
380
|
+
# RUN_ERROR 后不发送任何清理事件
|
|
381
|
+
if state.run_errored:
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
# 结束所有未结束的工具调用
|
|
385
|
+
for sse_data in state.end_all_tools(self._encoder):
|
|
386
|
+
yield sse_data
|
|
387
|
+
|
|
388
|
+
# 发送 TEXT_MESSAGE_END(如果有文本消息且未结束)
|
|
389
|
+
for sse_data in state.end_text_if_open(self._encoder):
|
|
390
|
+
yield sse_data
|
|
391
|
+
|
|
392
|
+
# 发送 RUN_FINISHED
|
|
393
|
+
yield self._encoder.encode(
|
|
394
|
+
RunFinishedEvent(
|
|
395
|
+
thread_id=context.get("thread_id"),
|
|
396
|
+
run_id=context.get("run_id"),
|
|
397
|
+
)
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
def _process_event_with_boundaries(
|
|
401
|
+
self,
|
|
402
|
+
event: AgentEvent,
|
|
403
|
+
context: Dict[str, Any],
|
|
404
|
+
state: StreamStateMachine,
|
|
405
|
+
) -> Iterator[str]:
|
|
406
|
+
"""处理事件并注入边界事件"""
|
|
407
|
+
import json
|
|
408
|
+
|
|
409
|
+
from ag_ui.core import CustomEvent as AguiCustomEvent
|
|
410
|
+
from ag_ui.core import (
|
|
411
|
+
RunErrorEvent,
|
|
412
|
+
StateDeltaEvent,
|
|
413
|
+
StateSnapshotEvent,
|
|
414
|
+
TextMessageContentEvent,
|
|
415
|
+
ToolCallArgsEvent,
|
|
416
|
+
ToolCallEndEvent,
|
|
417
|
+
ToolCallResultEvent,
|
|
418
|
+
ToolCallStartEvent,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# RAW 事件直接透传
|
|
422
|
+
if event.event == EventType.RAW:
|
|
423
|
+
raw_data = event.data.get("raw", "")
|
|
424
|
+
if raw_data:
|
|
425
|
+
if not raw_data.endswith("\n\n"):
|
|
426
|
+
raw_data = raw_data.rstrip("\n") + "\n\n"
|
|
427
|
+
yield raw_data
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
|
|
431
|
+
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
|
|
432
|
+
if event.event == EventType.TEXT:
|
|
433
|
+
for sse_data in state.end_all_tools(self._encoder):
|
|
434
|
+
yield sse_data
|
|
435
|
+
|
|
436
|
+
for sse_data in state.ensure_text_started(self._encoder):
|
|
437
|
+
yield sse_data
|
|
438
|
+
|
|
439
|
+
agui_event = TextMessageContentEvent(
|
|
440
|
+
message_id=state.text.message_id,
|
|
441
|
+
delta=event.data.get("delta", ""),
|
|
442
|
+
)
|
|
443
|
+
if event.addition:
|
|
444
|
+
event_dict = agui_event.model_dump(
|
|
445
|
+
by_alias=True, exclude_none=True
|
|
446
|
+
)
|
|
447
|
+
event_dict = self._apply_addition(
|
|
448
|
+
event_dict,
|
|
449
|
+
event.addition,
|
|
450
|
+
event.addition_merge_options,
|
|
451
|
+
)
|
|
452
|
+
json_str = json.dumps(event_dict, ensure_ascii=False)
|
|
453
|
+
yield f"data: {json_str}\n\n"
|
|
454
|
+
else:
|
|
455
|
+
yield self._encoder.encode(agui_event)
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
# TOOL_CALL_CHUNK 事件:在首个 CHUNK 前注入 TOOL_CALL_START
|
|
459
|
+
if event.event == EventType.TOOL_CALL_CHUNK:
|
|
460
|
+
tool_id = event.data.get("id", "")
|
|
461
|
+
tool_name = event.data.get("name", "")
|
|
462
|
+
|
|
463
|
+
for sse_data in state.end_text_if_open(self._encoder):
|
|
464
|
+
yield sse_data
|
|
465
|
+
|
|
466
|
+
need_start = False
|
|
467
|
+
current_state = state.tool_call_states.get(tool_id)
|
|
468
|
+
if tool_id:
|
|
469
|
+
if current_state is None or current_state.ended:
|
|
470
|
+
need_start = True
|
|
471
|
+
|
|
472
|
+
if need_start:
|
|
473
|
+
yield self._encoder.encode(
|
|
474
|
+
ToolCallStartEvent(
|
|
475
|
+
tool_call_id=tool_id,
|
|
476
|
+
tool_call_name=tool_name,
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
state.tool_call_states[tool_id] = ToolCallState(
|
|
480
|
+
name=tool_name,
|
|
481
|
+
started=True,
|
|
482
|
+
ended=False,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
yield self._encoder.encode(
|
|
486
|
+
ToolCallArgsEvent(
|
|
487
|
+
tool_call_id=tool_id,
|
|
488
|
+
delta=event.data.get("args_delta", ""),
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
return
|
|
492
|
+
|
|
493
|
+
# TOOL_CALL 事件:完整的工具调用事件
|
|
494
|
+
if event.event == EventType.TOOL_CALL:
|
|
495
|
+
tool_id = event.data.get("id", "")
|
|
496
|
+
tool_name = event.data.get("name", "")
|
|
497
|
+
tool_args = event.data.get("args", "")
|
|
498
|
+
|
|
499
|
+
for sse_data in state.end_text_if_open(self._encoder):
|
|
500
|
+
yield sse_data
|
|
501
|
+
|
|
502
|
+
need_start = False
|
|
503
|
+
current_state = state.tool_call_states.get(tool_id)
|
|
504
|
+
if tool_id:
|
|
505
|
+
if current_state is None or current_state.ended:
|
|
506
|
+
need_start = True
|
|
507
|
+
|
|
508
|
+
if need_start:
|
|
509
|
+
yield self._encoder.encode(
|
|
510
|
+
ToolCallStartEvent(
|
|
511
|
+
tool_call_id=tool_id,
|
|
512
|
+
tool_call_name=tool_name,
|
|
513
|
+
)
|
|
514
|
+
)
|
|
515
|
+
state.tool_call_states[tool_id] = ToolCallState(
|
|
516
|
+
name=tool_name,
|
|
517
|
+
started=True,
|
|
518
|
+
ended=False,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# 发送工具参数(如果存在)
|
|
522
|
+
if tool_args:
|
|
523
|
+
yield self._encoder.encode(
|
|
524
|
+
ToolCallArgsEvent(
|
|
525
|
+
tool_call_id=tool_id,
|
|
526
|
+
delta=tool_args,
|
|
527
|
+
)
|
|
528
|
+
)
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
# TOOL_RESULT_CHUNK 事件:工具执行过程中的流式输出
|
|
532
|
+
if event.event == EventType.TOOL_RESULT_CHUNK:
|
|
533
|
+
tool_id = event.data.get("id", "")
|
|
534
|
+
delta = event.data.get("delta", "")
|
|
535
|
+
state.cache_tool_result_chunk(tool_id, delta)
|
|
536
|
+
return
|
|
537
|
+
|
|
538
|
+
# HITL 事件:请求人类介入
|
|
539
|
+
if event.event == EventType.HITL:
|
|
540
|
+
hitl_id = event.data.get("id", "")
|
|
541
|
+
tool_call_id = event.data.get("tool_call_id", "")
|
|
542
|
+
hitl_type = event.data.get("type", "confirmation")
|
|
543
|
+
prompt = event.data.get("prompt", "")
|
|
544
|
+
options = event.data.get("options")
|
|
545
|
+
default = event.data.get("default")
|
|
546
|
+
timeout = event.data.get("timeout")
|
|
547
|
+
schema = event.data.get("schema")
|
|
548
|
+
|
|
549
|
+
for sse_data in state.end_text_if_open(self._encoder):
|
|
550
|
+
yield sse_data
|
|
551
|
+
|
|
552
|
+
if tool_call_id and tool_call_id in state.tool_call_states:
|
|
553
|
+
tool_state = state.tool_call_states[tool_call_id]
|
|
554
|
+
if tool_state.started and not tool_state.ended:
|
|
555
|
+
yield self._encoder.encode(
|
|
556
|
+
ToolCallEndEvent(tool_call_id=tool_call_id)
|
|
557
|
+
)
|
|
558
|
+
tool_state.ended = True
|
|
559
|
+
tool_state.is_hitl = True
|
|
560
|
+
tool_state.has_result = False
|
|
561
|
+
return
|
|
562
|
+
|
|
563
|
+
import json as json_module
|
|
564
|
+
|
|
565
|
+
args_dict: Dict[str, Any] = {
|
|
566
|
+
"type": hitl_type,
|
|
567
|
+
"prompt": prompt,
|
|
568
|
+
}
|
|
569
|
+
if options:
|
|
570
|
+
args_dict["options"] = options
|
|
571
|
+
if default is not None:
|
|
572
|
+
args_dict["default"] = default
|
|
573
|
+
if timeout is not None:
|
|
574
|
+
args_dict["timeout"] = timeout
|
|
575
|
+
if schema:
|
|
576
|
+
args_dict["schema"] = schema
|
|
577
|
+
|
|
578
|
+
args_json = json_module.dumps(args_dict, ensure_ascii=False)
|
|
579
|
+
actual_id = tool_call_id or hitl_id
|
|
580
|
+
|
|
581
|
+
yield self._encoder.encode(
|
|
582
|
+
ToolCallStartEvent(
|
|
583
|
+
tool_call_id=actual_id,
|
|
584
|
+
tool_call_name=f"hitl_{hitl_type}",
|
|
585
|
+
)
|
|
586
|
+
)
|
|
587
|
+
yield self._encoder.encode(
|
|
588
|
+
ToolCallArgsEvent(
|
|
589
|
+
tool_call_id=actual_id,
|
|
590
|
+
delta=args_json,
|
|
591
|
+
)
|
|
592
|
+
)
|
|
593
|
+
yield self._encoder.encode(ToolCallEndEvent(tool_call_id=actual_id))
|
|
594
|
+
|
|
595
|
+
state.tool_call_states[actual_id] = ToolCallState(
|
|
596
|
+
name=f"hitl_{hitl_type}",
|
|
597
|
+
started=True,
|
|
598
|
+
ended=True,
|
|
599
|
+
has_result=False,
|
|
600
|
+
is_hitl=True,
|
|
601
|
+
)
|
|
602
|
+
return
|
|
603
|
+
|
|
604
|
+
# TOOL_RESULT 事件:确保当前工具调用已结束
|
|
605
|
+
if event.event == EventType.TOOL_RESULT:
|
|
606
|
+
tool_id = event.data.get("id", "")
|
|
607
|
+
tool_name = event.data.get("name", "")
|
|
608
|
+
|
|
609
|
+
for sse_data in state.end_text_if_open(self._encoder):
|
|
610
|
+
yield sse_data
|
|
611
|
+
|
|
612
|
+
tool_state = (
|
|
613
|
+
state.tool_call_states.get(tool_id) if tool_id else None
|
|
614
|
+
)
|
|
615
|
+
if tool_id and tool_state is None:
|
|
616
|
+
yield self._encoder.encode(
|
|
617
|
+
ToolCallStartEvent(
|
|
618
|
+
tool_call_id=tool_id,
|
|
619
|
+
tool_call_name=tool_name or "",
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
tool_state = ToolCallState(
|
|
623
|
+
name=tool_name, started=True, ended=False
|
|
624
|
+
)
|
|
625
|
+
state.tool_call_states[tool_id] = tool_state
|
|
626
|
+
|
|
627
|
+
if tool_state and tool_state.started and not tool_state.ended:
|
|
628
|
+
yield self._encoder.encode(
|
|
629
|
+
ToolCallEndEvent(tool_call_id=tool_id)
|
|
630
|
+
)
|
|
631
|
+
tool_state.ended = True
|
|
632
|
+
|
|
633
|
+
final_result = event.data.get("content") or event.data.get(
|
|
634
|
+
"result", ""
|
|
635
|
+
)
|
|
636
|
+
if tool_id:
|
|
637
|
+
cached_chunks = state.pop_tool_result_chunks(tool_id)
|
|
638
|
+
if cached_chunks:
|
|
639
|
+
final_result = cached_chunks + final_result
|
|
640
|
+
|
|
641
|
+
yield self._encoder.encode(
|
|
642
|
+
ToolCallResultEvent(
|
|
643
|
+
message_id=event.data.get(
|
|
644
|
+
"message_id", f"tool-result-{tool_id}"
|
|
645
|
+
),
|
|
646
|
+
tool_call_id=tool_id,
|
|
647
|
+
content=final_result,
|
|
648
|
+
role="tool",
|
|
649
|
+
)
|
|
650
|
+
)
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
# ERROR 事件
|
|
654
|
+
if event.event == EventType.ERROR:
|
|
655
|
+
yield self._encoder.encode(
|
|
656
|
+
RunErrorEvent(
|
|
657
|
+
message=event.data.get("message", ""),
|
|
658
|
+
code=event.data.get("code"),
|
|
659
|
+
)
|
|
660
|
+
)
|
|
661
|
+
return
|
|
662
|
+
|
|
663
|
+
# STATE 事件
|
|
664
|
+
if event.event == EventType.STATE:
|
|
665
|
+
if "snapshot" in event.data:
|
|
666
|
+
yield self._encoder.encode(
|
|
667
|
+
StateSnapshotEvent(snapshot=event.data.get("snapshot", {}))
|
|
668
|
+
)
|
|
669
|
+
elif "delta" in event.data:
|
|
670
|
+
yield self._encoder.encode(
|
|
671
|
+
StateDeltaEvent(delta=event.data.get("delta", []))
|
|
672
|
+
)
|
|
673
|
+
else:
|
|
674
|
+
yield self._encoder.encode(
|
|
675
|
+
StateSnapshotEvent(snapshot=event.data)
|
|
676
|
+
)
|
|
677
|
+
return
|
|
678
|
+
|
|
679
|
+
# CUSTOM 事件
|
|
680
|
+
if event.event == EventType.CUSTOM:
|
|
681
|
+
yield self._encoder.encode(
|
|
682
|
+
AguiCustomEvent(
|
|
683
|
+
name=event.data.get("name", "custom"),
|
|
684
|
+
value=event.data.get("value"),
|
|
685
|
+
)
|
|
686
|
+
)
|
|
687
|
+
return
|
|
688
|
+
|
|
689
|
+
# 其他未知事件
|
|
690
|
+
event_name = (
|
|
691
|
+
event.event.value
|
|
692
|
+
if hasattr(event.event, "value")
|
|
693
|
+
else str(event.event)
|
|
694
|
+
)
|
|
695
|
+
yield self._encoder.encode(
|
|
696
|
+
AguiCustomEvent(
|
|
697
|
+
name=event_name,
|
|
698
|
+
value=event.data,
|
|
699
|
+
)
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def _convert_messages_for_snapshot(
|
|
703
|
+
self, messages: List[Dict[str, Any]]
|
|
704
|
+
) -> List["AguiMessage"]:
|
|
705
|
+
"""将消息列表转换为 ag-ui-protocol 格式
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
messages: 消息字典列表
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
ag-ui-protocol 消息列表
|
|
712
|
+
"""
|
|
713
|
+
from ag_ui.core import AssistantMessage, SystemMessage
|
|
714
|
+
from ag_ui.core import ToolMessage as AguiToolMessage
|
|
715
|
+
from ag_ui.core import UserMessage
|
|
716
|
+
|
|
717
|
+
result = []
|
|
718
|
+
for msg in messages:
|
|
719
|
+
if not isinstance(msg, dict):
|
|
720
|
+
continue
|
|
721
|
+
|
|
722
|
+
role = msg.get("role", "user")
|
|
723
|
+
content = msg.get("content", "")
|
|
724
|
+
msg_id = msg.get("id", str(uuid.uuid4()))
|
|
725
|
+
|
|
726
|
+
if role == "user":
|
|
727
|
+
result.append(
|
|
728
|
+
UserMessage(id=msg_id, role="user", content=content)
|
|
729
|
+
)
|
|
730
|
+
elif role == "assistant":
|
|
731
|
+
result.append(
|
|
732
|
+
AssistantMessage(
|
|
733
|
+
id=msg_id,
|
|
734
|
+
role="assistant",
|
|
735
|
+
content=content,
|
|
736
|
+
)
|
|
737
|
+
)
|
|
738
|
+
elif role == "system":
|
|
739
|
+
result.append(
|
|
740
|
+
SystemMessage(id=msg_id, role="system", content=content)
|
|
741
|
+
)
|
|
742
|
+
elif role == "tool":
|
|
743
|
+
result.append(
|
|
744
|
+
AguiToolMessage(
|
|
745
|
+
id=msg_id,
|
|
746
|
+
role="tool",
|
|
747
|
+
content=content,
|
|
748
|
+
tool_call_id=msg.get("tool_call_id", ""),
|
|
749
|
+
)
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
return result
|
|
753
|
+
|
|
754
|
+
def _apply_addition(
|
|
755
|
+
self,
|
|
756
|
+
event_data: Dict[str, Any],
|
|
757
|
+
addition: Optional[Dict[str, Any]],
|
|
758
|
+
merge_options: Optional[MergeOptions] = None,
|
|
759
|
+
) -> Dict[str, Any]:
|
|
760
|
+
"""应用 addition 字段
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
event_data: 原始事件数据
|
|
764
|
+
addition: 附加字段
|
|
765
|
+
merge_options: 合并选项,透传给 utils.helper.merge
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
合并后的事件数据
|
|
769
|
+
"""
|
|
770
|
+
if not addition:
|
|
771
|
+
return event_data
|
|
772
|
+
|
|
773
|
+
return merge(event_data, addition, **(merge_options or {}))
|
|
774
|
+
|
|
775
|
+
async def _error_stream(self, message: str) -> AsyncIterator[str]:
|
|
776
|
+
"""生成错误事件流
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
message: 错误消息
|
|
780
|
+
|
|
781
|
+
Yields:
|
|
782
|
+
SSE 格式的错误事件
|
|
783
|
+
"""
|
|
784
|
+
from ag_ui.core import RunErrorEvent, RunStartedEvent
|
|
785
|
+
|
|
786
|
+
thread_id = str(uuid.uuid4())
|
|
787
|
+
run_id = str(uuid.uuid4())
|
|
788
|
+
|
|
789
|
+
# 生命周期开始
|
|
790
|
+
yield self._encoder.encode(
|
|
791
|
+
RunStartedEvent(thread_id=thread_id, run_id=run_id)
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# 错误事件
|
|
795
|
+
yield self._encoder.encode(
|
|
796
|
+
RunErrorEvent(message=message, code="REQUEST_ERROR")
|
|
797
|
+
)
|