pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +3 -3
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +378 -164
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/format_prompt.py +3 -6
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -18
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/instrumented.py +6 -1
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +16 -4
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,12 +2,16 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
5
6
|
|
|
6
7
|
if sys.version_info < (3, 11):
|
|
7
8
|
from exceptiongroup import ExceptionGroup
|
|
8
9
|
else:
|
|
9
10
|
ExceptionGroup = ExceptionGroup
|
|
10
11
|
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .messages import RetryPromptPart
|
|
14
|
+
|
|
11
15
|
__all__ = (
|
|
12
16
|
'ModelRetry',
|
|
13
17
|
'UserError',
|
|
@@ -113,3 +117,11 @@ class ModelHTTPError(AgentRunError):
|
|
|
113
117
|
|
|
114
118
|
class FallbackExceptionGroup(ExceptionGroup):
|
|
115
119
|
"""A group of exceptions that can be raised when all fallback models fail."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ToolRetryError(Exception):
|
|
123
|
+
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
124
|
+
|
|
125
|
+
def __init__(self, tool_retry: RetryPromptPart):
|
|
126
|
+
self.tool_retry = tool_retry
|
|
127
|
+
super().__init__()
|
pydantic_ai/ext/aci.py
CHANGED
|
@@ -4,11 +4,13 @@ try:
|
|
|
4
4
|
except ImportError as _import_error:
|
|
5
5
|
raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
|
|
6
6
|
|
|
7
|
+
from collections.abc import Sequence
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
from aci import ACI
|
|
10
11
|
|
|
11
|
-
from pydantic_ai import Tool
|
|
12
|
+
from pydantic_ai.tools import Tool
|
|
13
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
def _clean_schema(schema):
|
|
@@ -22,10 +24,10 @@ def _clean_schema(schema):
|
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
|
|
25
|
-
"""Creates a Pydantic AI tool proxy from an ACI function.
|
|
27
|
+
"""Creates a Pydantic AI tool proxy from an ACI.dev function.
|
|
26
28
|
|
|
27
29
|
Args:
|
|
28
|
-
aci_function: The ACI function to
|
|
30
|
+
aci_function: The ACI.dev function to wrap.
|
|
29
31
|
linked_account_owner_id: The ACI user ID to execute the function on behalf of.
|
|
30
32
|
|
|
31
33
|
Returns:
|
|
@@ -64,3 +66,10 @@ def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
|
|
|
64
66
|
description=function_description,
|
|
65
67
|
json_schema=json_schema,
|
|
66
68
|
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ACIToolset(FunctionToolset):
|
|
72
|
+
"""A toolset that wraps ACI.dev tools."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
|
|
75
|
+
super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
|
pydantic_ai/ext/langchain.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any, Protocol
|
|
|
3
3
|
from pydantic.json_schema import JsonSchemaValue
|
|
4
4
|
|
|
5
5
|
from pydantic_ai.tools import Tool
|
|
6
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LangChainTool(Protocol):
|
|
@@ -23,7 +24,7 @@ class LangChainTool(Protocol):
|
|
|
23
24
|
def run(self, *args: Any, **kwargs: Any) -> str: ...
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
__all__ = ('tool_from_langchain',)
|
|
27
|
+
__all__ = ('tool_from_langchain', 'LangChainToolset')
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
|
|
@@ -59,3 +60,10 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
|
|
|
59
60
|
description=function_description,
|
|
60
61
|
json_schema=schema,
|
|
61
62
|
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LangChainToolset(FunctionToolset):
|
|
66
|
+
"""A toolset that wraps LangChain tools."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, tools: list[LangChainTool]):
|
|
69
|
+
super().__init__([tool_from_langchain(tool) for tool in tools])
|
pydantic_ai/format_prompt.py
CHANGED
|
@@ -13,9 +13,8 @@ __all__ = ('format_as_xml',)
|
|
|
13
13
|
|
|
14
14
|
def format_as_xml(
|
|
15
15
|
obj: Any,
|
|
16
|
-
root_tag: str =
|
|
17
|
-
item_tag: str = '
|
|
18
|
-
include_root_tag: bool = True,
|
|
16
|
+
root_tag: str | None = None,
|
|
17
|
+
item_tag: str = 'item',
|
|
19
18
|
none_str: str = 'null',
|
|
20
19
|
indent: str | None = ' ',
|
|
21
20
|
) -> str:
|
|
@@ -32,8 +31,6 @@ def format_as_xml(
|
|
|
32
31
|
root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
|
|
33
32
|
item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
|
|
34
33
|
for dataclasses and Pydantic models.
|
|
35
|
-
include_root_tag: Whether to include the root tag in the output
|
|
36
|
-
(The root tag is always included if it includes a body - e.g. when the input is a simple value).
|
|
37
34
|
none_str: String to use for `None` values.
|
|
38
35
|
indent: Indentation string to use for pretty printing.
|
|
39
36
|
|
|
@@ -55,7 +52,7 @@ def format_as_xml(
|
|
|
55
52
|
```
|
|
56
53
|
"""
|
|
57
54
|
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
|
|
58
|
-
if
|
|
55
|
+
if root_tag is None and el.text is None:
|
|
59
56
|
join = '' if indent is None else '\n'
|
|
60
57
|
return join.join(_rootless_xml_elements(el, indent))
|
|
61
58
|
else:
|
pydantic_ai/mcp.py
CHANGED
|
@@ -3,11 +3,11 @@ from __future__ import annotations
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from asyncio import Lock
|
|
6
7
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
7
8
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
8
|
-
from dataclasses import dataclass
|
|
9
|
+
from dataclasses import dataclass, field, replace
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from types import TracebackType
|
|
11
11
|
from typing import Any, Callable
|
|
12
12
|
|
|
13
13
|
import anyio
|
|
@@ -16,6 +16,11 @@ import pydantic_core
|
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
17
|
from typing_extensions import Self, assert_never, deprecated
|
|
18
18
|
|
|
19
|
+
from pydantic_ai._run_context import RunContext
|
|
20
|
+
from pydantic_ai.tools import ToolDefinition
|
|
21
|
+
|
|
22
|
+
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
23
|
+
|
|
19
24
|
try:
|
|
20
25
|
from mcp import types as mcp_types
|
|
21
26
|
from mcp.client.session import ClientSession, LoggingFnT
|
|
@@ -32,12 +37,18 @@ except ImportError as _import_error:
|
|
|
32
37
|
) from _import_error
|
|
33
38
|
|
|
34
39
|
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
35
|
-
from . import _mcp, exceptions, messages, models
|
|
40
|
+
from . import _mcp, exceptions, messages, models
|
|
36
41
|
|
|
37
42
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
38
43
|
|
|
44
|
+
TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator(
|
|
45
|
+
schema=pydantic_core.core_schema.dict_schema(
|
|
46
|
+
pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema()
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
39
50
|
|
|
40
|
-
class MCPServer(ABC):
|
|
51
|
+
class MCPServer(AbstractToolset[Any], ABC):
|
|
41
52
|
"""Base class for attaching agents to MCP servers.
|
|
42
53
|
|
|
43
54
|
See <https://modelcontextprotocol.io> for more information.
|
|
@@ -50,15 +61,22 @@ class MCPServer(ABC):
|
|
|
50
61
|
timeout: float = 5
|
|
51
62
|
process_tool_call: ProcessToolCallback | None = None
|
|
52
63
|
allow_sampling: bool = True
|
|
64
|
+
max_retries: int = 1
|
|
65
|
+
sampling_model: models.Model | None = None
|
|
53
66
|
# } end of "abstract fields"
|
|
54
67
|
|
|
55
|
-
|
|
68
|
+
_enter_lock: Lock = field(compare=False)
|
|
69
|
+
_running_count: int
|
|
70
|
+
_exit_stack: AsyncExitStack | None
|
|
56
71
|
|
|
57
72
|
_client: ClientSession
|
|
58
73
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
59
74
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
60
|
-
|
|
61
|
-
|
|
75
|
+
|
|
76
|
+
def __post_init__(self):
|
|
77
|
+
self._enter_lock = Lock()
|
|
78
|
+
self._running_count = 0
|
|
79
|
+
self._exit_stack = None
|
|
62
80
|
|
|
63
81
|
@abstractmethod
|
|
64
82
|
@asynccontextmanager
|
|
@@ -74,47 +92,36 @@ class MCPServer(ABC):
|
|
|
74
92
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
75
93
|
yield
|
|
76
94
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
return
|
|
80
|
-
|
|
81
|
-
def get_unprefixed_tool_name(self, tool_name: str) -> str:
|
|
82
|
-
"""Get original tool name without prefix for calling tools."""
|
|
83
|
-
return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
|
|
95
|
+
@property
|
|
96
|
+
def name(self) -> str:
|
|
97
|
+
return repr(self)
|
|
84
98
|
|
|
85
99
|
@property
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
return bool(self._running_count)
|
|
100
|
+
def tool_name_conflict_hint(self) -> str:
|
|
101
|
+
return 'Consider setting `tool_prefix` to avoid name conflicts.'
|
|
89
102
|
|
|
90
|
-
async def list_tools(self) -> list[
|
|
103
|
+
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
91
104
|
"""Retrieve tools that are currently active on the server.
|
|
92
105
|
|
|
93
106
|
Note:
|
|
94
107
|
- We don't cache tools as they might change.
|
|
95
108
|
- We also don't subscribe to the server to avoid complexity.
|
|
96
109
|
"""
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
name=self.get_prefixed_tool_name(tool.name),
|
|
101
|
-
description=tool.description,
|
|
102
|
-
parameters_json_schema=tool.inputSchema,
|
|
103
|
-
)
|
|
104
|
-
for tool in mcp_tools.tools
|
|
105
|
-
]
|
|
110
|
+
async with self: # Ensure server is running
|
|
111
|
+
result = await self._client.list_tools()
|
|
112
|
+
return result.tools
|
|
106
113
|
|
|
107
|
-
async def
|
|
114
|
+
async def direct_call_tool(
|
|
108
115
|
self,
|
|
109
|
-
|
|
110
|
-
|
|
116
|
+
name: str,
|
|
117
|
+
args: dict[str, Any],
|
|
111
118
|
metadata: dict[str, Any] | None = None,
|
|
112
119
|
) -> ToolResult:
|
|
113
120
|
"""Call a tool on the server.
|
|
114
121
|
|
|
115
122
|
Args:
|
|
116
|
-
|
|
117
|
-
|
|
123
|
+
name: The name of the tool to call.
|
|
124
|
+
args: The arguments to pass to the tool.
|
|
118
125
|
metadata: Request-level metadata (optional)
|
|
119
126
|
|
|
120
127
|
Returns:
|
|
@@ -123,23 +130,23 @@ class MCPServer(ABC):
|
|
|
123
130
|
Raises:
|
|
124
131
|
ModelRetry: If the tool call fails.
|
|
125
132
|
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
133
|
+
async with self: # Ensure server is running
|
|
134
|
+
try:
|
|
135
|
+
result = await self._client.send_request(
|
|
136
|
+
mcp_types.ClientRequest(
|
|
137
|
+
mcp_types.CallToolRequest(
|
|
138
|
+
method='tools/call',
|
|
139
|
+
params=mcp_types.CallToolRequestParams(
|
|
140
|
+
name=name,
|
|
141
|
+
arguments=args,
|
|
142
|
+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
),
|
|
146
|
+
mcp_types.CallToolResult,
|
|
147
|
+
)
|
|
148
|
+
except McpError as e:
|
|
149
|
+
raise exceptions.ModelRetry(e.error.message)
|
|
143
150
|
|
|
144
151
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
145
152
|
|
|
@@ -149,36 +156,80 @@ class MCPServer(ABC):
|
|
|
149
156
|
else:
|
|
150
157
|
return content[0] if len(content) == 1 else content
|
|
151
158
|
|
|
152
|
-
async def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
async def call_tool(
|
|
160
|
+
self,
|
|
161
|
+
name: str,
|
|
162
|
+
tool_args: dict[str, Any],
|
|
163
|
+
ctx: RunContext[Any],
|
|
164
|
+
tool: ToolsetTool[Any],
|
|
165
|
+
) -> ToolResult:
|
|
166
|
+
if self.tool_prefix:
|
|
167
|
+
name = name.removeprefix(f'{self.tool_prefix}_')
|
|
168
|
+
ctx = replace(ctx, tool_name=name)
|
|
169
|
+
|
|
170
|
+
if self.process_tool_call is not None:
|
|
171
|
+
return await self.process_tool_call(ctx, self.direct_call_tool, name, tool_args)
|
|
172
|
+
else:
|
|
173
|
+
return await self.direct_call_tool(name, tool_args)
|
|
174
|
+
|
|
175
|
+
async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
|
|
176
|
+
return {
|
|
177
|
+
name: ToolsetTool(
|
|
178
|
+
toolset=self,
|
|
179
|
+
tool_def=ToolDefinition(
|
|
180
|
+
name=name,
|
|
181
|
+
description=mcp_tool.description,
|
|
182
|
+
parameters_json_schema=mcp_tool.inputSchema,
|
|
183
|
+
),
|
|
184
|
+
max_retries=self.max_retries,
|
|
185
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
162
186
|
)
|
|
163
|
-
|
|
187
|
+
for mcp_tool in await self.list_tools()
|
|
188
|
+
if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
async def __aenter__(self) -> Self:
|
|
192
|
+
"""Enter the MCP server context.
|
|
193
|
+
|
|
194
|
+
This will initialize the connection to the server.
|
|
195
|
+
If this server is an [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio], the server will first be started as a subprocess.
|
|
164
196
|
|
|
165
|
-
|
|
166
|
-
|
|
197
|
+
This is a no-op if the MCP server has already been entered.
|
|
198
|
+
"""
|
|
199
|
+
async with self._enter_lock:
|
|
200
|
+
if self._running_count == 0:
|
|
201
|
+
self._exit_stack = AsyncExitStack()
|
|
202
|
+
|
|
203
|
+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(
|
|
204
|
+
self.client_streams()
|
|
205
|
+
)
|
|
206
|
+
client = ClientSession(
|
|
207
|
+
read_stream=self._read_stream,
|
|
208
|
+
write_stream=self._write_stream,
|
|
209
|
+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
210
|
+
logging_callback=self.log_handler,
|
|
211
|
+
)
|
|
212
|
+
self._client = await self._exit_stack.enter_async_context(client)
|
|
213
|
+
|
|
214
|
+
with anyio.fail_after(self.timeout):
|
|
215
|
+
await self._client.initialize()
|
|
167
216
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
217
|
+
if log_level := self.log_level:
|
|
218
|
+
await self._client.set_logging_level(log_level)
|
|
219
|
+
self._running_count += 1
|
|
171
220
|
return self
|
|
172
221
|
|
|
173
|
-
async def __aexit__(
|
|
174
|
-
self
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
222
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
223
|
+
async with self._enter_lock:
|
|
224
|
+
self._running_count -= 1
|
|
225
|
+
if self._running_count == 0 and self._exit_stack is not None:
|
|
226
|
+
await self._exit_stack.aclose()
|
|
227
|
+
self._exit_stack = None
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def is_running(self) -> bool:
|
|
231
|
+
"""Check if the MCP server is running."""
|
|
232
|
+
return bool(self._running_count)
|
|
182
233
|
|
|
183
234
|
async def _sampling_callback(
|
|
184
235
|
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
|
|
@@ -271,10 +322,10 @@ class MCPServerStdio(MCPServer):
|
|
|
271
322
|
'stdio',
|
|
272
323
|
]
|
|
273
324
|
)
|
|
274
|
-
agent = Agent('openai:gpt-4o',
|
|
325
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
275
326
|
|
|
276
327
|
async def main():
|
|
277
|
-
async with agent
|
|
328
|
+
async with agent: # (2)!
|
|
278
329
|
...
|
|
279
330
|
```
|
|
280
331
|
|
|
@@ -327,6 +378,12 @@ class MCPServerStdio(MCPServer):
|
|
|
327
378
|
allow_sampling: bool = True
|
|
328
379
|
"""Whether to allow MCP sampling through this client."""
|
|
329
380
|
|
|
381
|
+
max_retries: int = 1
|
|
382
|
+
"""The maximum number of times to retry a tool call."""
|
|
383
|
+
|
|
384
|
+
sampling_model: models.Model | None = None
|
|
385
|
+
"""The model to use for sampling."""
|
|
386
|
+
|
|
330
387
|
@asynccontextmanager
|
|
331
388
|
async def client_streams(
|
|
332
389
|
self,
|
|
@@ -422,6 +479,12 @@ class _MCPServerHTTP(MCPServer):
|
|
|
422
479
|
allow_sampling: bool = True
|
|
423
480
|
"""Whether to allow MCP sampling through this client."""
|
|
424
481
|
|
|
482
|
+
max_retries: int = 1
|
|
483
|
+
"""The maximum number of times to retry a tool call."""
|
|
484
|
+
|
|
485
|
+
sampling_model: models.Model | None = None
|
|
486
|
+
"""The model to use for sampling."""
|
|
487
|
+
|
|
425
488
|
@property
|
|
426
489
|
@abstractmethod
|
|
427
490
|
def _transport_client(
|
|
@@ -503,10 +566,10 @@ class MCPServerSSE(_MCPServerHTTP):
|
|
|
503
566
|
from pydantic_ai.mcp import MCPServerSSE
|
|
504
567
|
|
|
505
568
|
server = MCPServerSSE('http://localhost:3001/sse') # (1)!
|
|
506
|
-
agent = Agent('openai:gpt-4o',
|
|
569
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
507
570
|
|
|
508
571
|
async def main():
|
|
509
|
-
async with agent
|
|
572
|
+
async with agent: # (2)!
|
|
510
573
|
...
|
|
511
574
|
```
|
|
512
575
|
|
|
@@ -537,10 +600,10 @@ class MCPServerHTTP(MCPServerSSE):
|
|
|
537
600
|
from pydantic_ai.mcp import MCPServerHTTP
|
|
538
601
|
|
|
539
602
|
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
|
|
540
|
-
agent = Agent('openai:gpt-4o',
|
|
603
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
541
604
|
|
|
542
605
|
async def main():
|
|
543
|
-
async with agent
|
|
606
|
+
async with agent: # (2)!
|
|
544
607
|
...
|
|
545
608
|
```
|
|
546
609
|
|
|
@@ -566,10 +629,10 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
566
629
|
from pydantic_ai.mcp import MCPServerStreamableHTTP
|
|
567
630
|
|
|
568
631
|
server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
|
|
569
|
-
agent = Agent('openai:gpt-4o',
|
|
632
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
570
633
|
|
|
571
634
|
async def main():
|
|
572
|
-
async with agent
|
|
635
|
+
async with agent: # (2)!
|
|
573
636
|
...
|
|
574
637
|
```
|
|
575
638
|
"""
|
|
@@ -586,14 +649,14 @@ ToolResult = (
|
|
|
586
649
|
| list[Any]
|
|
587
650
|
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
|
|
588
651
|
)
|
|
589
|
-
"""The result type of
|
|
652
|
+
"""The result type of an MCP tool call."""
|
|
590
653
|
|
|
591
654
|
CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
|
|
592
655
|
"""A function type that represents a tool call."""
|
|
593
656
|
|
|
594
657
|
ProcessToolCallback = Callable[
|
|
595
658
|
[
|
|
596
|
-
|
|
659
|
+
RunContext[Any],
|
|
597
660
|
CallToolFunc,
|
|
598
661
|
str,
|
|
599
662
|
dict[str, Any],
|
pydantic_ai/messages.py
CHANGED
|
@@ -282,6 +282,14 @@ class BinaryContent:
|
|
|
282
282
|
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
|
|
283
283
|
"""The media type of the binary data."""
|
|
284
284
|
|
|
285
|
+
identifier: str | None = None
|
|
286
|
+
"""Identifier for the binary content, such as a URL or unique ID.
|
|
287
|
+
|
|
288
|
+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
|
|
289
|
+
|
|
290
|
+
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
|
|
291
|
+
"""
|
|
292
|
+
|
|
285
293
|
vendor_metadata: dict[str, Any] | None = None
|
|
286
294
|
"""Vendor-specific metadata for the file.
|
|
287
295
|
|
|
@@ -521,7 +529,7 @@ class RetryPromptPart:
|
|
|
521
529
|
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
522
530
|
"""The tool call identifier, this is used by some models including OpenAI.
|
|
523
531
|
|
|
524
|
-
In case the tool call id is not provided by the model,
|
|
532
|
+
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
525
533
|
"""
|
|
526
534
|
|
|
527
535
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
@@ -562,12 +570,12 @@ class RetryPromptPart:
|
|
|
562
570
|
ModelRequestPart = Annotated[
|
|
563
571
|
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
|
|
564
572
|
]
|
|
565
|
-
"""A message part sent by
|
|
573
|
+
"""A message part sent by Pydantic AI to a model."""
|
|
566
574
|
|
|
567
575
|
|
|
568
576
|
@dataclass(repr=False)
|
|
569
577
|
class ModelRequest:
|
|
570
|
-
"""A request generated by
|
|
578
|
+
"""A request generated by Pydantic AI and sent to a model, e.g. a message from the Pydantic AI app to the model."""
|
|
571
579
|
|
|
572
580
|
parts: list[ModelRequestPart]
|
|
573
581
|
"""The parts of the user message."""
|
|
@@ -645,7 +653,7 @@ class ToolCallPart:
|
|
|
645
653
|
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
646
654
|
"""The tool call identifier, this is used by some models including OpenAI.
|
|
647
655
|
|
|
648
|
-
In case the tool call id is not provided by the model,
|
|
656
|
+
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
649
657
|
"""
|
|
650
658
|
|
|
651
659
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
@@ -693,7 +701,7 @@ ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydan
|
|
|
693
701
|
|
|
694
702
|
@dataclass(repr=False)
|
|
695
703
|
class ModelResponse:
|
|
696
|
-
"""A response from a model, e.g. a message from the model to the
|
|
704
|
+
"""A response from a model, e.g. a message from the model to the Pydantic AI app."""
|
|
697
705
|
|
|
698
706
|
parts: list[ModelResponsePart]
|
|
699
707
|
"""The parts of the model message."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -134,31 +134,15 @@ KnownModelName = TypeAliasType(
|
|
|
134
134
|
'cohere:command-r7b-12-2024',
|
|
135
135
|
'deepseek:deepseek-chat',
|
|
136
136
|
'deepseek:deepseek-reasoner',
|
|
137
|
-
'google-gla:gemini-1.5-flash',
|
|
138
|
-
'google-gla:gemini-1.5-flash-8b',
|
|
139
|
-
'google-gla:gemini-1.5-pro',
|
|
140
|
-
'google-gla:gemini-1.0-pro',
|
|
141
137
|
'google-gla:gemini-2.0-flash',
|
|
142
|
-
'google-gla:gemini-2.0-flash-lite
|
|
143
|
-
'google-gla:gemini-2.0-pro-exp-02-05',
|
|
144
|
-
'google-gla:gemini-2.5-flash-preview-05-20',
|
|
138
|
+
'google-gla:gemini-2.0-flash-lite',
|
|
145
139
|
'google-gla:gemini-2.5-flash',
|
|
146
140
|
'google-gla:gemini-2.5-flash-lite-preview-06-17',
|
|
147
|
-
'google-gla:gemini-2.5-pro-exp-03-25',
|
|
148
|
-
'google-gla:gemini-2.5-pro-preview-05-06',
|
|
149
141
|
'google-gla:gemini-2.5-pro',
|
|
150
|
-
'google-vertex:gemini-1.5-flash',
|
|
151
|
-
'google-vertex:gemini-1.5-flash-8b',
|
|
152
|
-
'google-vertex:gemini-1.5-pro',
|
|
153
|
-
'google-vertex:gemini-1.0-pro',
|
|
154
142
|
'google-vertex:gemini-2.0-flash',
|
|
155
|
-
'google-vertex:gemini-2.0-flash-lite
|
|
156
|
-
'google-vertex:gemini-2.0-pro-exp-02-05',
|
|
157
|
-
'google-vertex:gemini-2.5-flash-preview-05-20',
|
|
143
|
+
'google-vertex:gemini-2.0-flash-lite',
|
|
158
144
|
'google-vertex:gemini-2.5-flash',
|
|
159
145
|
'google-vertex:gemini-2.5-flash-lite-preview-06-17',
|
|
160
|
-
'google-vertex:gemini-2.5-pro-exp-03-25',
|
|
161
|
-
'google-vertex:gemini-2.5-pro-preview-05-06',
|
|
162
146
|
'google-vertex:gemini-2.5-pro',
|
|
163
147
|
'gpt-3.5-turbo',
|
|
164
148
|
'gpt-3.5-turbo-0125',
|
|
@@ -192,6 +176,7 @@ KnownModelName = TypeAliasType(
|
|
|
192
176
|
'gpt-4o-audio-preview',
|
|
193
177
|
'gpt-4o-audio-preview-2024-10-01',
|
|
194
178
|
'gpt-4o-audio-preview-2024-12-17',
|
|
179
|
+
'gpt-4o-audio-preview-2025-06-03',
|
|
195
180
|
'gpt-4o-mini',
|
|
196
181
|
'gpt-4o-mini-2024-07-18',
|
|
197
182
|
'gpt-4o-mini-audio-preview',
|
|
@@ -200,6 +185,14 @@ KnownModelName = TypeAliasType(
|
|
|
200
185
|
'gpt-4o-mini-search-preview-2025-03-11',
|
|
201
186
|
'gpt-4o-search-preview',
|
|
202
187
|
'gpt-4o-search-preview-2025-03-11',
|
|
188
|
+
'grok:grok-4',
|
|
189
|
+
'grok:grok-4-0709',
|
|
190
|
+
'grok:grok-3',
|
|
191
|
+
'grok:grok-3-mini',
|
|
192
|
+
'grok:grok-3-fast',
|
|
193
|
+
'grok:grok-3-mini-fast',
|
|
194
|
+
'grok:grok-2-vision-1212',
|
|
195
|
+
'grok:grok-2-image-1212',
|
|
203
196
|
'groq:distil-whisper-large-v3-en',
|
|
204
197
|
'groq:gemma2-9b-it',
|
|
205
198
|
'groq:llama-3.3-70b-versatile',
|
|
@@ -207,6 +200,7 @@ KnownModelName = TypeAliasType(
|
|
|
207
200
|
'groq:llama-guard-3-8b',
|
|
208
201
|
'groq:llama3-70b-8192',
|
|
209
202
|
'groq:llama3-8b-8192',
|
|
203
|
+
'groq:moonshotai/kimi-k2-instruct',
|
|
210
204
|
'groq:whisper-large-v3',
|
|
211
205
|
'groq:whisper-large-v3-turbo',
|
|
212
206
|
'groq:playai-tts',
|
|
@@ -245,11 +239,18 @@ KnownModelName = TypeAliasType(
|
|
|
245
239
|
'o1-mini-2024-09-12',
|
|
246
240
|
'o1-preview',
|
|
247
241
|
'o1-preview-2024-09-12',
|
|
242
|
+
'o1-pro',
|
|
243
|
+
'o1-pro-2025-03-19',
|
|
248
244
|
'o3',
|
|
249
245
|
'o3-2025-04-16',
|
|
246
|
+
'o3-deep-research',
|
|
247
|
+
'o3-deep-research-2025-06-26',
|
|
250
248
|
'o3-mini',
|
|
251
249
|
'o3-mini-2025-01-31',
|
|
250
|
+
'o3-pro',
|
|
251
|
+
'o3-pro-2025-06-10',
|
|
252
252
|
'openai:chatgpt-4o-latest',
|
|
253
|
+
'openai:codex-mini-latest',
|
|
253
254
|
'openai:gpt-3.5-turbo',
|
|
254
255
|
'openai:gpt-3.5-turbo-0125',
|
|
255
256
|
'openai:gpt-3.5-turbo-0301',
|
|
@@ -282,6 +283,7 @@ KnownModelName = TypeAliasType(
|
|
|
282
283
|
'openai:gpt-4o-audio-preview',
|
|
283
284
|
'openai:gpt-4o-audio-preview-2024-10-01',
|
|
284
285
|
'openai:gpt-4o-audio-preview-2024-12-17',
|
|
286
|
+
'openai:gpt-4o-audio-preview-2025-06-03',
|
|
285
287
|
'openai:gpt-4o-mini',
|
|
286
288
|
'openai:gpt-4o-mini-2024-07-18',
|
|
287
289
|
'openai:gpt-4o-mini-audio-preview',
|
|
@@ -296,12 +298,22 @@ KnownModelName = TypeAliasType(
|
|
|
296
298
|
'openai:o1-mini-2024-09-12',
|
|
297
299
|
'openai:o1-preview',
|
|
298
300
|
'openai:o1-preview-2024-09-12',
|
|
301
|
+
'openai:o1-pro',
|
|
302
|
+
'openai:o1-pro-2025-03-19',
|
|
299
303
|
'openai:o3',
|
|
300
304
|
'openai:o3-2025-04-16',
|
|
305
|
+
'openai:o3-deep-research',
|
|
306
|
+
'openai:o3-deep-research-2025-06-26',
|
|
301
307
|
'openai:o3-mini',
|
|
302
308
|
'openai:o3-mini-2025-01-31',
|
|
303
309
|
'openai:o4-mini',
|
|
304
310
|
'openai:o4-mini-2025-04-16',
|
|
311
|
+
'openai:o4-mini-deep-research',
|
|
312
|
+
'openai:o4-mini-deep-research-2025-06-26',
|
|
313
|
+
'openai:o3-pro',
|
|
314
|
+
'openai:o3-pro-2025-06-10',
|
|
315
|
+
'openai:computer-use-preview',
|
|
316
|
+
'openai:computer-use-preview-2025-03-11',
|
|
305
317
|
'test',
|
|
306
318
|
],
|
|
307
319
|
)
|