pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +6 -0
- pydantic_ai/_agent_graph.py +67 -20
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_output.py +20 -12
- pydantic_ai/_run_context.py +6 -2
- pydantic_ai/_utils.py +26 -8
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -25
- pydantic_ai/agent/abstract.py +146 -9
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/__init__.py +11 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/exceptions.py +6 -1
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/messages.py +46 -8
- pydantic_ai/models/__init__.py +87 -38
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +26 -23
- pydantic_ai/models/groq.py +13 -5
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +251 -52
- pydantic_ai/models/outlines.py +563 -0
- pydantic_ai/models/test.py +6 -3
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/__init__.py +25 -12
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +91 -24
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/providers/outlines.py +40 -0
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/result.py +173 -8
- pydantic_ai/run.py +40 -24
- pydantic_ai/settings.py +8 -0
- pydantic_ai/tools.py +10 -6
- pydantic_ai/toolsets/fastmcp.py +215 -0
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""AG-UI adapter for handling requests."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Any,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from ... import ExternalToolset, ToolDefinition
|
|
13
|
+
from ...messages import (
|
|
14
|
+
BuiltinToolCallPart,
|
|
15
|
+
BuiltinToolReturnPart,
|
|
16
|
+
ModelMessage,
|
|
17
|
+
SystemPromptPart,
|
|
18
|
+
TextPart,
|
|
19
|
+
ToolCallPart,
|
|
20
|
+
ToolReturnPart,
|
|
21
|
+
UserPromptPart,
|
|
22
|
+
)
|
|
23
|
+
from ...output import OutputDataT
|
|
24
|
+
from ...tools import AgentDepsT
|
|
25
|
+
from ...toolsets import AbstractToolset
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from ag_ui.core import (
|
|
29
|
+
AssistantMessage,
|
|
30
|
+
BaseEvent,
|
|
31
|
+
DeveloperMessage,
|
|
32
|
+
Message,
|
|
33
|
+
RunAgentInput,
|
|
34
|
+
SystemMessage,
|
|
35
|
+
Tool as AGUITool,
|
|
36
|
+
ToolMessage,
|
|
37
|
+
UserMessage,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from .. import MessagesBuilder, UIAdapter, UIEventStream
|
|
41
|
+
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
|
|
42
|
+
except ImportError as e: # pragma: no cover
|
|
43
|
+
raise ImportError(
|
|
44
|
+
'Please install the `ag-ui-protocol` package to use AG-UI integration, '
|
|
45
|
+
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
46
|
+
) from e
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
__all__ = ['AGUIAdapter']
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Frontend toolset
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
|
|
58
|
+
"""Toolset for AG-UI frontend tools."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, tools: list[AGUITool]):
|
|
61
|
+
"""Initialize the toolset with AG-UI tools.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tools: List of AG-UI tool definitions.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(
|
|
67
|
+
[
|
|
68
|
+
ToolDefinition(
|
|
69
|
+
name=tool.name,
|
|
70
|
+
description=tool.description,
|
|
71
|
+
parameters_json_schema=tool.parameters,
|
|
72
|
+
)
|
|
73
|
+
for tool in tools
|
|
74
|
+
]
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def label(self) -> str:
|
|
79
|
+
"""Return the label for this toolset."""
|
|
80
|
+
return 'the AG-UI frontend tools' # pragma: no cover
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class AGUIAdapter(UIAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]):
|
|
84
|
+
"""UI adapter for the Agent-User Interaction (AG-UI) protocol."""
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def build_run_input(cls, body: bytes) -> RunAgentInput:
|
|
88
|
+
"""Build an AG-UI run input object from the request body."""
|
|
89
|
+
return RunAgentInput.model_validate_json(body)
|
|
90
|
+
|
|
91
|
+
def build_event_stream(self) -> UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
|
|
92
|
+
"""Build an AG-UI event stream transformer."""
|
|
93
|
+
return AGUIEventStream(self.run_input, accept=self.accept)
|
|
94
|
+
|
|
95
|
+
@cached_property
|
|
96
|
+
def messages(self) -> list[ModelMessage]:
|
|
97
|
+
"""Pydantic AI messages from the AG-UI run input."""
|
|
98
|
+
return self.load_messages(self.run_input.messages)
|
|
99
|
+
|
|
100
|
+
@cached_property
|
|
101
|
+
def toolset(self) -> AbstractToolset[AgentDepsT] | None:
|
|
102
|
+
"""Toolset representing frontend tools from the AG-UI run input."""
|
|
103
|
+
if self.run_input.tools:
|
|
104
|
+
return _AGUIFrontendToolset[AgentDepsT](self.run_input.tools)
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
@cached_property
|
|
108
|
+
def state(self) -> dict[str, Any] | None:
|
|
109
|
+
"""Frontend state from the AG-UI run input."""
|
|
110
|
+
return self.run_input.state
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
|
|
114
|
+
"""Transform AG-UI messages into Pydantic AI messages."""
|
|
115
|
+
builder = MessagesBuilder()
|
|
116
|
+
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
117
|
+
|
|
118
|
+
for msg in messages:
|
|
119
|
+
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
|
|
120
|
+
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
121
|
+
):
|
|
122
|
+
if isinstance(msg, UserMessage):
|
|
123
|
+
builder.add(UserPromptPart(content=msg.content))
|
|
124
|
+
elif isinstance(msg, SystemMessage | DeveloperMessage):
|
|
125
|
+
builder.add(SystemPromptPart(content=msg.content))
|
|
126
|
+
else:
|
|
127
|
+
tool_call_id = msg.tool_call_id
|
|
128
|
+
tool_name = tool_calls.get(tool_call_id)
|
|
129
|
+
if tool_name is None: # pragma: no cover
|
|
130
|
+
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
|
|
131
|
+
|
|
132
|
+
builder.add(
|
|
133
|
+
ToolReturnPart(
|
|
134
|
+
tool_name=tool_name,
|
|
135
|
+
content=msg.content,
|
|
136
|
+
tool_call_id=tool_call_id,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
|
|
141
|
+
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
142
|
+
):
|
|
143
|
+
if isinstance(msg, AssistantMessage):
|
|
144
|
+
if msg.content:
|
|
145
|
+
builder.add(TextPart(content=msg.content))
|
|
146
|
+
|
|
147
|
+
if msg.tool_calls:
|
|
148
|
+
for tool_call in msg.tool_calls:
|
|
149
|
+
tool_call_id = tool_call.id
|
|
150
|
+
tool_name = tool_call.function.name
|
|
151
|
+
tool_calls[tool_call_id] = tool_name
|
|
152
|
+
|
|
153
|
+
if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
|
|
154
|
+
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
155
|
+
builder.add(
|
|
156
|
+
BuiltinToolCallPart(
|
|
157
|
+
tool_name=tool_name,
|
|
158
|
+
args=tool_call.function.arguments,
|
|
159
|
+
tool_call_id=tool_call_id,
|
|
160
|
+
provider_name=provider_name,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
builder.add(
|
|
165
|
+
ToolCallPart(
|
|
166
|
+
tool_name=tool_name,
|
|
167
|
+
tool_call_id=tool_call_id,
|
|
168
|
+
args=tool_call.function.arguments,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
tool_call_id = msg.tool_call_id
|
|
173
|
+
tool_name = tool_calls.get(tool_call_id)
|
|
174
|
+
if tool_name is None: # pragma: no cover
|
|
175
|
+
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
|
|
176
|
+
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
177
|
+
|
|
178
|
+
builder.add(
|
|
179
|
+
BuiltinToolReturnPart(
|
|
180
|
+
tool_name=tool_name,
|
|
181
|
+
content=msg.content,
|
|
182
|
+
tool_call_id=tool_call_id,
|
|
183
|
+
provider_name=provider_name,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return builder.messages
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""AG-UI protocol adapter for Pydantic AI agents.
|
|
2
|
+
|
|
3
|
+
This module provides classes for integrating Pydantic AI agents with the AG-UI protocol,
|
|
4
|
+
enabling streaming event-based communication for interactive AI applications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from collections.abc import AsyncIterator, Iterable
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Final
|
|
13
|
+
|
|
14
|
+
from ...messages import (
|
|
15
|
+
BuiltinToolCallPart,
|
|
16
|
+
BuiltinToolReturnPart,
|
|
17
|
+
FunctionToolResultEvent,
|
|
18
|
+
RetryPromptPart,
|
|
19
|
+
TextPart,
|
|
20
|
+
TextPartDelta,
|
|
21
|
+
ThinkingPart,
|
|
22
|
+
ThinkingPartDelta,
|
|
23
|
+
ToolCallPart,
|
|
24
|
+
ToolCallPartDelta,
|
|
25
|
+
ToolReturnPart,
|
|
26
|
+
)
|
|
27
|
+
from ...output import OutputDataT
|
|
28
|
+
from ...tools import AgentDepsT
|
|
29
|
+
from .. import SSE_CONTENT_TYPE, UIEventStream
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from ag_ui.core import (
|
|
33
|
+
BaseEvent,
|
|
34
|
+
EventType,
|
|
35
|
+
RunAgentInput,
|
|
36
|
+
RunErrorEvent,
|
|
37
|
+
RunFinishedEvent,
|
|
38
|
+
RunStartedEvent,
|
|
39
|
+
TextMessageContentEvent,
|
|
40
|
+
TextMessageEndEvent,
|
|
41
|
+
TextMessageStartEvent,
|
|
42
|
+
ThinkingEndEvent,
|
|
43
|
+
ThinkingStartEvent,
|
|
44
|
+
ThinkingTextMessageContentEvent,
|
|
45
|
+
ThinkingTextMessageEndEvent,
|
|
46
|
+
ThinkingTextMessageStartEvent,
|
|
47
|
+
ToolCallArgsEvent,
|
|
48
|
+
ToolCallEndEvent,
|
|
49
|
+
ToolCallResultEvent,
|
|
50
|
+
ToolCallStartEvent,
|
|
51
|
+
)
|
|
52
|
+
from ag_ui.encoder import EventEncoder
|
|
53
|
+
|
|
54
|
+
except ImportError as e: # pragma: no cover
|
|
55
|
+
raise ImportError(
|
|
56
|
+
'Please install the `ag-ui-protocol` package to use AG-UI integration, '
|
|
57
|
+
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
58
|
+
) from e
|
|
59
|
+
|
|
60
|
+
__all__ = [
|
|
61
|
+
'AGUIEventStream',
|
|
62
|
+
'RunAgentInput',
|
|
63
|
+
'RunStartedEvent',
|
|
64
|
+
'RunFinishedEvent',
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class AGUIEventStream(UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]):
|
|
72
|
+
"""UI event stream transformer for the Agent-User Interaction (AG-UI) protocol."""
|
|
73
|
+
|
|
74
|
+
_thinking_text: bool = False
|
|
75
|
+
_builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
|
|
76
|
+
_error: bool = False
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def _event_encoder(self) -> EventEncoder:
|
|
80
|
+
return EventEncoder(accept=self.accept or SSE_CONTENT_TYPE)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def content_type(self) -> str:
|
|
84
|
+
return self._event_encoder.get_content_type()
|
|
85
|
+
|
|
86
|
+
def encode_event(self, event: BaseEvent) -> str:
|
|
87
|
+
return self._event_encoder.encode(event)
|
|
88
|
+
|
|
89
|
+
async def before_stream(self) -> AsyncIterator[BaseEvent]:
|
|
90
|
+
yield RunStartedEvent(
|
|
91
|
+
thread_id=self.run_input.thread_id,
|
|
92
|
+
run_id=self.run_input.run_id,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
async def before_response(self) -> AsyncIterator[BaseEvent]:
|
|
96
|
+
# Prevent parts from a subsequent response being tied to parts from an earlier response.
|
|
97
|
+
# See https://github.com/pydantic/pydantic-ai/issues/3316
|
|
98
|
+
self.new_message_id()
|
|
99
|
+
return
|
|
100
|
+
yield # Make this an async generator
|
|
101
|
+
|
|
102
|
+
async def after_stream(self) -> AsyncIterator[BaseEvent]:
|
|
103
|
+
if not self._error:
|
|
104
|
+
yield RunFinishedEvent(
|
|
105
|
+
thread_id=self.run_input.thread_id,
|
|
106
|
+
run_id=self.run_input.run_id,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def on_error(self, error: Exception) -> AsyncIterator[BaseEvent]:
|
|
110
|
+
self._error = True
|
|
111
|
+
yield RunErrorEvent(message=str(error))
|
|
112
|
+
|
|
113
|
+
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseEvent]:
|
|
114
|
+
if follows_text:
|
|
115
|
+
message_id = self.message_id
|
|
116
|
+
else:
|
|
117
|
+
message_id = self.new_message_id()
|
|
118
|
+
yield TextMessageStartEvent(message_id=message_id)
|
|
119
|
+
|
|
120
|
+
if part.content: # pragma: no branch
|
|
121
|
+
yield TextMessageContentEvent(message_id=message_id, delta=part.content)
|
|
122
|
+
|
|
123
|
+
async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[BaseEvent]:
|
|
124
|
+
if delta.content_delta: # pragma: no branch
|
|
125
|
+
yield TextMessageContentEvent(message_id=self.message_id, delta=delta.content_delta)
|
|
126
|
+
|
|
127
|
+
async def handle_text_end(self, part: TextPart, followed_by_text: bool = False) -> AsyncIterator[BaseEvent]:
|
|
128
|
+
if not followed_by_text:
|
|
129
|
+
yield TextMessageEndEvent(message_id=self.message_id)
|
|
130
|
+
|
|
131
|
+
async def handle_thinking_start(
|
|
132
|
+
self, part: ThinkingPart, follows_thinking: bool = False
|
|
133
|
+
) -> AsyncIterator[BaseEvent]:
|
|
134
|
+
if not follows_thinking:
|
|
135
|
+
yield ThinkingStartEvent(type=EventType.THINKING_START)
|
|
136
|
+
|
|
137
|
+
if part.content:
|
|
138
|
+
yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START)
|
|
139
|
+
yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=part.content)
|
|
140
|
+
self._thinking_text = True
|
|
141
|
+
|
|
142
|
+
async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator[BaseEvent]:
|
|
143
|
+
if not delta.content_delta:
|
|
144
|
+
return # pragma: no cover
|
|
145
|
+
|
|
146
|
+
if not self._thinking_text:
|
|
147
|
+
yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START)
|
|
148
|
+
self._thinking_text = True
|
|
149
|
+
|
|
150
|
+
yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=delta.content_delta)
|
|
151
|
+
|
|
152
|
+
async def handle_thinking_end(
|
|
153
|
+
self, part: ThinkingPart, followed_by_thinking: bool = False
|
|
154
|
+
) -> AsyncIterator[BaseEvent]:
|
|
155
|
+
if self._thinking_text:
|
|
156
|
+
yield ThinkingTextMessageEndEvent(type=EventType.THINKING_TEXT_MESSAGE_END)
|
|
157
|
+
self._thinking_text = False
|
|
158
|
+
|
|
159
|
+
if not followed_by_thinking:
|
|
160
|
+
yield ThinkingEndEvent(type=EventType.THINKING_END)
|
|
161
|
+
|
|
162
|
+
def handle_tool_call_start(self, part: ToolCallPart | BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
|
|
163
|
+
return self._handle_tool_call_start(part)
|
|
164
|
+
|
|
165
|
+
def handle_builtin_tool_call_start(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
|
|
166
|
+
tool_call_id = part.tool_call_id
|
|
167
|
+
builtin_tool_call_id = '|'.join([BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id])
|
|
168
|
+
self._builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
|
|
169
|
+
tool_call_id = builtin_tool_call_id
|
|
170
|
+
|
|
171
|
+
return self._handle_tool_call_start(part, tool_call_id)
|
|
172
|
+
|
|
173
|
+
async def _handle_tool_call_start(
|
|
174
|
+
self, part: ToolCallPart | BuiltinToolCallPart, tool_call_id: str | None = None
|
|
175
|
+
) -> AsyncIterator[BaseEvent]:
|
|
176
|
+
tool_call_id = tool_call_id or part.tool_call_id
|
|
177
|
+
parent_message_id = self.message_id
|
|
178
|
+
|
|
179
|
+
yield ToolCallStartEvent(
|
|
180
|
+
tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=parent_message_id
|
|
181
|
+
)
|
|
182
|
+
if part.args:
|
|
183
|
+
yield ToolCallArgsEvent(tool_call_id=tool_call_id, delta=part.args_as_json_str())
|
|
184
|
+
|
|
185
|
+
async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterator[BaseEvent]:
|
|
186
|
+
tool_call_id = delta.tool_call_id
|
|
187
|
+
assert tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
188
|
+
if tool_call_id in self._builtin_tool_call_ids:
|
|
189
|
+
tool_call_id = self._builtin_tool_call_ids[tool_call_id]
|
|
190
|
+
yield ToolCallArgsEvent(
|
|
191
|
+
tool_call_id=tool_call_id,
|
|
192
|
+
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseEvent]:
|
|
196
|
+
yield ToolCallEndEvent(tool_call_id=part.tool_call_id)
|
|
197
|
+
|
|
198
|
+
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
|
|
199
|
+
yield ToolCallEndEvent(tool_call_id=self._builtin_tool_call_ids[part.tool_call_id])
|
|
200
|
+
|
|
201
|
+
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseEvent]:
|
|
202
|
+
tool_call_id = self._builtin_tool_call_ids[part.tool_call_id]
|
|
203
|
+
yield ToolCallResultEvent(
|
|
204
|
+
message_id=self.new_message_id(),
|
|
205
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
206
|
+
role='tool',
|
|
207
|
+
tool_call_id=tool_call_id,
|
|
208
|
+
content=part.model_response_str(),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseEvent]:
|
|
212
|
+
result = event.result
|
|
213
|
+
output = result.model_response() if isinstance(result, RetryPromptPart) else result.model_response_str()
|
|
214
|
+
|
|
215
|
+
yield ToolCallResultEvent(
|
|
216
|
+
message_id=self.new_message_id(),
|
|
217
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
218
|
+
role='tool',
|
|
219
|
+
tool_call_id=result.tool_call_id,
|
|
220
|
+
content=output,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# ToolCallResultEvent.content may hold user parts (e.g. text, images) that AG-UI does not currently have events for
|
|
224
|
+
|
|
225
|
+
if isinstance(result, ToolReturnPart):
|
|
226
|
+
# Check for AG-UI events returned by tool calls.
|
|
227
|
+
possible_event = result.metadata or result.content
|
|
228
|
+
if isinstance(possible_event, BaseEvent):
|
|
229
|
+
yield possible_event
|
|
230
|
+
elif isinstance(possible_event, str | bytes): # pragma: no branch
|
|
231
|
+
# Avoid iterable check for strings and bytes.
|
|
232
|
+
pass
|
|
233
|
+
elif isinstance(possible_event, Iterable): # pragma: no branch
|
|
234
|
+
for item in possible_event: # type: ignore[reportUnknownMemberType]
|
|
235
|
+
if isinstance(item, BaseEvent): # pragma: no branch
|
|
236
|
+
yield item
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""AG-UI protocol integration for Pydantic AI agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
from typing import Any, Generic
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from pydantic_ai import DeferredToolResults
|
|
12
|
+
from pydantic_ai.agent import AbstractAgent
|
|
13
|
+
from pydantic_ai.builtin_tools import AbstractBuiltinTool
|
|
14
|
+
from pydantic_ai.messages import ModelMessage
|
|
15
|
+
from pydantic_ai.models import KnownModelName, Model
|
|
16
|
+
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
17
|
+
from pydantic_ai.settings import ModelSettings
|
|
18
|
+
from pydantic_ai.tools import AgentDepsT
|
|
19
|
+
from pydantic_ai.toolsets import AbstractToolset
|
|
20
|
+
from pydantic_ai.usage import RunUsage, UsageLimits
|
|
21
|
+
|
|
22
|
+
from .. import OnCompleteFunc, StateHandler
|
|
23
|
+
from ._adapter import AGUIAdapter
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from starlette.applications import Starlette
|
|
27
|
+
from starlette.middleware import Middleware
|
|
28
|
+
from starlette.requests import Request
|
|
29
|
+
from starlette.responses import Response
|
|
30
|
+
from starlette.routing import BaseRoute
|
|
31
|
+
from starlette.types import ExceptionHandler, Lifespan
|
|
32
|
+
except ImportError as e: # pragma: no cover
|
|
33
|
+
raise ImportError(
|
|
34
|
+
'Please install the `starlette` package to use `AGUIApp`, '
|
|
35
|
+
'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
|
|
36
|
+
) from e
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
40
|
+
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
agent: AbstractAgent[AgentDepsT, OutputDataT],
|
|
45
|
+
*,
|
|
46
|
+
# AGUIAdapter.dispatch_request parameters
|
|
47
|
+
output_type: OutputSpec[Any] | None = None,
|
|
48
|
+
message_history: Sequence[ModelMessage] | None = None,
|
|
49
|
+
deferred_tool_results: DeferredToolResults | None = None,
|
|
50
|
+
model: Model | KnownModelName | str | None = None,
|
|
51
|
+
deps: AgentDepsT = None,
|
|
52
|
+
model_settings: ModelSettings | None = None,
|
|
53
|
+
usage_limits: UsageLimits | None = None,
|
|
54
|
+
usage: RunUsage | None = None,
|
|
55
|
+
infer_name: bool = True,
|
|
56
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
57
|
+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
|
|
58
|
+
on_complete: OnCompleteFunc[Any] | None = None,
|
|
59
|
+
# Starlette parameters
|
|
60
|
+
debug: bool = False,
|
|
61
|
+
routes: Sequence[BaseRoute] | None = None,
|
|
62
|
+
middleware: Sequence[Middleware] | None = None,
|
|
63
|
+
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
|
64
|
+
on_startup: Sequence[Callable[[], Any]] | None = None,
|
|
65
|
+
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
66
|
+
lifespan: Lifespan[Self] | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""An ASGI application that handles every request by running the agent and streaming the response.
|
|
69
|
+
|
|
70
|
+
Note that the `deps` will be the same for each request, with the exception of the frontend state that's
|
|
71
|
+
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ui.StateHandler] protocol.
|
|
72
|
+
To provide different `deps` for each request (e.g. based on the authenticated user),
|
|
73
|
+
use [`AGUIAdapter.run_stream()`][pydantic_ai.ui.ag_ui.AGUIAdapter.run_stream] or
|
|
74
|
+
[`AGUIAdapter.dispatch_request()`][pydantic_ai.ui.ag_ui.AGUIAdapter.dispatch_request] instead.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
agent: The agent to run.
|
|
78
|
+
|
|
79
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
80
|
+
no output validators since output validators would expect an argument that matches the agent's
|
|
81
|
+
output type.
|
|
82
|
+
message_history: History of the conversation so far.
|
|
83
|
+
deferred_tool_results: Optional results for deferred tool calls in the message history.
|
|
84
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
85
|
+
deps: Optional dependencies to use for this run.
|
|
86
|
+
model_settings: Optional settings to use for this model's request.
|
|
87
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
88
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
89
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
90
|
+
toolsets: Optional additional toolsets for this run.
|
|
91
|
+
builtin_tools: Optional additional builtin tools for this run.
|
|
92
|
+
on_complete: Optional callback function called when the agent run completes successfully.
|
|
93
|
+
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.
|
|
94
|
+
|
|
95
|
+
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
96
|
+
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
97
|
+
middleware: A list of middleware to run for every request. A starlette application will always
|
|
98
|
+
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
|
|
99
|
+
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
|
|
100
|
+
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
|
|
101
|
+
exception cases occurring in the routing or endpoints.
|
|
102
|
+
exception_handlers: A mapping of either integer status codes, or exception class types onto
|
|
103
|
+
callables which handle the exceptions. Exception handler callables should be of the form
|
|
104
|
+
`handler(request, exc) -> response` and may be either standard functions, or async functions.
|
|
105
|
+
on_startup: A list of callables to run on application startup. Startup handler callables do not
|
|
106
|
+
take any arguments, and may be either standard functions, or async functions.
|
|
107
|
+
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
|
|
108
|
+
not take any arguments, and may be either standard functions, or async functions.
|
|
109
|
+
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
|
|
110
|
+
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
|
|
111
|
+
the other, not both.
|
|
112
|
+
"""
|
|
113
|
+
super().__init__(
|
|
114
|
+
debug=debug,
|
|
115
|
+
routes=routes,
|
|
116
|
+
middleware=middleware,
|
|
117
|
+
exception_handlers=exception_handlers,
|
|
118
|
+
on_startup=on_startup,
|
|
119
|
+
on_shutdown=on_shutdown,
|
|
120
|
+
lifespan=lifespan,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
async def run_agent(request: Request) -> Response:
|
|
124
|
+
"""Endpoint to run the agent with the provided input data."""
|
|
125
|
+
# `dispatch_request` will store the frontend state from the request on `deps.state` (if it implements the `StateHandler` protocol),
|
|
126
|
+
# so we need to copy the deps to avoid different requests mutating the same deps object.
|
|
127
|
+
nonlocal deps
|
|
128
|
+
if isinstance(deps, StateHandler): # pragma: no branch
|
|
129
|
+
deps = replace(deps)
|
|
130
|
+
|
|
131
|
+
return await AGUIAdapter[AgentDepsT, OutputDataT].dispatch_request(
|
|
132
|
+
request,
|
|
133
|
+
agent=agent,
|
|
134
|
+
output_type=output_type,
|
|
135
|
+
message_history=message_history,
|
|
136
|
+
deferred_tool_results=deferred_tool_results,
|
|
137
|
+
model=model,
|
|
138
|
+
deps=deps,
|
|
139
|
+
model_settings=model_settings,
|
|
140
|
+
usage_limits=usage_limits,
|
|
141
|
+
usage=usage,
|
|
142
|
+
infer_name=infer_name,
|
|
143
|
+
toolsets=toolsets,
|
|
144
|
+
builtin_tools=builtin_tools,
|
|
145
|
+
on_complete=on_complete,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.router.add_route('/', run_agent, methods=['POST'])
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Vercel AI protocol adapter for Pydantic AI agents.
|
|
2
|
+
|
|
3
|
+
This module provides classes for integrating Pydantic AI agents with the Vercel AI protocol,
|
|
4
|
+
enabling streaming event-based communication for interactive AI applications.
|
|
5
|
+
|
|
6
|
+
Converted to Python from:
|
|
7
|
+
https://github.com/vercel/ai/blob/ai%405.0.34/packages/ai/src/ui/ui-messages.ts
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from ._adapter import VercelAIAdapter
|
|
11
|
+
from ._event_stream import VercelAIEventStream
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
'VercelAIEventStream',
|
|
15
|
+
'VercelAIAdapter',
|
|
16
|
+
]
|