pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -156
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -9
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +3 -3
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
1
|
+
"""Provides an AG-UI protocol adapter for the Pydantic AI agent.
|
|
2
|
+
|
|
3
|
+
This package provides seamless integration between pydantic-ai agents and ag-ui
|
|
4
|
+
for building interactive AI applications with streaming event-based communication.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import uuid
|
|
11
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from http import HTTPStatus
|
|
14
|
+
from typing import (
|
|
15
|
+
Any,
|
|
16
|
+
Callable,
|
|
17
|
+
Final,
|
|
18
|
+
Generic,
|
|
19
|
+
Protocol,
|
|
20
|
+
TypeVar,
|
|
21
|
+
runtime_checkable,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from ag_ui.core import (
|
|
26
|
+
AssistantMessage,
|
|
27
|
+
BaseEvent,
|
|
28
|
+
DeveloperMessage,
|
|
29
|
+
EventType,
|
|
30
|
+
Message,
|
|
31
|
+
RunAgentInput,
|
|
32
|
+
RunErrorEvent,
|
|
33
|
+
RunFinishedEvent,
|
|
34
|
+
RunStartedEvent,
|
|
35
|
+
State,
|
|
36
|
+
SystemMessage,
|
|
37
|
+
TextMessageContentEvent,
|
|
38
|
+
TextMessageEndEvent,
|
|
39
|
+
TextMessageStartEvent,
|
|
40
|
+
ThinkingTextMessageContentEvent,
|
|
41
|
+
ThinkingTextMessageEndEvent,
|
|
42
|
+
ThinkingTextMessageStartEvent,
|
|
43
|
+
ToolCallArgsEvent,
|
|
44
|
+
ToolCallEndEvent,
|
|
45
|
+
ToolCallResultEvent,
|
|
46
|
+
ToolCallStartEvent,
|
|
47
|
+
ToolMessage,
|
|
48
|
+
UserMessage,
|
|
49
|
+
)
|
|
50
|
+
from ag_ui.encoder import EventEncoder
|
|
51
|
+
except ImportError as e: # pragma: no cover
|
|
52
|
+
raise ImportError(
|
|
53
|
+
'Please install the `ag-ui-protocol` package to use `Agent.to_ag_ui()` method, '
|
|
54
|
+
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
55
|
+
) from e
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
from starlette.applications import Starlette
|
|
59
|
+
from starlette.middleware import Middleware
|
|
60
|
+
from starlette.requests import Request
|
|
61
|
+
from starlette.responses import Response, StreamingResponse
|
|
62
|
+
from starlette.routing import BaseRoute
|
|
63
|
+
from starlette.types import ExceptionHandler, Lifespan
|
|
64
|
+
except ImportError as e: # pragma: no cover
|
|
65
|
+
raise ImportError(
|
|
66
|
+
'Please install the `starlette` package to use `Agent.to_ag_ui()` method, '
|
|
67
|
+
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
68
|
+
) from e
|
|
69
|
+
|
|
70
|
+
from collections.abc import AsyncGenerator
|
|
71
|
+
|
|
72
|
+
from pydantic import BaseModel, ValidationError
|
|
73
|
+
|
|
74
|
+
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
75
|
+
from .agent import Agent, AgentRun, RunOutputDataT
|
|
76
|
+
from .messages import (
|
|
77
|
+
AgentStreamEvent,
|
|
78
|
+
FunctionToolResultEvent,
|
|
79
|
+
ModelMessage,
|
|
80
|
+
ModelRequest,
|
|
81
|
+
ModelResponse,
|
|
82
|
+
PartDeltaEvent,
|
|
83
|
+
PartStartEvent,
|
|
84
|
+
SystemPromptPart,
|
|
85
|
+
TextPart,
|
|
86
|
+
TextPartDelta,
|
|
87
|
+
ThinkingPart,
|
|
88
|
+
ThinkingPartDelta,
|
|
89
|
+
ToolCallPart,
|
|
90
|
+
ToolCallPartDelta,
|
|
91
|
+
ToolReturnPart,
|
|
92
|
+
UserPromptPart,
|
|
93
|
+
)
|
|
94
|
+
from .models import KnownModelName, Model
|
|
95
|
+
from .output import DeferredToolCalls, OutputDataT, OutputSpec
|
|
96
|
+
from .settings import ModelSettings
|
|
97
|
+
from .tools import AgentDepsT, ToolDefinition
|
|
98
|
+
from .toolsets import AbstractToolset
|
|
99
|
+
from .toolsets.deferred import DeferredToolset
|
|
100
|
+
from .usage import Usage, UsageLimits
|
|
101
|
+
|
|
102
|
+
__all__ = [
|
|
103
|
+
'SSE_CONTENT_TYPE',
|
|
104
|
+
'StateDeps',
|
|
105
|
+
'StateHandler',
|
|
106
|
+
'AGUIApp',
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
|
|
110
|
+
"""Content type header value for Server-Sent Events (SSE)."""
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
114
|
+
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
agent: Agent[AgentDepsT, OutputDataT],
|
|
119
|
+
*,
|
|
120
|
+
# Agent.iter parameters.
|
|
121
|
+
output_type: OutputSpec[OutputDataT] | None = None,
|
|
122
|
+
model: Model | KnownModelName | str | None = None,
|
|
123
|
+
deps: AgentDepsT = None,
|
|
124
|
+
model_settings: ModelSettings | None = None,
|
|
125
|
+
usage_limits: UsageLimits | None = None,
|
|
126
|
+
usage: Usage | None = None,
|
|
127
|
+
infer_name: bool = True,
|
|
128
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
129
|
+
# Starlette parameters.
|
|
130
|
+
debug: bool = False,
|
|
131
|
+
routes: Sequence[BaseRoute] | None = None,
|
|
132
|
+
middleware: Sequence[Middleware] | None = None,
|
|
133
|
+
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
|
134
|
+
on_startup: Sequence[Callable[[], Any]] | None = None,
|
|
135
|
+
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
136
|
+
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
|
|
137
|
+
) -> None:
|
|
138
|
+
"""Initialise the AG-UI application.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
agent: The Pydantic AI `Agent` to adapt.
|
|
142
|
+
|
|
143
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
144
|
+
no output validators since output validators would expect an argument that matches the agent's
|
|
145
|
+
output type.
|
|
146
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
147
|
+
deps: Optional dependencies to use for this run.
|
|
148
|
+
model_settings: Optional settings to use for this model's request.
|
|
149
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
150
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
151
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
152
|
+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
|
|
153
|
+
|
|
154
|
+
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
155
|
+
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
156
|
+
middleware: A list of middleware to run for every request. A starlette application will always
|
|
157
|
+
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
|
|
158
|
+
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
|
|
159
|
+
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
|
|
160
|
+
exception cases occurring in the routing or endpoints.
|
|
161
|
+
exception_handlers: A mapping of either integer status codes, or exception class types onto
|
|
162
|
+
callables which handle the exceptions. Exception handler callables should be of the form
|
|
163
|
+
`handler(request, exc) -> response` and may be either standard functions, or async functions.
|
|
164
|
+
on_startup: A list of callables to run on application startup. Startup handler callables do not
|
|
165
|
+
take any arguments, and may be either standard functions, or async functions.
|
|
166
|
+
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
|
|
167
|
+
not take any arguments, and may be either standard functions, or async functions.
|
|
168
|
+
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
|
|
169
|
+
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
|
|
170
|
+
the other, not both.
|
|
171
|
+
"""
|
|
172
|
+
super().__init__(
|
|
173
|
+
debug=debug,
|
|
174
|
+
routes=routes,
|
|
175
|
+
middleware=middleware,
|
|
176
|
+
exception_handlers=exception_handlers,
|
|
177
|
+
on_startup=on_startup,
|
|
178
|
+
on_shutdown=on_shutdown,
|
|
179
|
+
lifespan=lifespan,
|
|
180
|
+
)
|
|
181
|
+
adapter = _Adapter(agent=agent)
|
|
182
|
+
|
|
183
|
+
async def endpoint(request: Request) -> Response | StreamingResponse:
|
|
184
|
+
"""Endpoint to run the agent with the provided input data."""
|
|
185
|
+
accept = request.headers.get('accept', SSE_CONTENT_TYPE)
|
|
186
|
+
try:
|
|
187
|
+
input_data = RunAgentInput.model_validate(await request.json())
|
|
188
|
+
except ValidationError as e: # pragma: no cover
|
|
189
|
+
return Response(
|
|
190
|
+
content=json.dumps(e.json()),
|
|
191
|
+
media_type='application/json',
|
|
192
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return StreamingResponse(
|
|
196
|
+
adapter.run(
|
|
197
|
+
input_data,
|
|
198
|
+
accept,
|
|
199
|
+
output_type=output_type,
|
|
200
|
+
model=model,
|
|
201
|
+
deps=deps,
|
|
202
|
+
model_settings=model_settings,
|
|
203
|
+
usage_limits=usage_limits,
|
|
204
|
+
usage=usage,
|
|
205
|
+
infer_name=infer_name,
|
|
206
|
+
toolsets=toolsets,
|
|
207
|
+
),
|
|
208
|
+
media_type=SSE_CONTENT_TYPE,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass(repr=False)
|
|
215
|
+
class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
216
|
+
"""An agent adapter providing AG-UI protocol support for Pydantic AI agents.
|
|
217
|
+
|
|
218
|
+
This class manages the agent runs, tool calls, state storage and providing
|
|
219
|
+
an adapter for running agents with Server-Sent Event (SSE) streaming
|
|
220
|
+
responses using the AG-UI protocol.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
agent: The Pydantic AI `Agent` to adapt.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
agent: Agent[AgentDepsT, OutputDataT] = field(repr=False)
|
|
227
|
+
|
|
228
|
+
async def run(
|
|
229
|
+
self,
|
|
230
|
+
run_input: RunAgentInput,
|
|
231
|
+
accept: str = SSE_CONTENT_TYPE,
|
|
232
|
+
*,
|
|
233
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
234
|
+
model: Model | KnownModelName | str | None = None,
|
|
235
|
+
deps: AgentDepsT = None,
|
|
236
|
+
model_settings: ModelSettings | None = None,
|
|
237
|
+
usage_limits: UsageLimits | None = None,
|
|
238
|
+
usage: Usage | None = None,
|
|
239
|
+
infer_name: bool = True,
|
|
240
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
241
|
+
) -> AsyncGenerator[str, None]:
|
|
242
|
+
"""Run the agent with streaming response using AG-UI protocol events.
|
|
243
|
+
|
|
244
|
+
The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
run_input: The AG-UI run input containing thread_id, run_id, messages, etc.
|
|
248
|
+
accept: The accept header value for the run.
|
|
249
|
+
|
|
250
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
251
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
252
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
253
|
+
deps: Optional dependencies to use for this run.
|
|
254
|
+
model_settings: Optional settings to use for this model's request.
|
|
255
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
256
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
257
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
258
|
+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
|
|
259
|
+
|
|
260
|
+
Yields:
|
|
261
|
+
Streaming SSE-formatted event chunks.
|
|
262
|
+
"""
|
|
263
|
+
encoder = EventEncoder(accept=accept)
|
|
264
|
+
if run_input.tools:
|
|
265
|
+
# AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
|
|
266
|
+
# Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
|
|
267
|
+
# conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
|
|
268
|
+
toolset = DeferredToolset[AgentDepsT](
|
|
269
|
+
[
|
|
270
|
+
ToolDefinition(
|
|
271
|
+
name=tool.name,
|
|
272
|
+
description=tool.description,
|
|
273
|
+
parameters_json_schema=tool.parameters,
|
|
274
|
+
)
|
|
275
|
+
for tool in run_input.tools
|
|
276
|
+
]
|
|
277
|
+
)
|
|
278
|
+
toolsets = [*toolsets, toolset] if toolsets else [toolset]
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
yield encoder.encode(
|
|
282
|
+
RunStartedEvent(
|
|
283
|
+
thread_id=run_input.thread_id,
|
|
284
|
+
run_id=run_input.run_id,
|
|
285
|
+
),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if not run_input.messages:
|
|
289
|
+
raise _NoMessagesError
|
|
290
|
+
|
|
291
|
+
if isinstance(deps, StateHandler):
|
|
292
|
+
deps.state = run_input.state
|
|
293
|
+
|
|
294
|
+
history = _History.from_ag_ui(run_input.messages)
|
|
295
|
+
|
|
296
|
+
async with self.agent.iter(
|
|
297
|
+
user_prompt=None,
|
|
298
|
+
output_type=[output_type or self.agent.output_type, DeferredToolCalls],
|
|
299
|
+
message_history=history.messages,
|
|
300
|
+
model=model,
|
|
301
|
+
deps=deps,
|
|
302
|
+
model_settings=model_settings,
|
|
303
|
+
usage_limits=usage_limits,
|
|
304
|
+
usage=usage,
|
|
305
|
+
infer_name=infer_name,
|
|
306
|
+
toolsets=toolsets,
|
|
307
|
+
) as run:
|
|
308
|
+
async for event in self._agent_stream(run, history):
|
|
309
|
+
yield encoder.encode(event)
|
|
310
|
+
except _RunError as e:
|
|
311
|
+
yield encoder.encode(
|
|
312
|
+
RunErrorEvent(message=e.message, code=e.code),
|
|
313
|
+
)
|
|
314
|
+
except Exception as e: # pragma: no cover
|
|
315
|
+
yield encoder.encode(
|
|
316
|
+
RunErrorEvent(message=str(e)),
|
|
317
|
+
)
|
|
318
|
+
raise e
|
|
319
|
+
else:
|
|
320
|
+
yield encoder.encode(
|
|
321
|
+
RunFinishedEvent(
|
|
322
|
+
thread_id=run_input.thread_id,
|
|
323
|
+
run_id=run_input.run_id,
|
|
324
|
+
),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
async def _agent_stream(
|
|
328
|
+
self,
|
|
329
|
+
run: AgentRun[AgentDepsT, Any],
|
|
330
|
+
history: _History,
|
|
331
|
+
) -> AsyncGenerator[BaseEvent, None]:
|
|
332
|
+
"""Run the agent streaming responses using AG-UI protocol events.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
run: The agent run to process.
|
|
336
|
+
history: The history of messages and tool calls to use for the run.
|
|
337
|
+
|
|
338
|
+
Yields:
|
|
339
|
+
AG-UI Server-Sent Events (SSE).
|
|
340
|
+
"""
|
|
341
|
+
async for node in run:
|
|
342
|
+
if isinstance(node, ModelRequestNode):
|
|
343
|
+
stream_ctx = _RequestStreamContext()
|
|
344
|
+
async with node.stream(run.ctx) as request_stream:
|
|
345
|
+
async for agent_event in request_stream:
|
|
346
|
+
async for msg in self._handle_model_request_event(stream_ctx, agent_event):
|
|
347
|
+
yield msg
|
|
348
|
+
|
|
349
|
+
if stream_ctx.part_end: # pragma: no branch
|
|
350
|
+
yield stream_ctx.part_end
|
|
351
|
+
stream_ctx.part_end = None
|
|
352
|
+
elif isinstance(node, CallToolsNode):
|
|
353
|
+
async with node.stream(run.ctx) as handle_stream:
|
|
354
|
+
async for event in handle_stream:
|
|
355
|
+
if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart):
|
|
356
|
+
async for msg in self._handle_tool_result_event(event.result, history.prompt_message_id):
|
|
357
|
+
yield msg
|
|
358
|
+
|
|
359
|
+
async def _handle_model_request_event(
|
|
360
|
+
self,
|
|
361
|
+
stream_ctx: _RequestStreamContext,
|
|
362
|
+
agent_event: AgentStreamEvent,
|
|
363
|
+
) -> AsyncGenerator[BaseEvent, None]:
|
|
364
|
+
"""Handle an agent event and yield AG-UI protocol events.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
stream_ctx: The request stream context to manage state.
|
|
368
|
+
agent_event: The agent event to process.
|
|
369
|
+
|
|
370
|
+
Yields:
|
|
371
|
+
AG-UI Server-Sent Events (SSE) based on the agent event.
|
|
372
|
+
"""
|
|
373
|
+
if isinstance(agent_event, PartStartEvent):
|
|
374
|
+
if stream_ctx.part_end:
|
|
375
|
+
# End the previous part.
|
|
376
|
+
yield stream_ctx.part_end
|
|
377
|
+
stream_ctx.part_end = None
|
|
378
|
+
|
|
379
|
+
part = agent_event.part
|
|
380
|
+
if isinstance(part, TextPart):
|
|
381
|
+
message_id = stream_ctx.new_message_id()
|
|
382
|
+
yield TextMessageStartEvent(
|
|
383
|
+
message_id=message_id,
|
|
384
|
+
)
|
|
385
|
+
stream_ctx.part_end = TextMessageEndEvent(
|
|
386
|
+
message_id=message_id,
|
|
387
|
+
)
|
|
388
|
+
if part.content: # pragma: no branch
|
|
389
|
+
yield TextMessageContentEvent(
|
|
390
|
+
message_id=message_id,
|
|
391
|
+
delta=part.content,
|
|
392
|
+
)
|
|
393
|
+
elif isinstance(part, ToolCallPart): # pragma: no branch
|
|
394
|
+
yield ToolCallStartEvent(
|
|
395
|
+
tool_call_id=part.tool_call_id,
|
|
396
|
+
tool_call_name=part.tool_name,
|
|
397
|
+
)
|
|
398
|
+
stream_ctx.part_end = ToolCallEndEvent(
|
|
399
|
+
tool_call_id=part.tool_call_id,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
elif isinstance(part, ThinkingPart): # pragma: no branch
|
|
403
|
+
yield ThinkingTextMessageStartEvent(
|
|
404
|
+
type=EventType.THINKING_TEXT_MESSAGE_START,
|
|
405
|
+
)
|
|
406
|
+
# Always send the content even if it's empty, as it may be
|
|
407
|
+
# used to indicate the start of thinking.
|
|
408
|
+
yield ThinkingTextMessageContentEvent(
|
|
409
|
+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
410
|
+
delta=part.content or '',
|
|
411
|
+
)
|
|
412
|
+
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
413
|
+
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
elif isinstance(agent_event, PartDeltaEvent):
|
|
417
|
+
delta = agent_event.delta
|
|
418
|
+
if isinstance(delta, TextPartDelta):
|
|
419
|
+
yield TextMessageContentEvent(
|
|
420
|
+
message_id=stream_ctx.message_id,
|
|
421
|
+
delta=delta.content_delta,
|
|
422
|
+
)
|
|
423
|
+
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
424
|
+
assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
425
|
+
yield ToolCallArgsEvent(
|
|
426
|
+
tool_call_id=delta.tool_call_id,
|
|
427
|
+
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
428
|
+
)
|
|
429
|
+
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
430
|
+
if delta.content_delta: # pragma: no branch
|
|
431
|
+
yield ThinkingTextMessageContentEvent(
|
|
432
|
+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
433
|
+
delta=delta.content_delta,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
async def _handle_tool_result_event(
|
|
437
|
+
self,
|
|
438
|
+
result: ToolReturnPart,
|
|
439
|
+
prompt_message_id: str,
|
|
440
|
+
) -> AsyncGenerator[BaseEvent, None]:
|
|
441
|
+
"""Convert a tool call result to AG-UI events.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
result: The tool call result to process.
|
|
445
|
+
prompt_message_id: The message ID of the prompt that initiated the tool call.
|
|
446
|
+
|
|
447
|
+
Yields:
|
|
448
|
+
AG-UI Server-Sent Events (SSE).
|
|
449
|
+
"""
|
|
450
|
+
yield ToolCallResultEvent(
|
|
451
|
+
message_id=prompt_message_id,
|
|
452
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
453
|
+
role='tool',
|
|
454
|
+
tool_call_id=result.tool_call_id,
|
|
455
|
+
content=result.model_response_str(),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Now check for AG-UI events returned by the tool calls.
|
|
459
|
+
content = result.content
|
|
460
|
+
if isinstance(content, BaseEvent):
|
|
461
|
+
yield content
|
|
462
|
+
elif isinstance(content, (str, bytes)): # pragma: no branch
|
|
463
|
+
# Avoid iterable check for strings and bytes.
|
|
464
|
+
pass
|
|
465
|
+
elif isinstance(content, Iterable): # pragma: no branch
|
|
466
|
+
for item in content: # type: ignore[reportUnknownMemberType]
|
|
467
|
+
if isinstance(item, BaseEvent): # pragma: no branch
|
|
468
|
+
yield item
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@dataclass
|
|
472
|
+
class _History:
|
|
473
|
+
"""A simple history representation for AG-UI protocol."""
|
|
474
|
+
|
|
475
|
+
prompt_message_id: str # The ID of the last user message.
|
|
476
|
+
messages: list[ModelMessage]
|
|
477
|
+
|
|
478
|
+
@classmethod
|
|
479
|
+
def from_ag_ui(cls, messages: list[Message]) -> _History:
|
|
480
|
+
"""Convert a AG-UI history to a Pydantic AI one.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
messages: List of AG-UI messages to convert.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
List of Pydantic AI model messages.
|
|
487
|
+
"""
|
|
488
|
+
prompt_message_id = ''
|
|
489
|
+
result: list[ModelMessage] = []
|
|
490
|
+
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
491
|
+
for msg in messages:
|
|
492
|
+
if isinstance(msg, UserMessage):
|
|
493
|
+
prompt_message_id = msg.id
|
|
494
|
+
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
|
|
495
|
+
elif isinstance(msg, AssistantMessage):
|
|
496
|
+
if msg.tool_calls:
|
|
497
|
+
for tool_call in msg.tool_calls:
|
|
498
|
+
tool_calls[tool_call.id] = tool_call.function.name
|
|
499
|
+
|
|
500
|
+
result.append(
|
|
501
|
+
ModelResponse(
|
|
502
|
+
parts=[
|
|
503
|
+
ToolCallPart(
|
|
504
|
+
tool_name=tool_call.function.name,
|
|
505
|
+
tool_call_id=tool_call.id,
|
|
506
|
+
args=tool_call.function.arguments,
|
|
507
|
+
)
|
|
508
|
+
for tool_call in msg.tool_calls
|
|
509
|
+
]
|
|
510
|
+
)
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
if msg.content:
|
|
514
|
+
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
|
|
515
|
+
elif isinstance(msg, SystemMessage):
|
|
516
|
+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
517
|
+
elif isinstance(msg, ToolMessage):
|
|
518
|
+
tool_name = tool_calls.get(msg.tool_call_id)
|
|
519
|
+
if tool_name is None: # pragma: no cover
|
|
520
|
+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
521
|
+
|
|
522
|
+
result.append(
|
|
523
|
+
ModelRequest(
|
|
524
|
+
parts=[
|
|
525
|
+
ToolReturnPart(
|
|
526
|
+
tool_name=tool_name,
|
|
527
|
+
content=msg.content,
|
|
528
|
+
tool_call_id=msg.tool_call_id,
|
|
529
|
+
)
|
|
530
|
+
]
|
|
531
|
+
)
|
|
532
|
+
)
|
|
533
|
+
elif isinstance(msg, DeveloperMessage): # pragma: no branch
|
|
534
|
+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
535
|
+
|
|
536
|
+
return cls(
|
|
537
|
+
prompt_message_id=prompt_message_id,
|
|
538
|
+
messages=result,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
@runtime_checkable
|
|
543
|
+
class StateHandler(Protocol):
|
|
544
|
+
"""Protocol for state handlers in agent runs."""
|
|
545
|
+
|
|
546
|
+
@property
|
|
547
|
+
def state(self) -> State:
|
|
548
|
+
"""Get the current state of the agent run."""
|
|
549
|
+
...
|
|
550
|
+
|
|
551
|
+
@state.setter
|
|
552
|
+
def state(self, state: State) -> None:
|
|
553
|
+
"""Set the state of the agent run.
|
|
554
|
+
|
|
555
|
+
This method is called to update the state of the agent run with the
|
|
556
|
+
provided state.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
state: The run state.
|
|
560
|
+
|
|
561
|
+
Raises:
|
|
562
|
+
InvalidStateError: If `state` does not match the expected model.
|
|
563
|
+
"""
|
|
564
|
+
...
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
StateT = TypeVar('StateT', bound=BaseModel)
|
|
568
|
+
"""Type variable for the state type, which must be a subclass of `BaseModel`."""
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class StateDeps(Generic[StateT]):
|
|
572
|
+
"""Provides AG-UI state management.
|
|
573
|
+
|
|
574
|
+
This class is used to manage the state of an agent run. It allows setting
|
|
575
|
+
the state of the agent run with a specific type of state model, which must
|
|
576
|
+
be a subclass of `BaseModel`.
|
|
577
|
+
|
|
578
|
+
The state is set using the `state` setter by the `Adapter` when the run starts.
|
|
579
|
+
|
|
580
|
+
Implements the `StateHandler` protocol.
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
def __init__(self, default: StateT) -> None:
|
|
584
|
+
"""Initialize the state with the provided state type."""
|
|
585
|
+
self._state = default
|
|
586
|
+
|
|
587
|
+
@property
|
|
588
|
+
def state(self) -> StateT:
|
|
589
|
+
"""Get the current state of the agent run.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
The current run state.
|
|
593
|
+
"""
|
|
594
|
+
return self._state
|
|
595
|
+
|
|
596
|
+
@state.setter
|
|
597
|
+
def state(self, state: State) -> None:
|
|
598
|
+
"""Set the state of the agent run.
|
|
599
|
+
|
|
600
|
+
This method is called to update the state of the agent run with the
|
|
601
|
+
provided state.
|
|
602
|
+
|
|
603
|
+
Implements the `StateHandler` protocol.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
state: The run state, which must be `None` or model validate for the state type.
|
|
607
|
+
|
|
608
|
+
Raises:
|
|
609
|
+
InvalidStateError: If `state` does not validate.
|
|
610
|
+
"""
|
|
611
|
+
if state is None:
|
|
612
|
+
# If state is None, we keep the current state, which will be the default state.
|
|
613
|
+
return
|
|
614
|
+
|
|
615
|
+
try:
|
|
616
|
+
self._state = type(self._state).model_validate(state)
|
|
617
|
+
except ValidationError as e: # pragma: no cover
|
|
618
|
+
raise _InvalidStateError from e
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
@dataclass(repr=False)
|
|
622
|
+
class _RequestStreamContext:
|
|
623
|
+
"""Data class to hold request stream context."""
|
|
624
|
+
|
|
625
|
+
message_id: str = ''
|
|
626
|
+
part_end: BaseEvent | None = None
|
|
627
|
+
|
|
628
|
+
def new_message_id(self) -> str:
|
|
629
|
+
"""Generate a new message ID for the request stream.
|
|
630
|
+
|
|
631
|
+
Assigns a new UUID to the `message_id` and returns it.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
A new message ID.
|
|
635
|
+
"""
|
|
636
|
+
self.message_id = str(uuid.uuid4())
|
|
637
|
+
return self.message_id
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
@dataclass
|
|
641
|
+
class _RunError(Exception):
|
|
642
|
+
"""Exception raised for errors during agent runs."""
|
|
643
|
+
|
|
644
|
+
message: str
|
|
645
|
+
code: str
|
|
646
|
+
|
|
647
|
+
def __str__(self) -> str: # pragma: no cover
|
|
648
|
+
return self.message
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@dataclass
|
|
652
|
+
class _NoMessagesError(_RunError):
|
|
653
|
+
"""Exception raised when no messages are found in the input."""
|
|
654
|
+
|
|
655
|
+
message: str = 'no messages found in the input'
|
|
656
|
+
code: str = 'no_messages'
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
@dataclass
|
|
660
|
+
class _InvalidStateError(_RunError, ValidationError):
|
|
661
|
+
"""Exception raised when an invalid state is provided."""
|
|
662
|
+
|
|
663
|
+
message: str = 'invalid state provided'
|
|
664
|
+
code: str = 'invalid_state'
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class _ToolCallNotFoundError(_RunError, ValueError):
|
|
668
|
+
"""Exception raised when an tool result is present without a matching call."""
|
|
669
|
+
|
|
670
|
+
def __init__(self, tool_call_id: str) -> None:
|
|
671
|
+
"""Initialize the exception with the tool call ID."""
|
|
672
|
+
super().__init__( # pragma: no cover
|
|
673
|
+
message=f'Tool call with ID {tool_call_id} not found in the history.',
|
|
674
|
+
code='tool_call_not_found',
|
|
675
|
+
)
|