pydantic-ai-slim 1.8.0__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (33) hide show
  1. pydantic_ai/__init__.py +2 -0
  2. pydantic_ai/_agent_graph.py +3 -0
  3. pydantic_ai/ag_ui.py +50 -696
  4. pydantic_ai/agent/abstract.py +13 -3
  5. pydantic_ai/direct.py +12 -0
  6. pydantic_ai/durable_exec/dbos/_agent.py +3 -0
  7. pydantic_ai/durable_exec/prefect/_agent.py +3 -0
  8. pydantic_ai/durable_exec/temporal/_agent.py +3 -0
  9. pydantic_ai/messages.py +39 -7
  10. pydantic_ai/models/__init__.py +42 -1
  11. pydantic_ai/models/google.py +5 -12
  12. pydantic_ai/models/groq.py +9 -1
  13. pydantic_ai/providers/anthropic.py +2 -2
  14. pydantic_ai/result.py +19 -7
  15. pydantic_ai/ui/__init__.py +16 -0
  16. pydantic_ai/ui/_adapter.py +386 -0
  17. pydantic_ai/ui/_event_stream.py +591 -0
  18. pydantic_ai/ui/_messages_builder.py +28 -0
  19. pydantic_ai/ui/ag_ui/__init__.py +9 -0
  20. pydantic_ai/ui/ag_ui/_adapter.py +187 -0
  21. pydantic_ai/ui/ag_ui/_event_stream.py +227 -0
  22. pydantic_ai/ui/ag_ui/app.py +148 -0
  23. pydantic_ai/ui/vercel_ai/__init__.py +16 -0
  24. pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
  25. pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
  26. pydantic_ai/ui/vercel_ai/_utils.py +16 -0
  27. pydantic_ai/ui/vercel_ai/request_types.py +275 -0
  28. pydantic_ai/ui/vercel_ai/response_types.py +230 -0
  29. {pydantic_ai_slim-1.8.0.dist-info → pydantic_ai_slim-1.9.1.dist-info}/METADATA +5 -3
  30. {pydantic_ai_slim-1.8.0.dist-info → pydantic_ai_slim-1.9.1.dist-info}/RECORD +33 -19
  31. {pydantic_ai_slim-1.8.0.dist-info → pydantic_ai_slim-1.9.1.dist-info}/WHEEL +0 -0
  32. {pydantic_ai_slim-1.8.0.dist-info → pydantic_ai_slim-1.9.1.dist-info}/entry_points.txt +0 -0
  33. {pydantic_ai_slim-1.8.0.dist-info → pydantic_ai_slim-1.9.1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py CHANGED
@@ -4,107 +4,35 @@ This package provides seamless integration between pydantic-ai agents and ag-ui
4
4
  for building interactive AI applications with streaming event-based communication.
5
5
  """
6
6
 
7
- from __future__ import annotations
7
+ # TODO (v2): Remove this module in favor of `pydantic_ai.ui.ag_ui`
8
8
 
9
- import json
10
- import uuid
11
- from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
12
- from dataclasses import Field, dataclass, field, replace
13
- from http import HTTPStatus
14
- from typing import (
15
- Any,
16
- ClassVar,
17
- Final,
18
- Generic,
19
- Protocol,
20
- TypeAlias,
21
- TypeVar,
22
- runtime_checkable,
23
- )
9
+ from __future__ import annotations
24
10
 
25
- from pydantic import BaseModel, ValidationError
11
+ from collections.abc import AsyncIterator, Sequence
12
+ from typing import Any
26
13
 
27
- from . import _utils
28
- from ._agent_graph import CallToolsNode, ModelRequestNode
29
- from .agent import AbstractAgent, AgentRun, AgentRunResult
30
- from .exceptions import UserError
31
- from .messages import (
32
- BaseToolCallPart,
33
- BuiltinToolCallPart,
34
- BuiltinToolReturnPart,
35
- FunctionToolResultEvent,
36
- ModelMessage,
37
- ModelRequest,
38
- ModelRequestPart,
39
- ModelResponse,
40
- ModelResponsePart,
41
- ModelResponseStreamEvent,
42
- PartDeltaEvent,
43
- PartStartEvent,
44
- SystemPromptPart,
45
- TextPart,
46
- TextPartDelta,
47
- ThinkingPart,
48
- ThinkingPartDelta,
49
- ToolCallPart,
50
- ToolCallPartDelta,
51
- ToolReturnPart,
52
- UserPromptPart,
53
- )
14
+ from . import DeferredToolResults
15
+ from .agent import AbstractAgent
16
+ from .messages import ModelMessage
54
17
  from .models import KnownModelName, Model
55
- from .output import OutputDataT, OutputSpec
18
+ from .output import OutputSpec
56
19
  from .settings import ModelSettings
57
- from .tools import AgentDepsT, DeferredToolRequests, ToolDefinition
20
+ from .tools import AgentDepsT
58
21
  from .toolsets import AbstractToolset
59
- from .toolsets.external import ExternalToolset
60
22
  from .usage import RunUsage, UsageLimits
61
23
 
62
24
  try:
63
- from ag_ui.core import (
64
- AssistantMessage,
65
- BaseEvent,
66
- DeveloperMessage,
67
- EventType,
68
- Message,
69
- RunAgentInput,
70
- RunErrorEvent,
71
- RunFinishedEvent,
72
- RunStartedEvent,
73
- State,
74
- SystemMessage,
75
- TextMessageContentEvent,
76
- TextMessageEndEvent,
77
- TextMessageStartEvent,
78
- ThinkingEndEvent,
79
- ThinkingStartEvent,
80
- ThinkingTextMessageContentEvent,
81
- ThinkingTextMessageEndEvent,
82
- ThinkingTextMessageStartEvent,
83
- Tool as AGUITool,
84
- ToolCallArgsEvent,
85
- ToolCallEndEvent,
86
- ToolCallResultEvent,
87
- ToolCallStartEvent,
88
- ToolMessage,
89
- UserMessage,
90
- )
91
- from ag_ui.encoder import EventEncoder
92
- except ImportError as e: # pragma: no cover
93
- raise ImportError(
94
- 'Please install the `ag-ui-protocol` package to use `Agent.to_ag_ui()` method, '
95
- 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
96
- ) from e
97
-
98
- try:
99
- from starlette.applications import Starlette
100
- from starlette.middleware import Middleware
25
+ from ag_ui.core import BaseEvent
26
+ from ag_ui.core.types import RunAgentInput
101
27
  from starlette.requests import Request
102
- from starlette.responses import Response, StreamingResponse
103
- from starlette.routing import BaseRoute
104
- from starlette.types import ExceptionHandler, Lifespan
28
+ from starlette.responses import Response
29
+
30
+ from .ui import SSE_CONTENT_TYPE, OnCompleteFunc, StateDeps, StateHandler
31
+ from .ui.ag_ui import AGUIAdapter
32
+ from .ui.ag_ui.app import AGUIApp
105
33
  except ImportError as e: # pragma: no cover
106
34
  raise ImportError(
107
- 'Please install the `starlette` package to use `Agent.to_ag_ui()` method, '
35
+ 'Please install the `ag-ui-protocol` and `starlette` packages to use `AGUIAdapter`, '
108
36
  'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
109
37
  ) from e
110
38
 
@@ -119,113 +47,14 @@ __all__ = [
119
47
  'run_ag_ui',
120
48
  ]
121
49
 
122
- SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
123
- """Content type header value for Server-Sent Events (SSE)."""
124
-
125
- OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
126
- """Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
127
-
128
- _BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
129
-
130
-
131
- class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
132
- """ASGI application for running Pydantic AI agents with AG-UI protocol support."""
133
-
134
- def __init__(
135
- self,
136
- agent: AbstractAgent[AgentDepsT, OutputDataT],
137
- *,
138
- # Agent.iter parameters.
139
- output_type: OutputSpec[Any] | None = None,
140
- model: Model | KnownModelName | str | None = None,
141
- deps: AgentDepsT = None,
142
- model_settings: ModelSettings | None = None,
143
- usage_limits: UsageLimits | None = None,
144
- usage: RunUsage | None = None,
145
- infer_name: bool = True,
146
- toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
147
- # Starlette parameters.
148
- debug: bool = False,
149
- routes: Sequence[BaseRoute] | None = None,
150
- middleware: Sequence[Middleware] | None = None,
151
- exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
152
- on_startup: Sequence[Callable[[], Any]] | None = None,
153
- on_shutdown: Sequence[Callable[[], Any]] | None = None,
154
- lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
155
- ) -> None:
156
- """An ASGI application that handles every AG-UI request by running the agent.
157
-
158
- Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
159
- injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
160
- To provide different `deps` for each request (e.g. based on the authenticated user),
161
- use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
162
- [`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
163
-
164
- Args:
165
- agent: The agent to run.
166
-
167
- output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
168
- no output validators since output validators would expect an argument that matches the agent's
169
- output type.
170
- model: Optional model to use for this run, required if `model` was not set when creating the agent.
171
- deps: Optional dependencies to use for this run.
172
- model_settings: Optional settings to use for this model's request.
173
- usage_limits: Optional limits on model request count or token usage.
174
- usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
175
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
176
- toolsets: Optional additional toolsets for this run.
177
-
178
- debug: Boolean indicating if debug tracebacks should be returned on errors.
179
- routes: A list of routes to serve incoming HTTP and WebSocket requests.
180
- middleware: A list of middleware to run for every request. A starlette application will always
181
- automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
182
- outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
183
- `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
184
- exception cases occurring in the routing or endpoints.
185
- exception_handlers: A mapping of either integer status codes, or exception class types onto
186
- callables which handle the exceptions. Exception handler callables should be of the form
187
- `handler(request, exc) -> response` and may be either standard functions, or async functions.
188
- on_startup: A list of callables to run on application startup. Startup handler callables do not
189
- take any arguments, and may be either standard functions, or async functions.
190
- on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
191
- not take any arguments, and may be either standard functions, or async functions.
192
- lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
193
- This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
194
- the other, not both.
195
- """
196
- super().__init__(
197
- debug=debug,
198
- routes=routes,
199
- middleware=middleware,
200
- exception_handlers=exception_handlers,
201
- on_startup=on_startup,
202
- on_shutdown=on_shutdown,
203
- lifespan=lifespan,
204
- )
205
-
206
- async def endpoint(request: Request) -> Response:
207
- """Endpoint to run the agent with the provided input data."""
208
- return await handle_ag_ui_request(
209
- agent,
210
- request,
211
- output_type=output_type,
212
- model=model,
213
- deps=deps,
214
- model_settings=model_settings,
215
- usage_limits=usage_limits,
216
- usage=usage,
217
- infer_name=infer_name,
218
- toolsets=toolsets,
219
- )
220
-
221
- self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
222
-
223
50
 
224
51
  async def handle_ag_ui_request(
225
52
  agent: AbstractAgent[AgentDepsT, Any],
226
53
  request: Request,
227
54
  *,
228
55
  output_type: OutputSpec[Any] | None = None,
56
+ message_history: Sequence[ModelMessage] | None = None,
57
+ deferred_tool_results: DeferredToolResults | None = None,
229
58
  model: Model | KnownModelName | str | None = None,
230
59
  deps: AgentDepsT = None,
231
60
  model_settings: ModelSettings | None = None,
@@ -233,7 +62,7 @@ async def handle_ag_ui_request(
233
62
  usage: RunUsage | None = None,
234
63
  infer_name: bool = True,
235
64
  toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
236
- on_complete: OnCompleteFunc | None = None,
65
+ on_complete: OnCompleteFunc[BaseEvent] | None = None,
237
66
  ) -> Response:
238
67
  """Handle an AG-UI request by running the agent and returning a streaming response.
239
68
 
@@ -243,6 +72,8 @@ async def handle_ag_ui_request(
243
72
 
244
73
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
245
74
  output validators since output validators would expect an argument that matches the agent's output type.
75
+ message_history: History of the conversation so far.
76
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
246
77
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
247
78
  deps: Optional dependencies to use for this run.
248
79
  model_settings: Optional settings to use for this model's request.
@@ -256,41 +87,31 @@ async def handle_ag_ui_request(
256
87
  Returns:
257
88
  A streaming Starlette response with AG-UI protocol events.
258
89
  """
259
- accept = request.headers.get('accept', SSE_CONTENT_TYPE)
260
- try:
261
- input_data = RunAgentInput.model_validate(await request.json())
262
- except ValidationError as e: # pragma: no cover
263
- return Response(
264
- content=json.dumps(e.json()),
265
- media_type='application/json',
266
- status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
267
- )
268
-
269
- return StreamingResponse(
270
- run_ag_ui(
271
- agent,
272
- input_data,
273
- accept,
274
- output_type=output_type,
275
- model=model,
276
- deps=deps,
277
- model_settings=model_settings,
278
- usage_limits=usage_limits,
279
- usage=usage,
280
- infer_name=infer_name,
281
- toolsets=toolsets,
282
- on_complete=on_complete,
283
- ),
284
- media_type=accept,
90
+ return await AGUIAdapter[AgentDepsT].dispatch_request(
91
+ request,
92
+ agent=agent,
93
+ deps=deps,
94
+ output_type=output_type,
95
+ message_history=message_history,
96
+ deferred_tool_results=deferred_tool_results,
97
+ model=model,
98
+ model_settings=model_settings,
99
+ usage_limits=usage_limits,
100
+ usage=usage,
101
+ infer_name=infer_name,
102
+ toolsets=toolsets,
103
+ on_complete=on_complete,
285
104
  )
286
105
 
287
106
 
288
- async def run_ag_ui(
107
+ def run_ag_ui(
289
108
  agent: AbstractAgent[AgentDepsT, Any],
290
109
  run_input: RunAgentInput,
291
110
  accept: str = SSE_CONTENT_TYPE,
292
111
  *,
293
112
  output_type: OutputSpec[Any] | None = None,
113
+ message_history: Sequence[ModelMessage] | None = None,
114
+ deferred_tool_results: DeferredToolResults | None = None,
294
115
  model: Model | KnownModelName | str | None = None,
295
116
  deps: AgentDepsT = None,
296
117
  model_settings: ModelSettings | None = None,
@@ -298,7 +119,7 @@ async def run_ag_ui(
298
119
  usage: RunUsage | None = None,
299
120
  infer_name: bool = True,
300
121
  toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
301
- on_complete: OnCompleteFunc | None = None,
122
+ on_complete: OnCompleteFunc[BaseEvent] | None = None,
302
123
  ) -> AsyncIterator[str]:
303
124
  """Run the agent with the AG-UI run input and stream AG-UI protocol events.
304
125
 
@@ -309,6 +130,8 @@ async def run_ag_ui(
309
130
 
310
131
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
311
132
  output validators since output validators would expect an argument that matches the agent's output type.
133
+ message_history: History of the conversation so far.
134
+ deferred_tool_results: Optional results for deferred tool calls in the message history.
312
135
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
313
136
  deps: Optional dependencies to use for this run.
314
137
  model_settings: Optional settings to use for this model's request.
@@ -322,50 +145,12 @@ async def run_ag_ui(
322
145
  Yields:
323
146
  Streaming event chunks encoded as strings according to the accept header value.
324
147
  """
325
- encoder = EventEncoder(accept=accept)
326
- if run_input.tools:
327
- # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
328
- # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
329
- # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
330
- toolset = _AGUIFrontendToolset[AgentDepsT](run_input.tools)
331
- toolsets = [*toolsets, toolset] if toolsets else [toolset]
332
-
333
- try:
334
- yield encoder.encode(
335
- RunStartedEvent(
336
- thread_id=run_input.thread_id,
337
- run_id=run_input.run_id,
338
- ),
339
- )
340
-
341
- if not run_input.messages:
342
- raise _NoMessagesError
343
-
344
- raw_state: dict[str, Any] = run_input.state or {}
345
- if isinstance(deps, StateHandler):
346
- if isinstance(deps.state, BaseModel):
347
- try:
348
- state = type(deps.state).model_validate(raw_state)
349
- except ValidationError as e: # pragma: no cover
350
- raise _InvalidStateError from e
351
- else:
352
- state = raw_state
353
-
354
- deps = replace(deps, state=state)
355
- elif raw_state:
356
- raise UserError(
357
- f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
358
- )
359
- else:
360
- # `deps` not being a `StateHandler` is OK if there is no state.
361
- pass
362
-
363
- messages = _messages_from_ag_ui(run_input.messages)
364
-
365
- async with agent.iter(
366
- user_prompt=None,
367
- output_type=[output_type or agent.output_type, DeferredToolRequests],
368
- message_history=messages,
148
+ adapter = AGUIAdapter(agent=agent, run_input=run_input, accept=accept)
149
+ return adapter.encode_stream(
150
+ adapter.run_stream(
151
+ output_type=output_type,
152
+ message_history=message_history,
153
+ deferred_tool_results=deferred_tool_results,
369
154
  model=model,
370
155
  deps=deps,
371
156
  model_settings=model_settings,
@@ -373,437 +158,6 @@ async def run_ag_ui(
373
158
  usage=usage,
374
159
  infer_name=infer_name,
375
160
  toolsets=toolsets,
376
- ) as run:
377
- async for event in _agent_stream(run):
378
- yield encoder.encode(event)
379
-
380
- if on_complete is not None and run.result is not None:
381
- if _utils.is_async_callable(on_complete):
382
- await on_complete(run.result)
383
- else:
384
- await _utils.run_in_executor(on_complete, run.result)
385
- except _RunError as e:
386
- yield encoder.encode(
387
- RunErrorEvent(message=e.message, code=e.code),
388
- )
389
- except Exception as e:
390
- yield encoder.encode(
391
- RunErrorEvent(message=str(e)),
392
- )
393
- raise e
394
- else:
395
- yield encoder.encode(
396
- RunFinishedEvent(
397
- thread_id=run_input.thread_id,
398
- run_id=run_input.run_id,
399
- ),
400
- )
401
-
402
-
403
- async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEvent]:
404
- """Run the agent streaming responses using AG-UI protocol events.
405
-
406
- Args:
407
- run: The agent run to process.
408
-
409
- Yields:
410
- AG-UI Server-Sent Events (SSE).
411
- """
412
- async for node in run:
413
- stream_ctx = _RequestStreamContext()
414
- if isinstance(node, ModelRequestNode):
415
- async with node.stream(run.ctx) as request_stream:
416
- async for agent_event in request_stream:
417
- async for msg in _handle_model_request_event(stream_ctx, agent_event):
418
- yield msg
419
-
420
- if stream_ctx.part_end: # pragma: no branch
421
- yield stream_ctx.part_end
422
- stream_ctx.part_end = None
423
- if stream_ctx.thinking:
424
- yield ThinkingEndEvent(
425
- type=EventType.THINKING_END,
426
- )
427
- stream_ctx.thinking = False
428
- elif isinstance(node, CallToolsNode):
429
- async with node.stream(run.ctx) as handle_stream:
430
- async for event in handle_stream:
431
- if isinstance(event, FunctionToolResultEvent):
432
- async for msg in _handle_tool_result_event(stream_ctx, event):
433
- yield msg
434
-
435
-
436
- async def _handle_model_request_event( # noqa: C901
437
- stream_ctx: _RequestStreamContext,
438
- agent_event: ModelResponseStreamEvent,
439
- ) -> AsyncIterator[BaseEvent]:
440
- """Handle an agent event and yield AG-UI protocol events.
441
-
442
- Args:
443
- stream_ctx: The request stream context to manage state.
444
- agent_event: The agent event to process.
445
-
446
- Yields:
447
- AG-UI Server-Sent Events (SSE) based on the agent event.
448
- """
449
- if isinstance(agent_event, PartStartEvent):
450
- if stream_ctx.part_end:
451
- # End the previous part.
452
- yield stream_ctx.part_end
453
- stream_ctx.part_end = None
454
-
455
- part = agent_event.part
456
- if isinstance(part, ThinkingPart): # pragma: no branch
457
- if not stream_ctx.thinking:
458
- yield ThinkingStartEvent(
459
- type=EventType.THINKING_START,
460
- )
461
- stream_ctx.thinking = True
462
-
463
- if part.content:
464
- yield ThinkingTextMessageStartEvent(
465
- type=EventType.THINKING_TEXT_MESSAGE_START,
466
- )
467
- yield ThinkingTextMessageContentEvent(
468
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
469
- delta=part.content,
470
- )
471
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
472
- type=EventType.THINKING_TEXT_MESSAGE_END,
473
- )
474
- else:
475
- if stream_ctx.thinking:
476
- yield ThinkingEndEvent(
477
- type=EventType.THINKING_END,
478
- )
479
- stream_ctx.thinking = False
480
-
481
- if isinstance(part, TextPart):
482
- message_id = stream_ctx.new_message_id()
483
- yield TextMessageStartEvent(
484
- message_id=message_id,
485
- )
486
- if part.content: # pragma: no branch
487
- yield TextMessageContentEvent(
488
- message_id=message_id,
489
- delta=part.content,
490
- )
491
- stream_ctx.part_end = TextMessageEndEvent(
492
- message_id=message_id,
493
- )
494
- elif isinstance(part, BaseToolCallPart):
495
- tool_call_id = part.tool_call_id
496
- if isinstance(part, BuiltinToolCallPart):
497
- builtin_tool_call_id = '|'.join(
498
- [_BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id]
499
- )
500
- stream_ctx.builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
501
- tool_call_id = builtin_tool_call_id
502
-
503
- message_id = stream_ctx.message_id or stream_ctx.new_message_id()
504
- yield ToolCallStartEvent(
505
- tool_call_id=tool_call_id,
506
- tool_call_name=part.tool_name,
507
- parent_message_id=message_id,
508
- )
509
- if part.args:
510
- yield ToolCallArgsEvent(
511
- tool_call_id=tool_call_id,
512
- delta=part.args_as_json_str(),
513
- )
514
- stream_ctx.part_end = ToolCallEndEvent(
515
- tool_call_id=tool_call_id,
516
- )
517
- elif isinstance(part, BuiltinToolReturnPart): # pragma: no branch
518
- tool_call_id = stream_ctx.builtin_tool_call_ids[part.tool_call_id]
519
- yield ToolCallResultEvent(
520
- message_id=stream_ctx.new_message_id(),
521
- type=EventType.TOOL_CALL_RESULT,
522
- role='tool',
523
- tool_call_id=tool_call_id,
524
- content=part.model_response_str(),
525
- )
526
-
527
- elif isinstance(agent_event, PartDeltaEvent):
528
- delta = agent_event.delta
529
- if isinstance(delta, TextPartDelta):
530
- if delta.content_delta: # pragma: no branch
531
- yield TextMessageContentEvent(
532
- message_id=stream_ctx.message_id,
533
- delta=delta.content_delta,
534
- )
535
- elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
536
- tool_call_id = delta.tool_call_id
537
- assert tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
538
- if tool_call_id in stream_ctx.builtin_tool_call_ids:
539
- tool_call_id = stream_ctx.builtin_tool_call_ids[tool_call_id]
540
- yield ToolCallArgsEvent(
541
- tool_call_id=tool_call_id,
542
- delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
543
- )
544
- elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
545
- if delta.content_delta: # pragma: no branch
546
- if not isinstance(stream_ctx.part_end, ThinkingTextMessageEndEvent):
547
- yield ThinkingTextMessageStartEvent(
548
- type=EventType.THINKING_TEXT_MESSAGE_START,
549
- )
550
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
551
- type=EventType.THINKING_TEXT_MESSAGE_END,
552
- )
553
-
554
- yield ThinkingTextMessageContentEvent(
555
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
556
- delta=delta.content_delta,
557
- )
558
-
559
-
560
- async def _handle_tool_result_event(
561
- stream_ctx: _RequestStreamContext,
562
- event: FunctionToolResultEvent,
563
- ) -> AsyncIterator[BaseEvent]:
564
- """Convert a tool call result to AG-UI events.
565
-
566
- Args:
567
- stream_ctx: The request stream context to manage state.
568
- event: The tool call result event to process.
569
-
570
- Yields:
571
- AG-UI Server-Sent Events (SSE).
572
- """
573
- result = event.result
574
- if not isinstance(result, ToolReturnPart):
575
- return
576
-
577
- yield ToolCallResultEvent(
578
- message_id=stream_ctx.new_message_id(),
579
- type=EventType.TOOL_CALL_RESULT,
580
- role='tool',
581
- tool_call_id=result.tool_call_id,
582
- content=result.model_response_str(),
161
+ on_complete=on_complete,
162
+ ),
583
163
  )
584
-
585
- # Now check for AG-UI events returned by the tool calls.
586
- possible_event = result.metadata or result.content
587
- if isinstance(possible_event, BaseEvent):
588
- yield possible_event
589
- elif isinstance(possible_event, str | bytes): # pragma: no branch
590
- # Avoid iterable check for strings and bytes.
591
- pass
592
- elif isinstance(possible_event, Iterable): # pragma: no branch
593
- for item in possible_event: # type: ignore[reportUnknownMemberType]
594
- if isinstance(item, BaseEvent): # pragma: no branch
595
- yield item
596
-
597
-
598
- def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
599
- """Convert a AG-UI history to a Pydantic AI one."""
600
- result: list[ModelMessage] = []
601
- tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
602
- request_parts: list[ModelRequestPart] | None = None
603
- response_parts: list[ModelResponsePart] | None = None
604
- for msg in messages:
605
- if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
606
- isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
607
- ):
608
- if request_parts is None:
609
- request_parts = []
610
- result.append(ModelRequest(parts=request_parts))
611
- response_parts = None
612
-
613
- if isinstance(msg, UserMessage):
614
- request_parts.append(UserPromptPart(content=msg.content))
615
- elif isinstance(msg, SystemMessage | DeveloperMessage):
616
- request_parts.append(SystemPromptPart(content=msg.content))
617
- else:
618
- tool_call_id = msg.tool_call_id
619
- tool_name = tool_calls.get(tool_call_id)
620
- if tool_name is None: # pragma: no cover
621
- raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
622
-
623
- request_parts.append(
624
- ToolReturnPart(
625
- tool_name=tool_name,
626
- content=msg.content,
627
- tool_call_id=tool_call_id,
628
- )
629
- )
630
-
631
- elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
632
- isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
633
- ):
634
- if response_parts is None:
635
- response_parts = []
636
- result.append(ModelResponse(parts=response_parts))
637
- request_parts = None
638
-
639
- if isinstance(msg, AssistantMessage):
640
- if msg.content:
641
- response_parts.append(TextPart(content=msg.content))
642
-
643
- if msg.tool_calls:
644
- for tool_call in msg.tool_calls:
645
- tool_call_id = tool_call.id
646
- tool_name = tool_call.function.name
647
- tool_calls[tool_call_id] = tool_name
648
-
649
- if tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX):
650
- _, provider_name, tool_call_id = tool_call_id.split('|', 2)
651
- response_parts.append(
652
- BuiltinToolCallPart(
653
- tool_name=tool_name,
654
- args=tool_call.function.arguments,
655
- tool_call_id=tool_call_id,
656
- provider_name=provider_name,
657
- )
658
- )
659
- else:
660
- response_parts.append(
661
- ToolCallPart(
662
- tool_name=tool_name,
663
- tool_call_id=tool_call_id,
664
- args=tool_call.function.arguments,
665
- )
666
- )
667
- else:
668
- tool_call_id = msg.tool_call_id
669
- tool_name = tool_calls.get(tool_call_id)
670
- if tool_name is None: # pragma: no cover
671
- raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
672
- _, provider_name, tool_call_id = tool_call_id.split('|', 2)
673
-
674
- response_parts.append(
675
- BuiltinToolReturnPart(
676
- tool_name=tool_name,
677
- content=msg.content,
678
- tool_call_id=tool_call_id,
679
- provider_name=provider_name,
680
- )
681
- )
682
-
683
- return result
684
-
685
-
686
- @runtime_checkable
687
- class StateHandler(Protocol):
688
- """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field."""
689
-
690
- # Has to be a dataclass so we can use `replace` to update the state.
691
- # From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352
692
- __dataclass_fields__: ClassVar[dict[str, Field[Any]]]
693
-
694
- @property
695
- def state(self) -> State:
696
- """Get the current state of the agent run."""
697
- ...
698
-
699
- @state.setter
700
- def state(self, state: State) -> None:
701
- """Set the state of the agent run.
702
-
703
- This method is called to update the state of the agent run with the
704
- provided state.
705
-
706
- Args:
707
- state: The run state.
708
-
709
- Raises:
710
- InvalidStateError: If `state` does not match the expected model.
711
- """
712
- ...
713
-
714
-
715
- StateT = TypeVar('StateT', bound=BaseModel)
716
- """Type variable for the state type, which must be a subclass of `BaseModel`."""
717
-
718
-
719
- @dataclass
720
- class StateDeps(Generic[StateT]):
721
- """Provides AG-UI state management.
722
-
723
- This class is used to manage the state of an agent run. It allows setting
724
- the state of the agent run with a specific type of state model, which must
725
- be a subclass of `BaseModel`.
726
-
727
- The state is set using the `state` setter by the `Adapter` when the run starts.
728
-
729
- Implements the `StateHandler` protocol.
730
- """
731
-
732
- state: StateT
733
-
734
-
735
- @dataclass(repr=False)
736
- class _RequestStreamContext:
737
- """Data class to hold request stream context."""
738
-
739
- message_id: str = ''
740
- part_end: BaseEvent | None = None
741
- thinking: bool = False
742
- builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
743
-
744
- def new_message_id(self) -> str:
745
- """Generate a new message ID for the request stream.
746
-
747
- Assigns a new UUID to the `message_id` and returns it.
748
-
749
- Returns:
750
- A new message ID.
751
- """
752
- self.message_id = str(uuid.uuid4())
753
- return self.message_id
754
-
755
-
756
- @dataclass
757
- class _RunError(Exception):
758
- """Exception raised for errors during agent runs."""
759
-
760
- message: str
761
- code: str
762
-
763
- def __str__(self) -> str: # pragma: no cover
764
- return self.message
765
-
766
-
767
- @dataclass
768
- class _NoMessagesError(_RunError):
769
- """Exception raised when no messages are found in the input."""
770
-
771
- message: str = 'no messages found in the input'
772
- code: str = 'no_messages'
773
-
774
-
775
- @dataclass
776
- class _InvalidStateError(_RunError, ValidationError):
777
- """Exception raised when an invalid state is provided."""
778
-
779
- message: str = 'invalid state provided'
780
- code: str = 'invalid_state'
781
-
782
-
783
- class _ToolCallNotFoundError(_RunError, ValueError):
784
- """Exception raised when an tool result is present without a matching call."""
785
-
786
- def __init__(self, tool_call_id: str) -> None:
787
- """Initialize the exception with the tool call ID."""
788
- super().__init__( # pragma: no cover
789
- message=f'Tool call with ID {tool_call_id} not found in the history.',
790
- code='tool_call_not_found',
791
- )
792
-
793
-
794
- class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
795
- def __init__(self, tools: list[AGUITool]):
796
- super().__init__(
797
- [
798
- ToolDefinition(
799
- name=tool.name,
800
- description=tool.description,
801
- parameters_json_schema=tool.parameters,
802
- )
803
- for tool in tools
804
- ]
805
- )
806
-
807
- @property
808
- def label(self) -> str:
809
- return 'the AG-UI frontend tools' # pragma: no cover