pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,10 +9,13 @@ from typing_extensions import TypedDict
|
|
|
9
9
|
from pydantic_ai.tools import Tool
|
|
10
10
|
|
|
11
11
|
try:
|
|
12
|
-
|
|
12
|
+
try:
|
|
13
|
+
from ddgs import DDGS
|
|
14
|
+
except ImportError: # Fallback for older versions of ddgs
|
|
15
|
+
from duckduckgo_search import DDGS
|
|
13
16
|
except ImportError as _import_error:
|
|
14
17
|
raise ImportError(
|
|
15
|
-
'Please install `
|
|
18
|
+
'Please install `ddgs` to use the DuckDuckGo search tool, '
|
|
16
19
|
'you can use the `duckduckgo` optional group — `pip install "pydantic-ai-slim[duckduckgo]"`'
|
|
17
20
|
) from _import_error
|
|
18
21
|
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,11 +2,15 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import sys
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
5
6
|
|
|
6
7
|
if sys.version_info < (3, 11):
|
|
7
|
-
from exceptiongroup import ExceptionGroup
|
|
8
|
+
from exceptiongroup import ExceptionGroup
|
|
8
9
|
else:
|
|
9
|
-
ExceptionGroup = ExceptionGroup
|
|
10
|
+
ExceptionGroup = ExceptionGroup
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .messages import RetryPromptPart
|
|
10
14
|
|
|
11
15
|
__all__ = (
|
|
12
16
|
'ModelRetry',
|
|
@@ -113,3 +117,11 @@ class ModelHTTPError(AgentRunError):
|
|
|
113
117
|
|
|
114
118
|
class FallbackExceptionGroup(ExceptionGroup):
|
|
115
119
|
"""A group of exceptions that can be raised when all fallback models fail."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ToolRetryError(Exception):
|
|
123
|
+
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
124
|
+
|
|
125
|
+
def __init__(self, tool_retry: RetryPromptPart):
|
|
126
|
+
self.tool_retry = tool_retry
|
|
127
|
+
super().__init__()
|
pydantic_ai/ext/aci.py
CHANGED
|
@@ -4,11 +4,13 @@ try:
|
|
|
4
4
|
except ImportError as _import_error:
|
|
5
5
|
raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error
|
|
6
6
|
|
|
7
|
+
from collections.abc import Sequence
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
from aci import ACI
|
|
10
11
|
|
|
11
|
-
from pydantic_ai import Tool
|
|
12
|
+
from pydantic_ai.tools import Tool
|
|
13
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
def _clean_schema(schema):
|
|
@@ -22,10 +24,10 @@ def _clean_schema(schema):
|
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
|
|
25
|
-
"""Creates a Pydantic AI tool proxy from an ACI function.
|
|
27
|
+
"""Creates a Pydantic AI tool proxy from an ACI.dev function.
|
|
26
28
|
|
|
27
29
|
Args:
|
|
28
|
-
aci_function: The ACI function to
|
|
30
|
+
aci_function: The ACI.dev function to wrap.
|
|
29
31
|
linked_account_owner_id: The ACI user ID to execute the function on behalf of.
|
|
30
32
|
|
|
31
33
|
Returns:
|
|
@@ -64,3 +66,10 @@ def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool:
|
|
|
64
66
|
description=function_description,
|
|
65
67
|
json_schema=json_schema,
|
|
66
68
|
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ACIToolset(FunctionToolset):
|
|
72
|
+
"""A toolset that wraps ACI.dev tools."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
|
|
75
|
+
super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
|
pydantic_ai/ext/langchain.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any, Protocol
|
|
|
3
3
|
from pydantic.json_schema import JsonSchemaValue
|
|
4
4
|
|
|
5
5
|
from pydantic_ai.tools import Tool
|
|
6
|
+
from pydantic_ai.toolsets.function import FunctionToolset
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LangChainTool(Protocol):
|
|
@@ -23,7 +24,7 @@ class LangChainTool(Protocol):
|
|
|
23
24
|
def run(self, *args: Any, **kwargs: Any) -> str: ...
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
__all__ = ('tool_from_langchain',)
|
|
27
|
+
__all__ = ('tool_from_langchain', 'LangChainToolset')
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
|
|
@@ -59,3 +60,10 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
|
|
|
59
60
|
description=function_description,
|
|
60
61
|
json_schema=schema,
|
|
61
62
|
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LangChainToolset(FunctionToolset):
|
|
66
|
+
"""A toolset that wraps LangChain tools."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, tools: list[LangChainTool]):
|
|
69
|
+
super().__init__([tool_from_langchain(tool) for tool in tools])
|
pydantic_ai/mcp.py
CHANGED
|
@@ -3,11 +3,11 @@ from __future__ import annotations
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from asyncio import Lock
|
|
6
7
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
7
8
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
8
|
-
from dataclasses import dataclass
|
|
9
|
+
from dataclasses import dataclass, field, replace
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from types import TracebackType
|
|
11
11
|
from typing import Any, Callable
|
|
12
12
|
|
|
13
13
|
import anyio
|
|
@@ -16,6 +16,11 @@ import pydantic_core
|
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
17
|
from typing_extensions import Self, assert_never, deprecated
|
|
18
18
|
|
|
19
|
+
from pydantic_ai._run_context import RunContext
|
|
20
|
+
from pydantic_ai.tools import ToolDefinition
|
|
21
|
+
|
|
22
|
+
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
23
|
+
|
|
19
24
|
try:
|
|
20
25
|
from mcp import types as mcp_types
|
|
21
26
|
from mcp.client.session import ClientSession, LoggingFnT
|
|
@@ -32,12 +37,18 @@ except ImportError as _import_error:
|
|
|
32
37
|
) from _import_error
|
|
33
38
|
|
|
34
39
|
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
35
|
-
from . import _mcp, exceptions, messages, models
|
|
40
|
+
from . import _mcp, exceptions, messages, models
|
|
36
41
|
|
|
37
42
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
38
43
|
|
|
44
|
+
TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator(
|
|
45
|
+
schema=pydantic_core.core_schema.dict_schema(
|
|
46
|
+
pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema()
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
39
50
|
|
|
40
|
-
class MCPServer(ABC):
|
|
51
|
+
class MCPServer(AbstractToolset[Any], ABC):
|
|
41
52
|
"""Base class for attaching agents to MCP servers.
|
|
42
53
|
|
|
43
54
|
See <https://modelcontextprotocol.io> for more information.
|
|
@@ -50,15 +61,22 @@ class MCPServer(ABC):
|
|
|
50
61
|
timeout: float = 5
|
|
51
62
|
process_tool_call: ProcessToolCallback | None = None
|
|
52
63
|
allow_sampling: bool = True
|
|
64
|
+
max_retries: int = 1
|
|
65
|
+
sampling_model: models.Model | None = None
|
|
53
66
|
# } end of "abstract fields"
|
|
54
67
|
|
|
55
|
-
|
|
68
|
+
_enter_lock: Lock = field(compare=False)
|
|
69
|
+
_running_count: int
|
|
70
|
+
_exit_stack: AsyncExitStack | None
|
|
56
71
|
|
|
57
72
|
_client: ClientSession
|
|
58
73
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
59
74
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
60
|
-
|
|
61
|
-
|
|
75
|
+
|
|
76
|
+
def __post_init__(self):
|
|
77
|
+
self._enter_lock = Lock()
|
|
78
|
+
self._running_count = 0
|
|
79
|
+
self._exit_stack = None
|
|
62
80
|
|
|
63
81
|
@abstractmethod
|
|
64
82
|
@asynccontextmanager
|
|
@@ -74,47 +92,36 @@ class MCPServer(ABC):
|
|
|
74
92
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
75
93
|
yield
|
|
76
94
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
return
|
|
80
|
-
|
|
81
|
-
def get_unprefixed_tool_name(self, tool_name: str) -> str:
|
|
82
|
-
"""Get original tool name without prefix for calling tools."""
|
|
83
|
-
return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
|
|
95
|
+
@property
|
|
96
|
+
def name(self) -> str:
|
|
97
|
+
return repr(self)
|
|
84
98
|
|
|
85
99
|
@property
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
return bool(self._running_count)
|
|
100
|
+
def tool_name_conflict_hint(self) -> str:
|
|
101
|
+
return 'Consider setting `tool_prefix` to avoid name conflicts.'
|
|
89
102
|
|
|
90
|
-
async def list_tools(self) -> list[
|
|
103
|
+
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
91
104
|
"""Retrieve tools that are currently active on the server.
|
|
92
105
|
|
|
93
106
|
Note:
|
|
94
107
|
- We don't cache tools as they might change.
|
|
95
108
|
- We also don't subscribe to the server to avoid complexity.
|
|
96
109
|
"""
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
name=self.get_prefixed_tool_name(tool.name),
|
|
101
|
-
description=tool.description,
|
|
102
|
-
parameters_json_schema=tool.inputSchema,
|
|
103
|
-
)
|
|
104
|
-
for tool in mcp_tools.tools
|
|
105
|
-
]
|
|
110
|
+
async with self: # Ensure server is running
|
|
111
|
+
result = await self._client.list_tools()
|
|
112
|
+
return result.tools
|
|
106
113
|
|
|
107
|
-
async def
|
|
114
|
+
async def direct_call_tool(
|
|
108
115
|
self,
|
|
109
|
-
|
|
110
|
-
|
|
116
|
+
name: str,
|
|
117
|
+
args: dict[str, Any],
|
|
111
118
|
metadata: dict[str, Any] | None = None,
|
|
112
119
|
) -> ToolResult:
|
|
113
120
|
"""Call a tool on the server.
|
|
114
121
|
|
|
115
122
|
Args:
|
|
116
|
-
|
|
117
|
-
|
|
123
|
+
name: The name of the tool to call.
|
|
124
|
+
args: The arguments to pass to the tool.
|
|
118
125
|
metadata: Request-level metadata (optional)
|
|
119
126
|
|
|
120
127
|
Returns:
|
|
@@ -123,23 +130,23 @@ class MCPServer(ABC):
|
|
|
123
130
|
Raises:
|
|
124
131
|
ModelRetry: If the tool call fails.
|
|
125
132
|
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
133
|
+
async with self: # Ensure server is running
|
|
134
|
+
try:
|
|
135
|
+
result = await self._client.send_request(
|
|
136
|
+
mcp_types.ClientRequest(
|
|
137
|
+
mcp_types.CallToolRequest(
|
|
138
|
+
method='tools/call',
|
|
139
|
+
params=mcp_types.CallToolRequestParams(
|
|
140
|
+
name=name,
|
|
141
|
+
arguments=args,
|
|
142
|
+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
),
|
|
146
|
+
mcp_types.CallToolResult,
|
|
147
|
+
)
|
|
148
|
+
except McpError as e:
|
|
149
|
+
raise exceptions.ModelRetry(e.error.message)
|
|
143
150
|
|
|
144
151
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
145
152
|
|
|
@@ -149,36 +156,80 @@ class MCPServer(ABC):
|
|
|
149
156
|
else:
|
|
150
157
|
return content[0] if len(content) == 1 else content
|
|
151
158
|
|
|
152
|
-
async def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
async def call_tool(
|
|
160
|
+
self,
|
|
161
|
+
name: str,
|
|
162
|
+
tool_args: dict[str, Any],
|
|
163
|
+
ctx: RunContext[Any],
|
|
164
|
+
tool: ToolsetTool[Any],
|
|
165
|
+
) -> ToolResult:
|
|
166
|
+
if self.tool_prefix:
|
|
167
|
+
name = name.removeprefix(f'{self.tool_prefix}_')
|
|
168
|
+
ctx = replace(ctx, tool_name=name)
|
|
169
|
+
|
|
170
|
+
if self.process_tool_call is not None:
|
|
171
|
+
return await self.process_tool_call(ctx, self.direct_call_tool, name, tool_args)
|
|
172
|
+
else:
|
|
173
|
+
return await self.direct_call_tool(name, tool_args)
|
|
174
|
+
|
|
175
|
+
async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
|
|
176
|
+
return {
|
|
177
|
+
name: ToolsetTool(
|
|
178
|
+
toolset=self,
|
|
179
|
+
tool_def=ToolDefinition(
|
|
180
|
+
name=name,
|
|
181
|
+
description=mcp_tool.description,
|
|
182
|
+
parameters_json_schema=mcp_tool.inputSchema,
|
|
183
|
+
),
|
|
184
|
+
max_retries=self.max_retries,
|
|
185
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
162
186
|
)
|
|
163
|
-
|
|
187
|
+
for mcp_tool in await self.list_tools()
|
|
188
|
+
if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
async def __aenter__(self) -> Self:
|
|
192
|
+
"""Enter the MCP server context.
|
|
193
|
+
|
|
194
|
+
This will initialize the connection to the server.
|
|
195
|
+
If this server is an [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio], the server will first be started as a subprocess.
|
|
164
196
|
|
|
165
|
-
|
|
166
|
-
|
|
197
|
+
This is a no-op if the MCP server has already been entered.
|
|
198
|
+
"""
|
|
199
|
+
async with self._enter_lock:
|
|
200
|
+
if self._running_count == 0:
|
|
201
|
+
self._exit_stack = AsyncExitStack()
|
|
202
|
+
|
|
203
|
+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(
|
|
204
|
+
self.client_streams()
|
|
205
|
+
)
|
|
206
|
+
client = ClientSession(
|
|
207
|
+
read_stream=self._read_stream,
|
|
208
|
+
write_stream=self._write_stream,
|
|
209
|
+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
210
|
+
logging_callback=self.log_handler,
|
|
211
|
+
)
|
|
212
|
+
self._client = await self._exit_stack.enter_async_context(client)
|
|
213
|
+
|
|
214
|
+
with anyio.fail_after(self.timeout):
|
|
215
|
+
await self._client.initialize()
|
|
167
216
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
217
|
+
if log_level := self.log_level:
|
|
218
|
+
await self._client.set_logging_level(log_level)
|
|
219
|
+
self._running_count += 1
|
|
171
220
|
return self
|
|
172
221
|
|
|
173
|
-
async def __aexit__(
|
|
174
|
-
self
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
222
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
223
|
+
async with self._enter_lock:
|
|
224
|
+
self._running_count -= 1
|
|
225
|
+
if self._running_count == 0 and self._exit_stack is not None:
|
|
226
|
+
await self._exit_stack.aclose()
|
|
227
|
+
self._exit_stack = None
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def is_running(self) -> bool:
|
|
231
|
+
"""Check if the MCP server is running."""
|
|
232
|
+
return bool(self._running_count)
|
|
182
233
|
|
|
183
234
|
async def _sampling_callback(
|
|
184
235
|
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
|
|
@@ -271,10 +322,10 @@ class MCPServerStdio(MCPServer):
|
|
|
271
322
|
'stdio',
|
|
272
323
|
]
|
|
273
324
|
)
|
|
274
|
-
agent = Agent('openai:gpt-4o',
|
|
325
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
275
326
|
|
|
276
327
|
async def main():
|
|
277
|
-
async with agent
|
|
328
|
+
async with agent: # (2)!
|
|
278
329
|
...
|
|
279
330
|
```
|
|
280
331
|
|
|
@@ -327,6 +378,12 @@ class MCPServerStdio(MCPServer):
|
|
|
327
378
|
allow_sampling: bool = True
|
|
328
379
|
"""Whether to allow MCP sampling through this client."""
|
|
329
380
|
|
|
381
|
+
max_retries: int = 1
|
|
382
|
+
"""The maximum number of times to retry a tool call."""
|
|
383
|
+
|
|
384
|
+
sampling_model: models.Model | None = None
|
|
385
|
+
"""The model to use for sampling."""
|
|
386
|
+
|
|
330
387
|
@asynccontextmanager
|
|
331
388
|
async def client_streams(
|
|
332
389
|
self,
|
|
@@ -422,6 +479,12 @@ class _MCPServerHTTP(MCPServer):
|
|
|
422
479
|
allow_sampling: bool = True
|
|
423
480
|
"""Whether to allow MCP sampling through this client."""
|
|
424
481
|
|
|
482
|
+
max_retries: int = 1
|
|
483
|
+
"""The maximum number of times to retry a tool call."""
|
|
484
|
+
|
|
485
|
+
sampling_model: models.Model | None = None
|
|
486
|
+
"""The model to use for sampling."""
|
|
487
|
+
|
|
425
488
|
@property
|
|
426
489
|
@abstractmethod
|
|
427
490
|
def _transport_client(
|
|
@@ -503,10 +566,10 @@ class MCPServerSSE(_MCPServerHTTP):
|
|
|
503
566
|
from pydantic_ai.mcp import MCPServerSSE
|
|
504
567
|
|
|
505
568
|
server = MCPServerSSE('http://localhost:3001/sse') # (1)!
|
|
506
|
-
agent = Agent('openai:gpt-4o',
|
|
569
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
507
570
|
|
|
508
571
|
async def main():
|
|
509
|
-
async with agent
|
|
572
|
+
async with agent: # (2)!
|
|
510
573
|
...
|
|
511
574
|
```
|
|
512
575
|
|
|
@@ -537,10 +600,10 @@ class MCPServerHTTP(MCPServerSSE):
|
|
|
537
600
|
from pydantic_ai.mcp import MCPServerHTTP
|
|
538
601
|
|
|
539
602
|
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
|
|
540
|
-
agent = Agent('openai:gpt-4o',
|
|
603
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
541
604
|
|
|
542
605
|
async def main():
|
|
543
|
-
async with agent
|
|
606
|
+
async with agent: # (2)!
|
|
544
607
|
...
|
|
545
608
|
```
|
|
546
609
|
|
|
@@ -566,10 +629,10 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
566
629
|
from pydantic_ai.mcp import MCPServerStreamableHTTP
|
|
567
630
|
|
|
568
631
|
server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
|
|
569
|
-
agent = Agent('openai:gpt-4o',
|
|
632
|
+
agent = Agent('openai:gpt-4o', toolsets=[server])
|
|
570
633
|
|
|
571
634
|
async def main():
|
|
572
|
-
async with agent
|
|
635
|
+
async with agent: # (2)!
|
|
573
636
|
...
|
|
574
637
|
```
|
|
575
638
|
"""
|
|
@@ -586,14 +649,14 @@ ToolResult = (
|
|
|
586
649
|
| list[Any]
|
|
587
650
|
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
|
|
588
651
|
)
|
|
589
|
-
"""The result type of
|
|
652
|
+
"""The result type of an MCP tool call."""
|
|
590
653
|
|
|
591
654
|
CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
|
|
592
655
|
"""A function type that represents a tool call."""
|
|
593
656
|
|
|
594
657
|
ProcessToolCallback = Callable[
|
|
595
658
|
[
|
|
596
|
-
|
|
659
|
+
RunContext[Any],
|
|
597
660
|
CallToolFunc,
|
|
598
661
|
str,
|
|
599
662
|
dict[str, Any],
|
pydantic_ai/messages.py
CHANGED
|
@@ -282,6 +282,14 @@ class BinaryContent:
|
|
|
282
282
|
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
|
|
283
283
|
"""The media type of the binary data."""
|
|
284
284
|
|
|
285
|
+
identifier: str | None = None
|
|
286
|
+
"""Identifier for the binary content, such as a URL or unique ID.
|
|
287
|
+
|
|
288
|
+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
|
|
289
|
+
|
|
290
|
+
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
|
|
291
|
+
"""
|
|
292
|
+
|
|
285
293
|
vendor_metadata: dict[str, Any] | None = None
|
|
286
294
|
"""Vendor-specific metadata for the file.
|
|
287
295
|
|
|
@@ -411,9 +419,9 @@ class UserPromptPart:
|
|
|
411
419
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
412
420
|
|
|
413
421
|
def otel_event(self, settings: InstrumentationSettings) -> Event:
|
|
414
|
-
content: str | list[dict[str, Any] | str]
|
|
422
|
+
content: str | list[dict[str, Any] | str] | dict[str, Any]
|
|
415
423
|
if isinstance(self.content, str):
|
|
416
|
-
content = self.content
|
|
424
|
+
content = self.content if settings.include_content else {'kind': 'text'}
|
|
417
425
|
else:
|
|
418
426
|
content = []
|
|
419
427
|
for part in self.content:
|
|
@@ -433,7 +441,9 @@ class UserPromptPart:
|
|
|
433
441
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
434
442
|
|
|
435
443
|
|
|
436
|
-
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(
|
|
444
|
+
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(
|
|
445
|
+
Any, config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64', val_json_bytes='base64')
|
|
446
|
+
)
|
|
437
447
|
|
|
438
448
|
|
|
439
449
|
@dataclass(repr=False)
|
|
@@ -519,7 +529,7 @@ class RetryPromptPart:
|
|
|
519
529
|
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
520
530
|
"""The tool call identifier, this is used by some models including OpenAI.
|
|
521
531
|
|
|
522
|
-
In case the tool call id is not provided by the model,
|
|
532
|
+
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
523
533
|
"""
|
|
524
534
|
|
|
525
535
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
@@ -560,12 +570,12 @@ class RetryPromptPart:
|
|
|
560
570
|
ModelRequestPart = Annotated[
|
|
561
571
|
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
|
|
562
572
|
]
|
|
563
|
-
"""A message part sent by
|
|
573
|
+
"""A message part sent by Pydantic AI to a model."""
|
|
564
574
|
|
|
565
575
|
|
|
566
576
|
@dataclass(repr=False)
|
|
567
577
|
class ModelRequest:
|
|
568
|
-
"""A request generated by
|
|
578
|
+
"""A request generated by Pydantic AI and sent to a model, e.g. a message from the Pydantic AI app to the model."""
|
|
569
579
|
|
|
570
580
|
parts: list[ModelRequestPart]
|
|
571
581
|
"""The parts of the user message."""
|
|
@@ -643,7 +653,7 @@ class ToolCallPart:
|
|
|
643
653
|
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
644
654
|
"""The tool call identifier, this is used by some models including OpenAI.
|
|
645
655
|
|
|
646
|
-
In case the tool call id is not provided by the model,
|
|
656
|
+
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
647
657
|
"""
|
|
648
658
|
|
|
649
659
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
@@ -691,7 +701,7 @@ ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydan
|
|
|
691
701
|
|
|
692
702
|
@dataclass(repr=False)
|
|
693
703
|
class ModelResponse:
|
|
694
|
-
"""A response from a model, e.g. a message from the model to the
|
|
704
|
+
"""A response from a model, e.g. a message from the model to the Pydantic AI app."""
|
|
695
705
|
|
|
696
706
|
parts: list[ModelResponsePart]
|
|
697
707
|
"""The parts of the model message."""
|
|
@@ -743,7 +753,7 @@ class ModelResponse:
|
|
|
743
753
|
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
|
|
744
754
|
'function': {
|
|
745
755
|
'name': part.tool_name,
|
|
746
|
-
'arguments': part.args,
|
|
756
|
+
**({'arguments': part.args} if settings.include_content else {}),
|
|
747
757
|
},
|
|
748
758
|
}
|
|
749
759
|
)
|