pydantic-ai-slim 0.4.10__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_function_schema.py +7 -4
- pydantic_ai/_parts_manager.py +8 -9
- pydantic_ai/_thinking_part.py +7 -12
- pydantic_ai/ag_ui.py +346 -316
- pydantic_ai/agent.py +7 -5
- pydantic_ai/messages.py +37 -10
- pydantic_ai/models/__init__.py +2 -2
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/gemini.py +1 -1
- pydantic_ai/models/google.py +1 -1
- pydantic_ai/models/groq.py +7 -3
- pydantic_ai/models/huggingface.py +7 -2
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +7 -3
- pydantic_ai/models/test.py +3 -1
- pydantic_ai/profiles/__init__.py +3 -0
- pydantic_ai/profiles/anthropic.py +1 -1
- pydantic_ai/profiles/openai.py +22 -12
- pydantic_ai/tools.py +13 -5
- {pydantic_ai_slim-0.4.10.dist-info → pydantic_ai_slim-0.5.0.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-0.4.10.dist-info → pydantic_ai_slim-0.5.0.dist-info}/RECORD +24 -24
- {pydantic_ai_slim-0.4.10.dist-info → pydantic_ai_slim-0.5.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.10.dist-info → pydantic_ai_slim-0.5.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.10.dist-info → pydantic_ai_slim-0.5.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py
CHANGED
|
@@ -8,11 +8,10 @@ from __future__ import annotations
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import uuid
|
|
11
|
-
from collections.abc import Iterable, Mapping, Sequence
|
|
12
|
-
from dataclasses import Field, dataclass,
|
|
11
|
+
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
|
|
12
|
+
from dataclasses import Field, dataclass, replace
|
|
13
13
|
from http import HTTPStatus
|
|
14
14
|
from typing import (
|
|
15
|
-
TYPE_CHECKING,
|
|
16
15
|
Any,
|
|
17
16
|
Callable,
|
|
18
17
|
ClassVar,
|
|
@@ -23,10 +22,36 @@ from typing import (
|
|
|
23
22
|
runtime_checkable,
|
|
24
23
|
)
|
|
25
24
|
|
|
26
|
-
from
|
|
25
|
+
from pydantic import BaseModel, ValidationError
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
28
|
+
from .agent import Agent, AgentRun
|
|
29
|
+
from .exceptions import UserError
|
|
30
|
+
from .messages import (
|
|
31
|
+
AgentStreamEvent,
|
|
32
|
+
FunctionToolResultEvent,
|
|
33
|
+
ModelMessage,
|
|
34
|
+
ModelRequest,
|
|
35
|
+
ModelResponse,
|
|
36
|
+
PartDeltaEvent,
|
|
37
|
+
PartStartEvent,
|
|
38
|
+
SystemPromptPart,
|
|
39
|
+
TextPart,
|
|
40
|
+
TextPartDelta,
|
|
41
|
+
ThinkingPart,
|
|
42
|
+
ThinkingPartDelta,
|
|
43
|
+
ToolCallPart,
|
|
44
|
+
ToolCallPartDelta,
|
|
45
|
+
ToolReturnPart,
|
|
46
|
+
UserPromptPart,
|
|
47
|
+
)
|
|
48
|
+
from .models import KnownModelName, Model
|
|
49
|
+
from .output import DeferredToolCalls, OutputDataT, OutputSpec
|
|
50
|
+
from .settings import ModelSettings
|
|
51
|
+
from .tools import AgentDepsT, ToolDefinition
|
|
52
|
+
from .toolsets import AbstractToolset
|
|
53
|
+
from .toolsets.deferred import DeferredToolset
|
|
54
|
+
from .usage import Usage, UsageLimits
|
|
30
55
|
|
|
31
56
|
try:
|
|
32
57
|
from ag_ui.core import (
|
|
@@ -74,43 +99,13 @@ except ImportError as e: # pragma: no cover
|
|
|
74
99
|
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
75
100
|
) from e
|
|
76
101
|
|
|
77
|
-
from collections.abc import AsyncGenerator
|
|
78
|
-
|
|
79
|
-
from pydantic import BaseModel, ValidationError
|
|
80
|
-
|
|
81
|
-
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
82
|
-
from .agent import Agent, AgentRun, RunOutputDataT
|
|
83
|
-
from .messages import (
|
|
84
|
-
AgentStreamEvent,
|
|
85
|
-
FunctionToolResultEvent,
|
|
86
|
-
ModelMessage,
|
|
87
|
-
ModelRequest,
|
|
88
|
-
ModelResponse,
|
|
89
|
-
PartDeltaEvent,
|
|
90
|
-
PartStartEvent,
|
|
91
|
-
SystemPromptPart,
|
|
92
|
-
TextPart,
|
|
93
|
-
TextPartDelta,
|
|
94
|
-
ThinkingPart,
|
|
95
|
-
ThinkingPartDelta,
|
|
96
|
-
ToolCallPart,
|
|
97
|
-
ToolCallPartDelta,
|
|
98
|
-
ToolReturnPart,
|
|
99
|
-
UserPromptPart,
|
|
100
|
-
)
|
|
101
|
-
from .models import KnownModelName, Model
|
|
102
|
-
from .output import DeferredToolCalls, OutputDataT, OutputSpec
|
|
103
|
-
from .settings import ModelSettings
|
|
104
|
-
from .tools import AgentDepsT, ToolDefinition
|
|
105
|
-
from .toolsets import AbstractToolset
|
|
106
|
-
from .toolsets.deferred import DeferredToolset
|
|
107
|
-
from .usage import Usage, UsageLimits
|
|
108
|
-
|
|
109
102
|
__all__ = [
|
|
110
103
|
'SSE_CONTENT_TYPE',
|
|
111
104
|
'StateDeps',
|
|
112
105
|
'StateHandler',
|
|
113
106
|
'AGUIApp',
|
|
107
|
+
'handle_ag_ui_request',
|
|
108
|
+
'run_ag_ui',
|
|
114
109
|
]
|
|
115
110
|
|
|
116
111
|
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
|
|
@@ -125,7 +120,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
|
125
120
|
agent: Agent[AgentDepsT, OutputDataT],
|
|
126
121
|
*,
|
|
127
122
|
# Agent.iter parameters.
|
|
128
|
-
output_type: OutputSpec[
|
|
123
|
+
output_type: OutputSpec[Any] | None = None,
|
|
129
124
|
model: Model | KnownModelName | str | None = None,
|
|
130
125
|
deps: AgentDepsT = None,
|
|
131
126
|
model_settings: ModelSettings | None = None,
|
|
@@ -142,10 +137,16 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
|
142
137
|
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
143
138
|
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
|
|
144
139
|
) -> None:
|
|
145
|
-
"""
|
|
140
|
+
"""An ASGI application that handles every AG-UI request by running the agent.
|
|
141
|
+
|
|
142
|
+
Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
|
|
143
|
+
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
|
|
144
|
+
To provide different `deps` for each request (e.g. based on the authenticated user),
|
|
145
|
+
use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
|
|
146
|
+
[`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
|
|
146
147
|
|
|
147
148
|
Args:
|
|
148
|
-
agent: The
|
|
149
|
+
agent: The agent to run.
|
|
149
150
|
|
|
150
151
|
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
151
152
|
no output validators since output validators would expect an argument that matches the agent's
|
|
@@ -156,7 +157,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
|
156
157
|
usage_limits: Optional limits on model request count or token usage.
|
|
157
158
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
158
159
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
159
|
-
toolsets: Optional
|
|
160
|
+
toolsets: Optional additional toolsets for this run.
|
|
160
161
|
|
|
161
162
|
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
162
163
|
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
@@ -185,320 +186,349 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
|
185
186
|
on_shutdown=on_shutdown,
|
|
186
187
|
lifespan=lifespan,
|
|
187
188
|
)
|
|
188
|
-
adapter = _Adapter(agent=agent)
|
|
189
189
|
|
|
190
|
-
async def endpoint(request: Request) -> Response
|
|
190
|
+
async def endpoint(request: Request) -> Response:
|
|
191
191
|
"""Endpoint to run the agent with the provided input data."""
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
adapter.run(
|
|
204
|
-
input_data,
|
|
205
|
-
accept,
|
|
206
|
-
output_type=output_type,
|
|
207
|
-
model=model,
|
|
208
|
-
deps=deps,
|
|
209
|
-
model_settings=model_settings,
|
|
210
|
-
usage_limits=usage_limits,
|
|
211
|
-
usage=usage,
|
|
212
|
-
infer_name=infer_name,
|
|
213
|
-
toolsets=toolsets,
|
|
214
|
-
),
|
|
215
|
-
media_type=SSE_CONTENT_TYPE,
|
|
192
|
+
return await handle_ag_ui_request(
|
|
193
|
+
agent,
|
|
194
|
+
request,
|
|
195
|
+
output_type=output_type,
|
|
196
|
+
model=model,
|
|
197
|
+
deps=deps,
|
|
198
|
+
model_settings=model_settings,
|
|
199
|
+
usage_limits=usage_limits,
|
|
200
|
+
usage=usage,
|
|
201
|
+
infer_name=infer_name,
|
|
202
|
+
toolsets=toolsets,
|
|
216
203
|
)
|
|
217
204
|
|
|
218
205
|
self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
|
|
219
206
|
|
|
220
207
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
208
|
+
async def handle_ag_ui_request(
|
|
209
|
+
agent: Agent[AgentDepsT, Any],
|
|
210
|
+
request: Request,
|
|
211
|
+
*,
|
|
212
|
+
output_type: OutputSpec[Any] | None = None,
|
|
213
|
+
model: Model | KnownModelName | str | None = None,
|
|
214
|
+
deps: AgentDepsT = None,
|
|
215
|
+
model_settings: ModelSettings | None = None,
|
|
216
|
+
usage_limits: UsageLimits | None = None,
|
|
217
|
+
usage: Usage | None = None,
|
|
218
|
+
infer_name: bool = True,
|
|
219
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
220
|
+
) -> Response:
|
|
221
|
+
"""Handle an AG-UI request by running the agent and returning a streaming response.
|
|
228
222
|
|
|
229
223
|
Args:
|
|
230
|
-
agent: The
|
|
224
|
+
agent: The agent to run.
|
|
225
|
+
request: The Starlette request (e.g. from FastAPI) containing the AG-UI run input.
|
|
226
|
+
|
|
227
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
228
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
229
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
230
|
+
deps: Optional dependencies to use for this run.
|
|
231
|
+
model_settings: Optional settings to use for this model's request.
|
|
232
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
233
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
234
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
235
|
+
toolsets: Optional additional toolsets for this run.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A streaming Starlette response with AG-UI protocol events.
|
|
231
239
|
"""
|
|
240
|
+
accept = request.headers.get('accept', SSE_CONTENT_TYPE)
|
|
241
|
+
try:
|
|
242
|
+
input_data = RunAgentInput.model_validate(await request.json())
|
|
243
|
+
except ValidationError as e: # pragma: no cover
|
|
244
|
+
return Response(
|
|
245
|
+
content=json.dumps(e.json()),
|
|
246
|
+
media_type='application/json',
|
|
247
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
248
|
+
)
|
|
232
249
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
)
|
|
249
|
-
"""Run the agent with streaming response using AG-UI protocol events.
|
|
250
|
+
return StreamingResponse(
|
|
251
|
+
run_ag_ui(
|
|
252
|
+
agent,
|
|
253
|
+
input_data,
|
|
254
|
+
accept,
|
|
255
|
+
output_type=output_type,
|
|
256
|
+
model=model,
|
|
257
|
+
deps=deps,
|
|
258
|
+
model_settings=model_settings,
|
|
259
|
+
usage_limits=usage_limits,
|
|
260
|
+
usage=usage,
|
|
261
|
+
infer_name=infer_name,
|
|
262
|
+
toolsets=toolsets,
|
|
263
|
+
),
|
|
264
|
+
media_type=accept,
|
|
265
|
+
)
|
|
250
266
|
|
|
251
|
-
The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method.
|
|
252
267
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
268
|
+
async def run_ag_ui(
|
|
269
|
+
agent: Agent[AgentDepsT, Any],
|
|
270
|
+
run_input: RunAgentInput,
|
|
271
|
+
accept: str = SSE_CONTENT_TYPE,
|
|
272
|
+
*,
|
|
273
|
+
output_type: OutputSpec[Any] | None = None,
|
|
274
|
+
model: Model | KnownModelName | str | None = None,
|
|
275
|
+
deps: AgentDepsT = None,
|
|
276
|
+
model_settings: ModelSettings | None = None,
|
|
277
|
+
usage_limits: UsageLimits | None = None,
|
|
278
|
+
usage: Usage | None = None,
|
|
279
|
+
infer_name: bool = True,
|
|
280
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
281
|
+
) -> AsyncIterator[str]:
|
|
282
|
+
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
|
|
256
283
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
284
|
+
Args:
|
|
285
|
+
agent: The agent to run.
|
|
286
|
+
run_input: The AG-UI run input containing thread_id, run_id, messages, etc.
|
|
287
|
+
accept: The accept header value for the run.
|
|
288
|
+
|
|
289
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
290
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
291
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
292
|
+
deps: Optional dependencies to use for this run.
|
|
293
|
+
model_settings: Optional settings to use for this model's request.
|
|
294
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
295
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
296
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
297
|
+
toolsets: Optional additional toolsets for this run.
|
|
298
|
+
|
|
299
|
+
Yields:
|
|
300
|
+
Streaming event chunks encoded as strings according to the accept header value.
|
|
301
|
+
"""
|
|
302
|
+
encoder = EventEncoder(accept=accept)
|
|
303
|
+
if run_input.tools:
|
|
304
|
+
# AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
|
|
305
|
+
# Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
|
|
306
|
+
# conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
|
|
307
|
+
toolset = DeferredToolset[AgentDepsT](
|
|
308
|
+
[
|
|
309
|
+
ToolDefinition(
|
|
310
|
+
name=tool.name,
|
|
311
|
+
description=tool.description,
|
|
312
|
+
parameters_json_schema=tool.parameters,
|
|
313
|
+
)
|
|
314
|
+
for tool in run_input.tools
|
|
315
|
+
]
|
|
316
|
+
)
|
|
317
|
+
toolsets = [*toolsets, toolset] if toolsets else [toolset]
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
yield encoder.encode(
|
|
321
|
+
RunStartedEvent(
|
|
322
|
+
thread_id=run_input.thread_id,
|
|
323
|
+
run_id=run_input.run_id,
|
|
324
|
+
),
|
|
325
|
+
)
|
|
266
326
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
"""
|
|
270
|
-
encoder = EventEncoder(accept=accept)
|
|
271
|
-
if run_input.tools:
|
|
272
|
-
# AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
|
|
273
|
-
# Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
|
|
274
|
-
# conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
|
|
275
|
-
toolset = DeferredToolset[AgentDepsT](
|
|
276
|
-
[
|
|
277
|
-
ToolDefinition(
|
|
278
|
-
name=tool.name,
|
|
279
|
-
description=tool.description,
|
|
280
|
-
parameters_json_schema=tool.parameters,
|
|
281
|
-
)
|
|
282
|
-
for tool in run_input.tools
|
|
283
|
-
]
|
|
284
|
-
)
|
|
285
|
-
toolsets = [*toolsets, toolset] if toolsets else [toolset]
|
|
286
|
-
|
|
287
|
-
try:
|
|
288
|
-
yield encoder.encode(
|
|
289
|
-
RunStartedEvent(
|
|
290
|
-
thread_id=run_input.thread_id,
|
|
291
|
-
run_id=run_input.run_id,
|
|
292
|
-
),
|
|
293
|
-
)
|
|
327
|
+
if not run_input.messages:
|
|
328
|
+
raise _NoMessagesError
|
|
294
329
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
state = type(deps.state).model_validate(raw_state)
|
|
303
|
-
except ValidationError as e: # pragma: no cover
|
|
304
|
-
raise _InvalidStateError from e
|
|
305
|
-
else:
|
|
306
|
-
state = raw_state
|
|
307
|
-
|
|
308
|
-
deps = replace(deps, state=state)
|
|
309
|
-
elif raw_state:
|
|
310
|
-
raise UserError(
|
|
311
|
-
f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
|
|
312
|
-
)
|
|
330
|
+
raw_state: dict[str, Any] = run_input.state or {}
|
|
331
|
+
if isinstance(deps, StateHandler):
|
|
332
|
+
if isinstance(deps.state, BaseModel):
|
|
333
|
+
try:
|
|
334
|
+
state = type(deps.state).model_validate(raw_state)
|
|
335
|
+
except ValidationError as e: # pragma: no cover
|
|
336
|
+
raise _InvalidStateError from e
|
|
313
337
|
else:
|
|
314
|
-
|
|
315
|
-
pass
|
|
316
|
-
|
|
317
|
-
messages = _messages_from_ag_ui(run_input.messages)
|
|
338
|
+
state = raw_state
|
|
318
339
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
model=model,
|
|
324
|
-
deps=deps,
|
|
325
|
-
model_settings=model_settings,
|
|
326
|
-
usage_limits=usage_limits,
|
|
327
|
-
usage=usage,
|
|
328
|
-
infer_name=infer_name,
|
|
329
|
-
toolsets=toolsets,
|
|
330
|
-
) as run:
|
|
331
|
-
async for event in self._agent_stream(run):
|
|
332
|
-
yield encoder.encode(event)
|
|
333
|
-
except _RunError as e:
|
|
334
|
-
yield encoder.encode(
|
|
335
|
-
RunErrorEvent(message=e.message, code=e.code),
|
|
336
|
-
)
|
|
337
|
-
except Exception as e:
|
|
338
|
-
yield encoder.encode(
|
|
339
|
-
RunErrorEvent(message=str(e)),
|
|
340
|
+
deps = replace(deps, state=state)
|
|
341
|
+
elif raw_state:
|
|
342
|
+
raise UserError(
|
|
343
|
+
f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
|
|
340
344
|
)
|
|
341
|
-
raise e
|
|
342
345
|
else:
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
thread_id=run_input.thread_id,
|
|
346
|
-
run_id=run_input.run_id,
|
|
347
|
-
),
|
|
348
|
-
)
|
|
346
|
+
# `deps` not being a `StateHandler` is OK if there is no state.
|
|
347
|
+
pass
|
|
349
348
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
349
|
+
messages = _messages_from_ag_ui(run_input.messages)
|
|
350
|
+
|
|
351
|
+
async with agent.iter(
|
|
352
|
+
user_prompt=None,
|
|
353
|
+
output_type=[output_type or agent.output_type, DeferredToolCalls],
|
|
354
|
+
message_history=messages,
|
|
355
|
+
model=model,
|
|
356
|
+
deps=deps,
|
|
357
|
+
model_settings=model_settings,
|
|
358
|
+
usage_limits=usage_limits,
|
|
359
|
+
usage=usage,
|
|
360
|
+
infer_name=infer_name,
|
|
361
|
+
toolsets=toolsets,
|
|
362
|
+
) as run:
|
|
363
|
+
async for event in _agent_stream(run):
|
|
364
|
+
yield encoder.encode(event)
|
|
365
|
+
except _RunError as e:
|
|
366
|
+
yield encoder.encode(
|
|
367
|
+
RunErrorEvent(message=e.message, code=e.code),
|
|
368
|
+
)
|
|
369
|
+
except Exception as e:
|
|
370
|
+
yield encoder.encode(
|
|
371
|
+
RunErrorEvent(message=str(e)),
|
|
372
|
+
)
|
|
373
|
+
raise e
|
|
374
|
+
else:
|
|
375
|
+
yield encoder.encode(
|
|
376
|
+
RunFinishedEvent(
|
|
377
|
+
thread_id=run_input.thread_id,
|
|
378
|
+
run_id=run_input.run_id,
|
|
379
|
+
),
|
|
380
|
+
)
|
|
355
381
|
|
|
356
|
-
Args:
|
|
357
|
-
run: The agent run to process.
|
|
358
382
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
383
|
+
async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEvent]:
|
|
384
|
+
"""Run the agent streaming responses using AG-UI protocol events.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
run: The agent run to process.
|
|
388
|
+
|
|
389
|
+
Yields:
|
|
390
|
+
AG-UI Server-Sent Events (SSE).
|
|
391
|
+
"""
|
|
392
|
+
async for node in run:
|
|
393
|
+
stream_ctx = _RequestStreamContext()
|
|
394
|
+
if isinstance(node, ModelRequestNode):
|
|
395
|
+
async with node.stream(run.ctx) as request_stream:
|
|
396
|
+
async for agent_event in request_stream:
|
|
397
|
+
async for msg in _handle_model_request_event(stream_ctx, agent_event):
|
|
398
|
+
yield msg
|
|
399
|
+
|
|
400
|
+
if stream_ctx.part_end: # pragma: no branch
|
|
401
|
+
yield stream_ctx.part_end
|
|
402
|
+
stream_ctx.part_end = None
|
|
403
|
+
elif isinstance(node, CallToolsNode):
|
|
404
|
+
async with node.stream(run.ctx) as handle_stream:
|
|
405
|
+
async for event in handle_stream:
|
|
406
|
+
if isinstance(event, FunctionToolResultEvent):
|
|
407
|
+
async for msg in _handle_tool_result_event(stream_ctx, event):
|
|
368
408
|
yield msg
|
|
369
409
|
|
|
370
|
-
if stream_ctx.part_end: # pragma: no branch
|
|
371
|
-
yield stream_ctx.part_end
|
|
372
|
-
stream_ctx.part_end = None
|
|
373
|
-
elif isinstance(node, CallToolsNode):
|
|
374
|
-
async with node.stream(run.ctx) as handle_stream:
|
|
375
|
-
async for event in handle_stream:
|
|
376
|
-
if isinstance(event, FunctionToolResultEvent):
|
|
377
|
-
async for msg in self._handle_tool_result_event(stream_ctx, event):
|
|
378
|
-
yield msg
|
|
379
|
-
|
|
380
|
-
async def _handle_model_request_event(
|
|
381
|
-
self,
|
|
382
|
-
stream_ctx: _RequestStreamContext,
|
|
383
|
-
agent_event: AgentStreamEvent,
|
|
384
|
-
) -> AsyncGenerator[BaseEvent, None]:
|
|
385
|
-
"""Handle an agent event and yield AG-UI protocol events.
|
|
386
410
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
411
|
+
async def _handle_model_request_event(
|
|
412
|
+
stream_ctx: _RequestStreamContext,
|
|
413
|
+
agent_event: AgentStreamEvent,
|
|
414
|
+
) -> AsyncIterator[BaseEvent]:
|
|
415
|
+
"""Handle an agent event and yield AG-UI protocol events.
|
|
390
416
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
417
|
+
Args:
|
|
418
|
+
stream_ctx: The request stream context to manage state.
|
|
419
|
+
agent_event: The agent event to process.
|
|
420
|
+
|
|
421
|
+
Yields:
|
|
422
|
+
AG-UI Server-Sent Events (SSE) based on the agent event.
|
|
423
|
+
"""
|
|
424
|
+
if isinstance(agent_event, PartStartEvent):
|
|
425
|
+
if stream_ctx.part_end:
|
|
426
|
+
# End the previous part.
|
|
427
|
+
yield stream_ctx.part_end
|
|
428
|
+
stream_ctx.part_end = None
|
|
429
|
+
|
|
430
|
+
part = agent_event.part
|
|
431
|
+
if isinstance(part, TextPart):
|
|
432
|
+
message_id = stream_ctx.new_message_id()
|
|
433
|
+
yield TextMessageStartEvent(
|
|
434
|
+
message_id=message_id,
|
|
435
|
+
)
|
|
436
|
+
if part.content: # pragma: no branch
|
|
437
|
+
yield TextMessageContentEvent(
|
|
412
438
|
message_id=message_id,
|
|
439
|
+
delta=part.content,
|
|
413
440
|
)
|
|
414
|
-
|
|
415
|
-
message_id
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
stream_ctx.part_end = ToolCallEndEvent(
|
|
441
|
+
stream_ctx.part_end = TextMessageEndEvent(
|
|
442
|
+
message_id=message_id,
|
|
443
|
+
)
|
|
444
|
+
elif isinstance(part, ToolCallPart): # pragma: no branch
|
|
445
|
+
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
|
|
446
|
+
yield ToolCallStartEvent(
|
|
447
|
+
tool_call_id=part.tool_call_id,
|
|
448
|
+
tool_call_name=part.tool_name,
|
|
449
|
+
parent_message_id=message_id,
|
|
450
|
+
)
|
|
451
|
+
if part.args:
|
|
452
|
+
yield ToolCallArgsEvent(
|
|
427
453
|
tool_call_id=part.tool_call_id,
|
|
454
|
+
delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
|
|
428
455
|
)
|
|
456
|
+
stream_ctx.part_end = ToolCallEndEvent(
|
|
457
|
+
tool_call_id=part.tool_call_id,
|
|
458
|
+
)
|
|
429
459
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
460
|
+
elif isinstance(part, ThinkingPart): # pragma: no branch
|
|
461
|
+
yield ThinkingTextMessageStartEvent(
|
|
462
|
+
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
463
|
+
)
|
|
464
|
+
# Always send the content even if it's empty, as it may be
|
|
465
|
+
# used to indicate the start of thinking.
|
|
466
|
+
yield ThinkingTextMessageContentEvent(
|
|
467
|
+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
468
|
+
delta=part.content,
|
|
469
|
+
)
|
|
470
|
+
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
471
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
elif isinstance(agent_event, PartDeltaEvent):
|
|
475
|
+
delta = agent_event.delta
|
|
476
|
+
if isinstance(delta, TextPartDelta):
|
|
477
|
+
yield TextMessageContentEvent(
|
|
478
|
+
message_id=stream_ctx.message_id,
|
|
479
|
+
delta=delta.content_delta,
|
|
480
|
+
)
|
|
481
|
+
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
482
|
+
assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
483
|
+
yield ToolCallArgsEvent(
|
|
484
|
+
tool_call_id=delta.tool_call_id,
|
|
485
|
+
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
486
|
+
)
|
|
487
|
+
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
488
|
+
if delta.content_delta: # pragma: no branch
|
|
436
489
|
yield ThinkingTextMessageContentEvent(
|
|
437
490
|
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
438
|
-
delta=part.content,
|
|
439
|
-
)
|
|
440
|
-
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
441
|
-
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
elif isinstance(agent_event, PartDeltaEvent):
|
|
445
|
-
delta = agent_event.delta
|
|
446
|
-
if isinstance(delta, TextPartDelta):
|
|
447
|
-
yield TextMessageContentEvent(
|
|
448
|
-
message_id=stream_ctx.message_id,
|
|
449
491
|
delta=delta.content_delta,
|
|
450
492
|
)
|
|
451
|
-
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
452
|
-
assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
453
|
-
yield ToolCallArgsEvent(
|
|
454
|
-
tool_call_id=delta.tool_call_id,
|
|
455
|
-
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
456
|
-
)
|
|
457
|
-
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
458
|
-
if delta.content_delta: # pragma: no branch
|
|
459
|
-
yield ThinkingTextMessageContentEvent(
|
|
460
|
-
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
461
|
-
delta=delta.content_delta,
|
|
462
|
-
)
|
|
463
493
|
|
|
464
|
-
async def _handle_tool_result_event(
|
|
465
|
-
self,
|
|
466
|
-
stream_ctx: _RequestStreamContext,
|
|
467
|
-
event: FunctionToolResultEvent,
|
|
468
|
-
) -> AsyncGenerator[BaseEvent, None]:
|
|
469
|
-
"""Convert a tool call result to AG-UI events.
|
|
470
494
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
495
|
+
async def _handle_tool_result_event(
|
|
496
|
+
stream_ctx: _RequestStreamContext,
|
|
497
|
+
event: FunctionToolResultEvent,
|
|
498
|
+
) -> AsyncIterator[BaseEvent]:
|
|
499
|
+
"""Convert a tool call result to AG-UI events.
|
|
474
500
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
result = event.result
|
|
479
|
-
if not isinstance(result, ToolReturnPart):
|
|
480
|
-
return
|
|
481
|
-
|
|
482
|
-
message_id = stream_ctx.new_message_id()
|
|
483
|
-
yield ToolCallResultEvent(
|
|
484
|
-
message_id=message_id,
|
|
485
|
-
type=EventType.TOOL_CALL_RESULT,
|
|
486
|
-
role='tool',
|
|
487
|
-
tool_call_id=result.tool_call_id,
|
|
488
|
-
content=result.model_response_str(),
|
|
489
|
-
)
|
|
501
|
+
Args:
|
|
502
|
+
stream_ctx: The request stream context to manage state.
|
|
503
|
+
event: The tool call result event to process.
|
|
490
504
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
505
|
+
Yields:
|
|
506
|
+
AG-UI Server-Sent Events (SSE).
|
|
507
|
+
"""
|
|
508
|
+
result = event.result
|
|
509
|
+
if not isinstance(result, ToolReturnPart):
|
|
510
|
+
return
|
|
511
|
+
|
|
512
|
+
message_id = stream_ctx.new_message_id()
|
|
513
|
+
yield ToolCallResultEvent(
|
|
514
|
+
message_id=message_id,
|
|
515
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
516
|
+
role='tool',
|
|
517
|
+
tool_call_id=result.tool_call_id,
|
|
518
|
+
content=result.model_response_str(),
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Now check for AG-UI events returned by the tool calls.
|
|
522
|
+
content = result.content
|
|
523
|
+
if isinstance(content, BaseEvent):
|
|
524
|
+
yield content
|
|
525
|
+
elif isinstance(content, (str, bytes)): # pragma: no branch
|
|
526
|
+
# Avoid iterable check for strings and bytes.
|
|
527
|
+
pass
|
|
528
|
+
elif isinstance(content, Iterable): # pragma: no branch
|
|
529
|
+
for item in content: # type: ignore[reportUnknownMemberType]
|
|
530
|
+
if isinstance(item, BaseEvent): # pragma: no branch
|
|
531
|
+
yield item
|
|
502
532
|
|
|
503
533
|
|
|
504
534
|
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|