pydantic-ai-slim 1.7.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -0
- pydantic_ai/_agent_graph.py +3 -0
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_run_context.py +8 -2
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/_utils.py +18 -0
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -3
- pydantic_ai/agent/abstract.py +172 -9
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +31 -0
- pydantic_ai/durable_exec/prefect/_agent.py +28 -0
- pydantic_ai/durable_exec/temporal/_agent.py +28 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/messages.py +49 -8
- pydantic_ai/models/__init__.py +42 -1
- pydantic_ai/models/google.py +5 -12
- pydantic_ai/models/groq.py +9 -1
- pydantic_ai/models/openai.py +6 -3
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/result.py +178 -11
- pydantic_ai/tools.py +10 -6
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/METADATA +10 -6
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/RECORD +47 -33
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py
CHANGED
|
@@ -4,107 +4,35 @@ This package provides seamless integration between pydantic-ai agents and ag-ui
|
|
|
4
4
|
for building interactive AI applications with streaming event-based communication.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
# TODO (v2): Remove this module in favor of `pydantic_ai.ui.ag_ui`
|
|
8
8
|
|
|
9
|
-
import
|
|
10
|
-
import uuid
|
|
11
|
-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
|
|
12
|
-
from dataclasses import Field, dataclass, field, replace
|
|
13
|
-
from http import HTTPStatus
|
|
14
|
-
from typing import (
|
|
15
|
-
Any,
|
|
16
|
-
ClassVar,
|
|
17
|
-
Final,
|
|
18
|
-
Generic,
|
|
19
|
-
Protocol,
|
|
20
|
-
TypeAlias,
|
|
21
|
-
TypeVar,
|
|
22
|
-
runtime_checkable,
|
|
23
|
-
)
|
|
9
|
+
from __future__ import annotations
|
|
24
10
|
|
|
25
|
-
from
|
|
11
|
+
from collections.abc import AsyncIterator, Sequence
|
|
12
|
+
from typing import Any
|
|
26
13
|
|
|
27
|
-
from . import
|
|
28
|
-
from .
|
|
29
|
-
from .
|
|
30
|
-
from .exceptions import UserError
|
|
31
|
-
from .messages import (
|
|
32
|
-
BaseToolCallPart,
|
|
33
|
-
BuiltinToolCallPart,
|
|
34
|
-
BuiltinToolReturnPart,
|
|
35
|
-
FunctionToolResultEvent,
|
|
36
|
-
ModelMessage,
|
|
37
|
-
ModelRequest,
|
|
38
|
-
ModelRequestPart,
|
|
39
|
-
ModelResponse,
|
|
40
|
-
ModelResponsePart,
|
|
41
|
-
ModelResponseStreamEvent,
|
|
42
|
-
PartDeltaEvent,
|
|
43
|
-
PartStartEvent,
|
|
44
|
-
SystemPromptPart,
|
|
45
|
-
TextPart,
|
|
46
|
-
TextPartDelta,
|
|
47
|
-
ThinkingPart,
|
|
48
|
-
ThinkingPartDelta,
|
|
49
|
-
ToolCallPart,
|
|
50
|
-
ToolCallPartDelta,
|
|
51
|
-
ToolReturnPart,
|
|
52
|
-
UserPromptPart,
|
|
53
|
-
)
|
|
14
|
+
from . import DeferredToolResults
|
|
15
|
+
from .agent import AbstractAgent
|
|
16
|
+
from .messages import ModelMessage
|
|
54
17
|
from .models import KnownModelName, Model
|
|
55
|
-
from .output import
|
|
18
|
+
from .output import OutputSpec
|
|
56
19
|
from .settings import ModelSettings
|
|
57
|
-
from .tools import AgentDepsT
|
|
20
|
+
from .tools import AgentDepsT
|
|
58
21
|
from .toolsets import AbstractToolset
|
|
59
|
-
from .toolsets.external import ExternalToolset
|
|
60
22
|
from .usage import RunUsage, UsageLimits
|
|
61
23
|
|
|
62
24
|
try:
|
|
63
|
-
from ag_ui.core import
|
|
64
|
-
|
|
65
|
-
BaseEvent,
|
|
66
|
-
DeveloperMessage,
|
|
67
|
-
EventType,
|
|
68
|
-
Message,
|
|
69
|
-
RunAgentInput,
|
|
70
|
-
RunErrorEvent,
|
|
71
|
-
RunFinishedEvent,
|
|
72
|
-
RunStartedEvent,
|
|
73
|
-
State,
|
|
74
|
-
SystemMessage,
|
|
75
|
-
TextMessageContentEvent,
|
|
76
|
-
TextMessageEndEvent,
|
|
77
|
-
TextMessageStartEvent,
|
|
78
|
-
ThinkingEndEvent,
|
|
79
|
-
ThinkingStartEvent,
|
|
80
|
-
ThinkingTextMessageContentEvent,
|
|
81
|
-
ThinkingTextMessageEndEvent,
|
|
82
|
-
ThinkingTextMessageStartEvent,
|
|
83
|
-
Tool as AGUITool,
|
|
84
|
-
ToolCallArgsEvent,
|
|
85
|
-
ToolCallEndEvent,
|
|
86
|
-
ToolCallResultEvent,
|
|
87
|
-
ToolCallStartEvent,
|
|
88
|
-
ToolMessage,
|
|
89
|
-
UserMessage,
|
|
90
|
-
)
|
|
91
|
-
from ag_ui.encoder import EventEncoder
|
|
92
|
-
except ImportError as e: # pragma: no cover
|
|
93
|
-
raise ImportError(
|
|
94
|
-
'Please install the `ag-ui-protocol` package to use `Agent.to_ag_ui()` method, '
|
|
95
|
-
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
96
|
-
) from e
|
|
97
|
-
|
|
98
|
-
try:
|
|
99
|
-
from starlette.applications import Starlette
|
|
100
|
-
from starlette.middleware import Middleware
|
|
25
|
+
from ag_ui.core import BaseEvent
|
|
26
|
+
from ag_ui.core.types import RunAgentInput
|
|
101
27
|
from starlette.requests import Request
|
|
102
|
-
from starlette.responses import Response
|
|
103
|
-
|
|
104
|
-
from
|
|
28
|
+
from starlette.responses import Response
|
|
29
|
+
|
|
30
|
+
from .ui import SSE_CONTENT_TYPE, OnCompleteFunc, StateDeps, StateHandler
|
|
31
|
+
from .ui.ag_ui import AGUIAdapter
|
|
32
|
+
from .ui.ag_ui.app import AGUIApp
|
|
105
33
|
except ImportError as e: # pragma: no cover
|
|
106
34
|
raise ImportError(
|
|
107
|
-
'Please install the `starlette`
|
|
35
|
+
'Please install the `ag-ui-protocol` and `starlette` packages to use `AGUIAdapter`, '
|
|
108
36
|
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
109
37
|
) from e
|
|
110
38
|
|
|
@@ -119,113 +47,14 @@ __all__ = [
|
|
|
119
47
|
'run_ag_ui',
|
|
120
48
|
]
|
|
121
49
|
|
|
122
|
-
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
|
|
123
|
-
"""Content type header value for Server-Sent Events (SSE)."""
|
|
124
|
-
|
|
125
|
-
OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
|
|
126
|
-
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
|
|
127
|
-
|
|
128
|
-
_BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
132
|
-
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
|
|
133
|
-
|
|
134
|
-
def __init__(
|
|
135
|
-
self,
|
|
136
|
-
agent: AbstractAgent[AgentDepsT, OutputDataT],
|
|
137
|
-
*,
|
|
138
|
-
# Agent.iter parameters.
|
|
139
|
-
output_type: OutputSpec[Any] | None = None,
|
|
140
|
-
model: Model | KnownModelName | str | None = None,
|
|
141
|
-
deps: AgentDepsT = None,
|
|
142
|
-
model_settings: ModelSettings | None = None,
|
|
143
|
-
usage_limits: UsageLimits | None = None,
|
|
144
|
-
usage: RunUsage | None = None,
|
|
145
|
-
infer_name: bool = True,
|
|
146
|
-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
147
|
-
# Starlette parameters.
|
|
148
|
-
debug: bool = False,
|
|
149
|
-
routes: Sequence[BaseRoute] | None = None,
|
|
150
|
-
middleware: Sequence[Middleware] | None = None,
|
|
151
|
-
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
|
152
|
-
on_startup: Sequence[Callable[[], Any]] | None = None,
|
|
153
|
-
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
154
|
-
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
|
|
155
|
-
) -> None:
|
|
156
|
-
"""An ASGI application that handles every AG-UI request by running the agent.
|
|
157
|
-
|
|
158
|
-
Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
|
|
159
|
-
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
|
|
160
|
-
To provide different `deps` for each request (e.g. based on the authenticated user),
|
|
161
|
-
use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
|
|
162
|
-
[`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
agent: The agent to run.
|
|
166
|
-
|
|
167
|
-
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
168
|
-
no output validators since output validators would expect an argument that matches the agent's
|
|
169
|
-
output type.
|
|
170
|
-
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
171
|
-
deps: Optional dependencies to use for this run.
|
|
172
|
-
model_settings: Optional settings to use for this model's request.
|
|
173
|
-
usage_limits: Optional limits on model request count or token usage.
|
|
174
|
-
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
175
|
-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
176
|
-
toolsets: Optional additional toolsets for this run.
|
|
177
|
-
|
|
178
|
-
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
179
|
-
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
180
|
-
middleware: A list of middleware to run for every request. A starlette application will always
|
|
181
|
-
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
|
|
182
|
-
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
|
|
183
|
-
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
|
|
184
|
-
exception cases occurring in the routing or endpoints.
|
|
185
|
-
exception_handlers: A mapping of either integer status codes, or exception class types onto
|
|
186
|
-
callables which handle the exceptions. Exception handler callables should be of the form
|
|
187
|
-
`handler(request, exc) -> response` and may be either standard functions, or async functions.
|
|
188
|
-
on_startup: A list of callables to run on application startup. Startup handler callables do not
|
|
189
|
-
take any arguments, and may be either standard functions, or async functions.
|
|
190
|
-
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
|
|
191
|
-
not take any arguments, and may be either standard functions, or async functions.
|
|
192
|
-
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
|
|
193
|
-
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
|
|
194
|
-
the other, not both.
|
|
195
|
-
"""
|
|
196
|
-
super().__init__(
|
|
197
|
-
debug=debug,
|
|
198
|
-
routes=routes,
|
|
199
|
-
middleware=middleware,
|
|
200
|
-
exception_handlers=exception_handlers,
|
|
201
|
-
on_startup=on_startup,
|
|
202
|
-
on_shutdown=on_shutdown,
|
|
203
|
-
lifespan=lifespan,
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
async def endpoint(request: Request) -> Response:
|
|
207
|
-
"""Endpoint to run the agent with the provided input data."""
|
|
208
|
-
return await handle_ag_ui_request(
|
|
209
|
-
agent,
|
|
210
|
-
request,
|
|
211
|
-
output_type=output_type,
|
|
212
|
-
model=model,
|
|
213
|
-
deps=deps,
|
|
214
|
-
model_settings=model_settings,
|
|
215
|
-
usage_limits=usage_limits,
|
|
216
|
-
usage=usage,
|
|
217
|
-
infer_name=infer_name,
|
|
218
|
-
toolsets=toolsets,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
|
|
222
|
-
|
|
223
50
|
|
|
224
51
|
async def handle_ag_ui_request(
|
|
225
52
|
agent: AbstractAgent[AgentDepsT, Any],
|
|
226
53
|
request: Request,
|
|
227
54
|
*,
|
|
228
55
|
output_type: OutputSpec[Any] | None = None,
|
|
56
|
+
message_history: Sequence[ModelMessage] | None = None,
|
|
57
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
229
58
|
model: Model | KnownModelName | str | None = None,
|
|
230
59
|
deps: AgentDepsT = None,
|
|
231
60
|
model_settings: ModelSettings | None = None,
|
|
@@ -233,7 +62,7 @@ async def handle_ag_ui_request(
|
|
|
233
62
|
usage: RunUsage | None = None,
|
|
234
63
|
infer_name: bool = True,
|
|
235
64
|
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
236
|
-
on_complete: OnCompleteFunc | None = None,
|
|
65
|
+
on_complete: OnCompleteFunc[BaseEvent] | None = None,
|
|
237
66
|
) -> Response:
|
|
238
67
|
"""Handle an AG-UI request by running the agent and returning a streaming response.
|
|
239
68
|
|
|
@@ -243,6 +72,8 @@ async def handle_ag_ui_request(
|
|
|
243
72
|
|
|
244
73
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
245
74
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
75
|
+
message_history: History of the conversation so far.
|
|
76
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
246
77
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
247
78
|
deps: Optional dependencies to use for this run.
|
|
248
79
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -256,41 +87,31 @@ async def handle_ag_ui_request(
|
|
|
256
87
|
Returns:
|
|
257
88
|
A streaming Starlette response with AG-UI protocol events.
|
|
258
89
|
"""
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
accept,
|
|
274
|
-
output_type=output_type,
|
|
275
|
-
model=model,
|
|
276
|
-
deps=deps,
|
|
277
|
-
model_settings=model_settings,
|
|
278
|
-
usage_limits=usage_limits,
|
|
279
|
-
usage=usage,
|
|
280
|
-
infer_name=infer_name,
|
|
281
|
-
toolsets=toolsets,
|
|
282
|
-
on_complete=on_complete,
|
|
283
|
-
),
|
|
284
|
-
media_type=accept,
|
|
90
|
+
return await AGUIAdapter[AgentDepsT].dispatch_request(
|
|
91
|
+
request,
|
|
92
|
+
agent=agent,
|
|
93
|
+
deps=deps,
|
|
94
|
+
output_type=output_type,
|
|
95
|
+
message_history=message_history,
|
|
96
|
+
deferred_tool_results=deferred_tool_results,
|
|
97
|
+
model=model,
|
|
98
|
+
model_settings=model_settings,
|
|
99
|
+
usage_limits=usage_limits,
|
|
100
|
+
usage=usage,
|
|
101
|
+
infer_name=infer_name,
|
|
102
|
+
toolsets=toolsets,
|
|
103
|
+
on_complete=on_complete,
|
|
285
104
|
)
|
|
286
105
|
|
|
287
106
|
|
|
288
|
-
|
|
107
|
+
def run_ag_ui(
|
|
289
108
|
agent: AbstractAgent[AgentDepsT, Any],
|
|
290
109
|
run_input: RunAgentInput,
|
|
291
110
|
accept: str = SSE_CONTENT_TYPE,
|
|
292
111
|
*,
|
|
293
112
|
output_type: OutputSpec[Any] | None = None,
|
|
113
|
+
message_history: Sequence[ModelMessage] | None = None,
|
|
114
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
294
115
|
model: Model | KnownModelName | str | None = None,
|
|
295
116
|
deps: AgentDepsT = None,
|
|
296
117
|
model_settings: ModelSettings | None = None,
|
|
@@ -298,7 +119,7 @@ async def run_ag_ui(
|
|
|
298
119
|
usage: RunUsage | None = None,
|
|
299
120
|
infer_name: bool = True,
|
|
300
121
|
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
301
|
-
on_complete: OnCompleteFunc | None = None,
|
|
122
|
+
on_complete: OnCompleteFunc[BaseEvent] | None = None,
|
|
302
123
|
) -> AsyncIterator[str]:
|
|
303
124
|
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
|
|
304
125
|
|
|
@@ -309,6 +130,8 @@ async def run_ag_ui(
|
|
|
309
130
|
|
|
310
131
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
311
132
|
output validators since output validators would expect an argument that matches the agent's output type.
|
|
133
|
+
message_history: History of the conversation so far.
|
|
134
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
312
135
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
313
136
|
deps: Optional dependencies to use for this run.
|
|
314
137
|
model_settings: Optional settings to use for this model's request.
|
|
@@ -322,50 +145,12 @@ async def run_ag_ui(
|
|
|
322
145
|
Yields:
|
|
323
146
|
Streaming event chunks encoded as strings according to the accept header value.
|
|
324
147
|
"""
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
toolsets = [*toolsets, toolset] if toolsets else [toolset]
|
|
332
|
-
|
|
333
|
-
try:
|
|
334
|
-
yield encoder.encode(
|
|
335
|
-
RunStartedEvent(
|
|
336
|
-
thread_id=run_input.thread_id,
|
|
337
|
-
run_id=run_input.run_id,
|
|
338
|
-
),
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
if not run_input.messages:
|
|
342
|
-
raise _NoMessagesError
|
|
343
|
-
|
|
344
|
-
raw_state: dict[str, Any] = run_input.state or {}
|
|
345
|
-
if isinstance(deps, StateHandler):
|
|
346
|
-
if isinstance(deps.state, BaseModel):
|
|
347
|
-
try:
|
|
348
|
-
state = type(deps.state).model_validate(raw_state)
|
|
349
|
-
except ValidationError as e: # pragma: no cover
|
|
350
|
-
raise _InvalidStateError from e
|
|
351
|
-
else:
|
|
352
|
-
state = raw_state
|
|
353
|
-
|
|
354
|
-
deps = replace(deps, state=state)
|
|
355
|
-
elif raw_state:
|
|
356
|
-
raise UserError(
|
|
357
|
-
f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
|
|
358
|
-
)
|
|
359
|
-
else:
|
|
360
|
-
# `deps` not being a `StateHandler` is OK if there is no state.
|
|
361
|
-
pass
|
|
362
|
-
|
|
363
|
-
messages = _messages_from_ag_ui(run_input.messages)
|
|
364
|
-
|
|
365
|
-
async with agent.iter(
|
|
366
|
-
user_prompt=None,
|
|
367
|
-
output_type=[output_type or agent.output_type, DeferredToolRequests],
|
|
368
|
-
message_history=messages,
|
|
148
|
+
adapter = AGUIAdapter(agent=agent, run_input=run_input, accept=accept)
|
|
149
|
+
return adapter.encode_stream(
|
|
150
|
+
adapter.run_stream(
|
|
151
|
+
output_type=output_type,
|
|
152
|
+
message_history=message_history,
|
|
153
|
+
deferred_tool_results=deferred_tool_results,
|
|
369
154
|
model=model,
|
|
370
155
|
deps=deps,
|
|
371
156
|
model_settings=model_settings,
|
|
@@ -373,437 +158,6 @@ async def run_ag_ui(
|
|
|
373
158
|
usage=usage,
|
|
374
159
|
infer_name=infer_name,
|
|
375
160
|
toolsets=toolsets,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
yield encoder.encode(event)
|
|
379
|
-
|
|
380
|
-
if on_complete is not None and run.result is not None:
|
|
381
|
-
if _utils.is_async_callable(on_complete):
|
|
382
|
-
await on_complete(run.result)
|
|
383
|
-
else:
|
|
384
|
-
await _utils.run_in_executor(on_complete, run.result)
|
|
385
|
-
except _RunError as e:
|
|
386
|
-
yield encoder.encode(
|
|
387
|
-
RunErrorEvent(message=e.message, code=e.code),
|
|
388
|
-
)
|
|
389
|
-
except Exception as e:
|
|
390
|
-
yield encoder.encode(
|
|
391
|
-
RunErrorEvent(message=str(e)),
|
|
392
|
-
)
|
|
393
|
-
raise e
|
|
394
|
-
else:
|
|
395
|
-
yield encoder.encode(
|
|
396
|
-
RunFinishedEvent(
|
|
397
|
-
thread_id=run_input.thread_id,
|
|
398
|
-
run_id=run_input.run_id,
|
|
399
|
-
),
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEvent]:
|
|
404
|
-
"""Run the agent streaming responses using AG-UI protocol events.
|
|
405
|
-
|
|
406
|
-
Args:
|
|
407
|
-
run: The agent run to process.
|
|
408
|
-
|
|
409
|
-
Yields:
|
|
410
|
-
AG-UI Server-Sent Events (SSE).
|
|
411
|
-
"""
|
|
412
|
-
async for node in run:
|
|
413
|
-
stream_ctx = _RequestStreamContext()
|
|
414
|
-
if isinstance(node, ModelRequestNode):
|
|
415
|
-
async with node.stream(run.ctx) as request_stream:
|
|
416
|
-
async for agent_event in request_stream:
|
|
417
|
-
async for msg in _handle_model_request_event(stream_ctx, agent_event):
|
|
418
|
-
yield msg
|
|
419
|
-
|
|
420
|
-
if stream_ctx.part_end: # pragma: no branch
|
|
421
|
-
yield stream_ctx.part_end
|
|
422
|
-
stream_ctx.part_end = None
|
|
423
|
-
if stream_ctx.thinking:
|
|
424
|
-
yield ThinkingEndEvent(
|
|
425
|
-
type=EventType.THINKING_END,
|
|
426
|
-
)
|
|
427
|
-
stream_ctx.thinking = False
|
|
428
|
-
elif isinstance(node, CallToolsNode):
|
|
429
|
-
async with node.stream(run.ctx) as handle_stream:
|
|
430
|
-
async for event in handle_stream:
|
|
431
|
-
if isinstance(event, FunctionToolResultEvent):
|
|
432
|
-
async for msg in _handle_tool_result_event(stream_ctx, event):
|
|
433
|
-
yield msg
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
async def _handle_model_request_event( # noqa: C901
|
|
437
|
-
stream_ctx: _RequestStreamContext,
|
|
438
|
-
agent_event: ModelResponseStreamEvent,
|
|
439
|
-
) -> AsyncIterator[BaseEvent]:
|
|
440
|
-
"""Handle an agent event and yield AG-UI protocol events.
|
|
441
|
-
|
|
442
|
-
Args:
|
|
443
|
-
stream_ctx: The request stream context to manage state.
|
|
444
|
-
agent_event: The agent event to process.
|
|
445
|
-
|
|
446
|
-
Yields:
|
|
447
|
-
AG-UI Server-Sent Events (SSE) based on the agent event.
|
|
448
|
-
"""
|
|
449
|
-
if isinstance(agent_event, PartStartEvent):
|
|
450
|
-
if stream_ctx.part_end:
|
|
451
|
-
# End the previous part.
|
|
452
|
-
yield stream_ctx.part_end
|
|
453
|
-
stream_ctx.part_end = None
|
|
454
|
-
|
|
455
|
-
part = agent_event.part
|
|
456
|
-
if isinstance(part, ThinkingPart): # pragma: no branch
|
|
457
|
-
if not stream_ctx.thinking:
|
|
458
|
-
yield ThinkingStartEvent(
|
|
459
|
-
type=EventType.THINKING_START,
|
|
460
|
-
)
|
|
461
|
-
stream_ctx.thinking = True
|
|
462
|
-
|
|
463
|
-
if part.content:
|
|
464
|
-
yield ThinkingTextMessageStartEvent(
|
|
465
|
-
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
466
|
-
)
|
|
467
|
-
yield ThinkingTextMessageContentEvent(
|
|
468
|
-
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
469
|
-
delta=part.content,
|
|
470
|
-
)
|
|
471
|
-
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
472
|
-
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
473
|
-
)
|
|
474
|
-
else:
|
|
475
|
-
if stream_ctx.thinking:
|
|
476
|
-
yield ThinkingEndEvent(
|
|
477
|
-
type=EventType.THINKING_END,
|
|
478
|
-
)
|
|
479
|
-
stream_ctx.thinking = False
|
|
480
|
-
|
|
481
|
-
if isinstance(part, TextPart):
|
|
482
|
-
message_id = stream_ctx.new_message_id()
|
|
483
|
-
yield TextMessageStartEvent(
|
|
484
|
-
message_id=message_id,
|
|
485
|
-
)
|
|
486
|
-
if part.content: # pragma: no branch
|
|
487
|
-
yield TextMessageContentEvent(
|
|
488
|
-
message_id=message_id,
|
|
489
|
-
delta=part.content,
|
|
490
|
-
)
|
|
491
|
-
stream_ctx.part_end = TextMessageEndEvent(
|
|
492
|
-
message_id=message_id,
|
|
493
|
-
)
|
|
494
|
-
elif isinstance(part, BaseToolCallPart):
|
|
495
|
-
tool_call_id = part.tool_call_id
|
|
496
|
-
if isinstance(part, BuiltinToolCallPart):
|
|
497
|
-
builtin_tool_call_id = '|'.join(
|
|
498
|
-
[_BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id]
|
|
499
|
-
)
|
|
500
|
-
stream_ctx.builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
|
|
501
|
-
tool_call_id = builtin_tool_call_id
|
|
502
|
-
|
|
503
|
-
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
|
|
504
|
-
yield ToolCallStartEvent(
|
|
505
|
-
tool_call_id=tool_call_id,
|
|
506
|
-
tool_call_name=part.tool_name,
|
|
507
|
-
parent_message_id=message_id,
|
|
508
|
-
)
|
|
509
|
-
if part.args:
|
|
510
|
-
yield ToolCallArgsEvent(
|
|
511
|
-
tool_call_id=tool_call_id,
|
|
512
|
-
delta=part.args_as_json_str(),
|
|
513
|
-
)
|
|
514
|
-
stream_ctx.part_end = ToolCallEndEvent(
|
|
515
|
-
tool_call_id=tool_call_id,
|
|
516
|
-
)
|
|
517
|
-
elif isinstance(part, BuiltinToolReturnPart): # pragma: no branch
|
|
518
|
-
tool_call_id = stream_ctx.builtin_tool_call_ids[part.tool_call_id]
|
|
519
|
-
yield ToolCallResultEvent(
|
|
520
|
-
message_id=stream_ctx.new_message_id(),
|
|
521
|
-
type=EventType.TOOL_CALL_RESULT,
|
|
522
|
-
role='tool',
|
|
523
|
-
tool_call_id=tool_call_id,
|
|
524
|
-
content=part.model_response_str(),
|
|
525
|
-
)
|
|
526
|
-
|
|
527
|
-
elif isinstance(agent_event, PartDeltaEvent):
|
|
528
|
-
delta = agent_event.delta
|
|
529
|
-
if isinstance(delta, TextPartDelta):
|
|
530
|
-
if delta.content_delta: # pragma: no branch
|
|
531
|
-
yield TextMessageContentEvent(
|
|
532
|
-
message_id=stream_ctx.message_id,
|
|
533
|
-
delta=delta.content_delta,
|
|
534
|
-
)
|
|
535
|
-
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
536
|
-
tool_call_id = delta.tool_call_id
|
|
537
|
-
assert tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
538
|
-
if tool_call_id in stream_ctx.builtin_tool_call_ids:
|
|
539
|
-
tool_call_id = stream_ctx.builtin_tool_call_ids[tool_call_id]
|
|
540
|
-
yield ToolCallArgsEvent(
|
|
541
|
-
tool_call_id=tool_call_id,
|
|
542
|
-
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
543
|
-
)
|
|
544
|
-
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
545
|
-
if delta.content_delta: # pragma: no branch
|
|
546
|
-
if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
|
|
547
|
-
yield ThinkingTextMessageStartEvent(
|
|
548
|
-
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
549
|
-
)
|
|
550
|
-
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
551
|
-
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
552
|
-
)
|
|
553
|
-
|
|
554
|
-
yield ThinkingTextMessageContentEvent(
|
|
555
|
-
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
556
|
-
delta=delta.content_delta,
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
async def _handle_tool_result_event(
|
|
561
|
-
stream_ctx: _RequestStreamContext,
|
|
562
|
-
event: FunctionToolResultEvent,
|
|
563
|
-
) -> AsyncIterator[BaseEvent]:
|
|
564
|
-
"""Convert a tool call result to AG-UI events.
|
|
565
|
-
|
|
566
|
-
Args:
|
|
567
|
-
stream_ctx: The request stream context to manage state.
|
|
568
|
-
event: The tool call result event to process.
|
|
569
|
-
|
|
570
|
-
Yields:
|
|
571
|
-
AG-UI Server-Sent Events (SSE).
|
|
572
|
-
"""
|
|
573
|
-
result = event.result
|
|
574
|
-
if not isinstance(result, ToolReturnPart):
|
|
575
|
-
return
|
|
576
|
-
|
|
577
|
-
yield ToolCallResultEvent(
|
|
578
|
-
message_id=stream_ctx.new_message_id(),
|
|
579
|
-
type=EventType.TOOL_CALL_RESULT,
|
|
580
|
-
role='tool',
|
|
581
|
-
tool_call_id=result.tool_call_id,
|
|
582
|
-
content=result.model_response_str(),
|
|
161
|
+
on_complete=on_complete,
|
|
162
|
+
),
|
|
583
163
|
)
|
|
584
|
-
|
|
585
|
-
# Now check for AG-UI events returned by the tool calls.
|
|
586
|
-
possible_event = result.metadata or result.content
|
|
587
|
-
if isinstance(possible_event, BaseEvent):
|
|
588
|
-
yield possible_event
|
|
589
|
-
elif isinstance(possible_event, str | bytes): # pragma: no branch
|
|
590
|
-
# Avoid iterable check for strings and bytes.
|
|
591
|
-
pass
|
|
592
|
-
elif isinstance(possible_event, Iterable): # pragma: no branch
|
|
593
|
-
for item in possible_event: # type: ignore[reportUnknownMemberType]
|
|
594
|
-
if isinstance(item, BaseEvent): # pragma: no branch
|
|
595
|
-
yield item
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
599
|
-
"""Convert a AG-UI history to a Pydantic AI one."""
|
|
600
|
-
result: list[ModelMessage] = []
|
|
601
|
-
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
602
|
-
request_parts: list[ModelRequestPart] | None = None
|
|
603
|
-
response_parts: list[ModelResponsePart] | None = None
|
|
604
|
-
for msg in messages:
|
|
605
|
-
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
|
|
606
|
-
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
607
|
-
):
|
|
608
|
-
if request_parts is None:
|
|
609
|
-
request_parts = []
|
|
610
|
-
result.append(ModelRequest(parts=request_parts))
|
|
611
|
-
response_parts = None
|
|
612
|
-
|
|
613
|
-
if isinstance(msg, UserMessage):
|
|
614
|
-
request_parts.append(UserPromptPart(content=msg.content))
|
|
615
|
-
elif isinstance(msg, SystemMessage | DeveloperMessage):
|
|
616
|
-
request_parts.append(SystemPromptPart(content=msg.content))
|
|
617
|
-
else:
|
|
618
|
-
tool_call_id = msg.tool_call_id
|
|
619
|
-
tool_name = tool_calls.get(tool_call_id)
|
|
620
|
-
if tool_name is None: # pragma: no cover
|
|
621
|
-
raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
|
|
622
|
-
|
|
623
|
-
request_parts.append(
|
|
624
|
-
ToolReturnPart(
|
|
625
|
-
tool_name=tool_name,
|
|
626
|
-
content=msg.content,
|
|
627
|
-
tool_call_id=tool_call_id,
|
|
628
|
-
)
|
|
629
|
-
)
|
|
630
|
-
|
|
631
|
-
elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
|
|
632
|
-
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
633
|
-
):
|
|
634
|
-
if response_parts is None:
|
|
635
|
-
response_parts = []
|
|
636
|
-
result.append(ModelResponse(parts=response_parts))
|
|
637
|
-
request_parts = None
|
|
638
|
-
|
|
639
|
-
if isinstance(msg, AssistantMessage):
|
|
640
|
-
if msg.content:
|
|
641
|
-
response_parts.append(TextPart(content=msg.content))
|
|
642
|
-
|
|
643
|
-
if msg.tool_calls:
|
|
644
|
-
for tool_call in msg.tool_calls:
|
|
645
|
-
tool_call_id = tool_call.id
|
|
646
|
-
tool_name = tool_call.function.name
|
|
647
|
-
tool_calls[tool_call_id] = tool_name
|
|
648
|
-
|
|
649
|
-
if tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX):
|
|
650
|
-
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
651
|
-
response_parts.append(
|
|
652
|
-
BuiltinToolCallPart(
|
|
653
|
-
tool_name=tool_name,
|
|
654
|
-
args=tool_call.function.arguments,
|
|
655
|
-
tool_call_id=tool_call_id,
|
|
656
|
-
provider_name=provider_name,
|
|
657
|
-
)
|
|
658
|
-
)
|
|
659
|
-
else:
|
|
660
|
-
response_parts.append(
|
|
661
|
-
ToolCallPart(
|
|
662
|
-
tool_name=tool_name,
|
|
663
|
-
tool_call_id=tool_call_id,
|
|
664
|
-
args=tool_call.function.arguments,
|
|
665
|
-
)
|
|
666
|
-
)
|
|
667
|
-
else:
|
|
668
|
-
tool_call_id = msg.tool_call_id
|
|
669
|
-
tool_name = tool_calls.get(tool_call_id)
|
|
670
|
-
if tool_name is None: # pragma: no cover
|
|
671
|
-
raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
|
|
672
|
-
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
673
|
-
|
|
674
|
-
response_parts.append(
|
|
675
|
-
BuiltinToolReturnPart(
|
|
676
|
-
tool_name=tool_name,
|
|
677
|
-
content=msg.content,
|
|
678
|
-
tool_call_id=tool_call_id,
|
|
679
|
-
provider_name=provider_name,
|
|
680
|
-
)
|
|
681
|
-
)
|
|
682
|
-
|
|
683
|
-
return result
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
@runtime_checkable
|
|
687
|
-
class StateHandler(Protocol):
|
|
688
|
-
"""Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
|
|
689
|
-
|
|
690
|
-
# Has to be a dataclass so we can use `replace` to update the state.
|
|
691
|
-
# From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352
|
|
692
|
-
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
|
693
|
-
|
|
694
|
-
@property
|
|
695
|
-
def state(self) -> State:
|
|
696
|
-
"""Get the current state of the agent run."""
|
|
697
|
-
...
|
|
698
|
-
|
|
699
|
-
@state.setter
|
|
700
|
-
def state(self, state: State) -> None:
|
|
701
|
-
"""Set the state of the agent run.
|
|
702
|
-
|
|
703
|
-
This method is called to update the state of the agent run with the
|
|
704
|
-
provided state.
|
|
705
|
-
|
|
706
|
-
Args:
|
|
707
|
-
state: The run state.
|
|
708
|
-
|
|
709
|
-
Raises:
|
|
710
|
-
InvalidStateError: If `state` does not match the expected model.
|
|
711
|
-
"""
|
|
712
|
-
...
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
StateT = TypeVar('StateT', bound=BaseModel)
|
|
716
|
-
"""Type variable for the state type, which must be a subclass of `BaseModel`."""
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
@dataclass
|
|
720
|
-
class StateDeps(Generic[StateT]):
|
|
721
|
-
"""Provides AG-UI state management.
|
|
722
|
-
|
|
723
|
-
This class is used to manage the state of an agent run. It allows setting
|
|
724
|
-
the state of the agent run with a specific type of state model, which must
|
|
725
|
-
be a subclass of `BaseModel`.
|
|
726
|
-
|
|
727
|
-
The state is set using the `state` setter by the `Adapter` when the run starts.
|
|
728
|
-
|
|
729
|
-
Implements the `StateHandler` protocol.
|
|
730
|
-
"""
|
|
731
|
-
|
|
732
|
-
state: StateT
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
@dataclass(repr=False)
|
|
736
|
-
class _RequestStreamContext:
|
|
737
|
-
"""Data class to hold request stream context."""
|
|
738
|
-
|
|
739
|
-
message_id: str = ''
|
|
740
|
-
part_end: BaseEvent | None = None
|
|
741
|
-
thinking: bool = False
|
|
742
|
-
builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
|
|
743
|
-
|
|
744
|
-
def new_message_id(self) -> str:
|
|
745
|
-
"""Generate a new message ID for the request stream.
|
|
746
|
-
|
|
747
|
-
Assigns a new UUID to the `message_id` and returns it.
|
|
748
|
-
|
|
749
|
-
Returns:
|
|
750
|
-
A new message ID.
|
|
751
|
-
"""
|
|
752
|
-
self.message_id = str(uuid.uuid4())
|
|
753
|
-
return self.message_id
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
@dataclass
|
|
757
|
-
class _RunError(Exception):
|
|
758
|
-
"""Exception raised for errors during agent runs."""
|
|
759
|
-
|
|
760
|
-
message: str
|
|
761
|
-
code: str
|
|
762
|
-
|
|
763
|
-
def __str__(self) -> str: # pragma: no cover
|
|
764
|
-
return self.message
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
@dataclass
|
|
768
|
-
class _NoMessagesError(_RunError):
|
|
769
|
-
"""Exception raised when no messages are found in the input."""
|
|
770
|
-
|
|
771
|
-
message: str = 'no messages found in the input'
|
|
772
|
-
code: str = 'no_messages'
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
@dataclass
|
|
776
|
-
class _InvalidStateError(_RunError, ValidationError):
|
|
777
|
-
"""Exception raised when an invalid state is provided."""
|
|
778
|
-
|
|
779
|
-
message: str = 'invalid state provided'
|
|
780
|
-
code: str = 'invalid_state'
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
class _ToolCallNotFoundError(_RunError, ValueError):
|
|
784
|
-
"""Exception raised when an tool result is present without a matching call."""
|
|
785
|
-
|
|
786
|
-
def __init__(self, tool_call_id: str) -> None:
|
|
787
|
-
"""Initialize the exception with the tool call ID."""
|
|
788
|
-
super().__init__( # pragma: no cover
|
|
789
|
-
message=f'Tool call with ID {tool_call_id} not found in the history.',
|
|
790
|
-
code='tool_call_not_found',
|
|
791
|
-
)
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
|
|
795
|
-
def __init__(self, tools: list[AGUITool]):
|
|
796
|
-
super().__init__(
|
|
797
|
-
[
|
|
798
|
-
ToolDefinition(
|
|
799
|
-
name=tool.name,
|
|
800
|
-
description=tool.description,
|
|
801
|
-
parameters_json_schema=tool.parameters,
|
|
802
|
-
)
|
|
803
|
-
for tool in tools
|
|
804
|
-
]
|
|
805
|
-
)
|
|
806
|
-
|
|
807
|
-
@property
|
|
808
|
-
def label(self) -> str:
|
|
809
|
-
return 'the AG-UI frontend tools' # pragma: no cover
|