pydantic-ai-slim 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -2
- pydantic_ai/_agent_graph.py +33 -15
- pydantic_ai/_cli.py +7 -3
- pydantic_ai/_function_schema.py +1 -4
- pydantic_ai/_mcp.py +123 -0
- pydantic_ai/_output.py +654 -159
- pydantic_ai/_run_context.py +56 -0
- pydantic_ai/_system_prompt.py +2 -1
- pydantic_ai/_utils.py +111 -1
- pydantic_ai/agent.py +66 -35
- pydantic_ai/mcp.py +144 -115
- pydantic_ai/models/__init__.py +21 -2
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +27 -4
- pydantic_ai/models/google.py +29 -4
- pydantic_ai/models/mcp_sampling.py +95 -0
- pydantic_ai/models/mistral.py +5 -1
- pydantic_ai/models/openai.py +70 -9
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +6 -0
- pydantic_ai/output.py +288 -0
- pydantic_ai/profiles/__init__.py +21 -0
- pydantic_ai/profiles/_json_schema.py +1 -1
- pydantic_ai/profiles/google.py +6 -2
- pydantic_ai/profiles/openai.py +5 -0
- pydantic_ai/result.py +52 -26
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +2 -47
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/RECORD +33 -29
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/mcp.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
|
-
import json
|
|
6
5
|
from abc import ABC, abstractmethod
|
|
7
6
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
7
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
@@ -13,41 +12,28 @@ from typing import Any, Callable
|
|
|
13
12
|
|
|
14
13
|
import anyio
|
|
15
14
|
import httpx
|
|
15
|
+
import pydantic_core
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
|
-
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
18
|
-
from mcp.shared.exceptions import McpError
|
|
19
|
-
from mcp.shared.message import SessionMessage
|
|
20
|
-
from mcp.types import (
|
|
21
|
-
AudioContent,
|
|
22
|
-
BlobResourceContents,
|
|
23
|
-
CallToolRequest,
|
|
24
|
-
CallToolRequestParams,
|
|
25
|
-
CallToolResult,
|
|
26
|
-
ClientRequest,
|
|
27
|
-
Content,
|
|
28
|
-
EmbeddedResource,
|
|
29
|
-
ImageContent,
|
|
30
|
-
LoggingLevel,
|
|
31
|
-
RequestParams,
|
|
32
|
-
TextContent,
|
|
33
|
-
TextResourceContents,
|
|
34
|
-
)
|
|
35
17
|
from typing_extensions import Self, assert_never, deprecated
|
|
36
18
|
|
|
37
|
-
from pydantic_ai.exceptions import ModelRetry
|
|
38
|
-
from pydantic_ai.messages import BinaryContent
|
|
39
|
-
from pydantic_ai.tools import RunContext, ToolDefinition
|
|
40
|
-
|
|
41
19
|
try:
|
|
42
|
-
from mcp
|
|
20
|
+
from mcp import types as mcp_types
|
|
21
|
+
from mcp.client.session import ClientSession, LoggingFnT
|
|
43
22
|
from mcp.client.sse import sse_client
|
|
44
23
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
24
|
+
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
25
|
+
from mcp.shared.context import RequestContext
|
|
26
|
+
from mcp.shared.exceptions import McpError
|
|
27
|
+
from mcp.shared.message import SessionMessage
|
|
45
28
|
except ImportError as _import_error:
|
|
46
29
|
raise ImportError(
|
|
47
30
|
'Please install the `mcp` package to use the MCP server, '
|
|
48
31
|
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
|
|
49
32
|
) from _import_error
|
|
50
33
|
|
|
34
|
+
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
35
|
+
from . import _mcp, exceptions, messages, models, tools
|
|
36
|
+
|
|
51
37
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
52
38
|
|
|
53
39
|
|
|
@@ -57,22 +43,22 @@ class MCPServer(ABC):
|
|
|
57
43
|
See <https://modelcontextprotocol.io> for more information.
|
|
58
44
|
"""
|
|
59
45
|
|
|
60
|
-
|
|
46
|
+
# these fields should be re-defined by dataclass subclasses so they appear as fields {
|
|
61
47
|
tool_prefix: str | None = None
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
67
|
-
"""
|
|
68
|
-
|
|
48
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
49
|
+
log_handler: LoggingFnT | None = None
|
|
50
|
+
timeout: float = 5
|
|
69
51
|
process_tool_call: ProcessToolCallback | None = None
|
|
70
|
-
|
|
52
|
+
allow_sampling: bool = True
|
|
53
|
+
# } end of "abstract fields"
|
|
54
|
+
|
|
55
|
+
_running_count: int = 0
|
|
71
56
|
|
|
72
57
|
_client: ClientSession
|
|
73
58
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
74
59
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
75
60
|
_exit_stack: AsyncExitStack
|
|
61
|
+
sampling_model: models.Model | None = None
|
|
76
62
|
|
|
77
63
|
@abstractmethod
|
|
78
64
|
@asynccontextmanager
|
|
@@ -88,14 +74,6 @@ class MCPServer(ABC):
|
|
|
88
74
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
89
75
|
yield
|
|
90
76
|
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
93
|
-
"""Get the log level for the MCP server."""
|
|
94
|
-
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
95
|
-
|
|
96
|
-
def _get_client_initialize_timeout(self) -> float:
|
|
97
|
-
return 5 # pragma: no cover
|
|
98
|
-
|
|
99
77
|
def get_prefixed_tool_name(self, tool_name: str) -> str:
|
|
100
78
|
"""Get the tool name with prefix if `tool_prefix` is set."""
|
|
101
79
|
return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
|
|
@@ -104,21 +82,26 @@ class MCPServer(ABC):
|
|
|
104
82
|
"""Get original tool name without prefix for calling tools."""
|
|
105
83
|
return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
|
|
106
84
|
|
|
107
|
-
|
|
85
|
+
@property
|
|
86
|
+
def is_running(self) -> bool:
|
|
87
|
+
"""Check if the MCP server is running."""
|
|
88
|
+
return bool(self._running_count)
|
|
89
|
+
|
|
90
|
+
async def list_tools(self) -> list[tools.ToolDefinition]:
|
|
108
91
|
"""Retrieve tools that are currently active on the server.
|
|
109
92
|
|
|
110
93
|
Note:
|
|
111
94
|
- We don't cache tools as they might change.
|
|
112
95
|
- We also don't subscribe to the server to avoid complexity.
|
|
113
96
|
"""
|
|
114
|
-
|
|
97
|
+
mcp_tools = await self._client.list_tools()
|
|
115
98
|
return [
|
|
116
|
-
ToolDefinition(
|
|
99
|
+
tools.ToolDefinition(
|
|
117
100
|
name=self.get_prefixed_tool_name(tool.name),
|
|
118
101
|
description=tool.description or '',
|
|
119
102
|
parameters_json_schema=tool.inputSchema,
|
|
120
103
|
)
|
|
121
|
-
for tool in
|
|
104
|
+
for tool in mcp_tools.tools
|
|
122
105
|
]
|
|
123
106
|
|
|
124
107
|
async def call_tool(
|
|
@@ -143,44 +126,48 @@ class MCPServer(ABC):
|
|
|
143
126
|
try:
|
|
144
127
|
# meta param is not provided by session yet, so build and can send_request directly.
|
|
145
128
|
result = await self._client.send_request(
|
|
146
|
-
ClientRequest(
|
|
147
|
-
CallToolRequest(
|
|
129
|
+
mcp_types.ClientRequest(
|
|
130
|
+
mcp_types.CallToolRequest(
|
|
148
131
|
method='tools/call',
|
|
149
|
-
params=CallToolRequestParams(
|
|
132
|
+
params=mcp_types.CallToolRequestParams(
|
|
150
133
|
name=self.get_unprefixed_tool_name(tool_name),
|
|
151
134
|
arguments=arguments,
|
|
152
|
-
_meta=RequestParams.Meta(**metadata) if metadata else None,
|
|
135
|
+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
|
|
153
136
|
),
|
|
154
137
|
)
|
|
155
138
|
),
|
|
156
|
-
CallToolResult,
|
|
139
|
+
mcp_types.CallToolResult,
|
|
157
140
|
)
|
|
158
141
|
except McpError as e:
|
|
159
|
-
raise ModelRetry(e.error.message)
|
|
142
|
+
raise exceptions.ModelRetry(e.error.message)
|
|
160
143
|
|
|
161
144
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
162
145
|
|
|
163
146
|
if result.isError:
|
|
164
147
|
text = '\n'.join(str(part) for part in content)
|
|
165
|
-
raise ModelRetry(text)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
return content[0]
|
|
169
|
-
return content
|
|
148
|
+
raise exceptions.ModelRetry(text)
|
|
149
|
+
else:
|
|
150
|
+
return content[0] if len(content) == 1 else content
|
|
170
151
|
|
|
171
152
|
async def __aenter__(self) -> Self:
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
153
|
+
if self._running_count == 0:
|
|
154
|
+
self._exit_stack = AsyncExitStack()
|
|
155
|
+
|
|
156
|
+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
|
|
157
|
+
client = ClientSession(
|
|
158
|
+
read_stream=self._read_stream,
|
|
159
|
+
write_stream=self._write_stream,
|
|
160
|
+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
161
|
+
logging_callback=self.log_handler,
|
|
162
|
+
)
|
|
163
|
+
self._client = await self._exit_stack.enter_async_context(client)
|
|
177
164
|
|
|
178
|
-
|
|
179
|
-
|
|
165
|
+
with anyio.fail_after(self.timeout):
|
|
166
|
+
await self._client.initialize()
|
|
180
167
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
self.
|
|
168
|
+
if log_level := self.log_level:
|
|
169
|
+
await self._client.set_logging_level(log_level)
|
|
170
|
+
self._running_count += 1
|
|
184
171
|
return self
|
|
185
172
|
|
|
186
173
|
async def __aexit__(
|
|
@@ -189,32 +176,64 @@ class MCPServer(ABC):
|
|
|
189
176
|
exc_value: BaseException | None,
|
|
190
177
|
traceback: TracebackType | None,
|
|
191
178
|
) -> bool | None:
|
|
192
|
-
|
|
193
|
-
self.
|
|
179
|
+
self._running_count -= 1
|
|
180
|
+
if self._running_count <= 0:
|
|
181
|
+
await self._exit_stack.aclose()
|
|
182
|
+
|
|
183
|
+
async def _sampling_callback(
|
|
184
|
+
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
|
|
185
|
+
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
|
|
186
|
+
"""MCP sampling callback."""
|
|
187
|
+
if self.sampling_model is None:
|
|
188
|
+
raise ValueError('Sampling model is not set') # pragma: no cover
|
|
189
|
+
|
|
190
|
+
pai_messages = _mcp.map_from_mcp_params(params)
|
|
191
|
+
model_settings = models.ModelSettings()
|
|
192
|
+
if max_tokens := params.maxTokens: # pragma: no branch
|
|
193
|
+
model_settings['max_tokens'] = max_tokens
|
|
194
|
+
if temperature := params.temperature: # pragma: no branch
|
|
195
|
+
model_settings['temperature'] = temperature
|
|
196
|
+
if stop_sequences := params.stopSequences: # pragma: no branch
|
|
197
|
+
model_settings['stop_sequences'] = stop_sequences
|
|
198
|
+
|
|
199
|
+
model_response = await self.sampling_model.request(
|
|
200
|
+
pai_messages,
|
|
201
|
+
model_settings,
|
|
202
|
+
models.ModelRequestParameters(),
|
|
203
|
+
)
|
|
204
|
+
return mcp_types.CreateMessageResult(
|
|
205
|
+
role='assistant',
|
|
206
|
+
content=_mcp.map_from_model_response(model_response),
|
|
207
|
+
model=self.sampling_model.model_name,
|
|
208
|
+
)
|
|
194
209
|
|
|
195
|
-
def _map_tool_result_part(
|
|
210
|
+
def _map_tool_result_part(
|
|
211
|
+
self, part: mcp_types.Content
|
|
212
|
+
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
|
|
196
213
|
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
197
214
|
|
|
198
|
-
if isinstance(part, TextContent):
|
|
215
|
+
if isinstance(part, mcp_types.TextContent):
|
|
199
216
|
text = part.text
|
|
200
217
|
if text.startswith(('[', '{')):
|
|
201
218
|
try:
|
|
202
|
-
return
|
|
219
|
+
return pydantic_core.from_json(text)
|
|
203
220
|
except ValueError:
|
|
204
221
|
pass
|
|
205
222
|
return text
|
|
206
|
-
elif isinstance(part, ImageContent):
|
|
207
|
-
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
208
|
-
elif isinstance(part, AudioContent):
|
|
223
|
+
elif isinstance(part, mcp_types.ImageContent):
|
|
224
|
+
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
225
|
+
elif isinstance(part, mcp_types.AudioContent):
|
|
209
226
|
# NOTE: The FastMCP server doesn't support audio content.
|
|
210
227
|
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
|
|
211
|
-
return BinaryContent(
|
|
212
|
-
|
|
228
|
+
return messages.BinaryContent(
|
|
229
|
+
data=base64.b64decode(part.data), media_type=part.mimeType
|
|
230
|
+
) # pragma: no cover
|
|
231
|
+
elif isinstance(part, mcp_types.EmbeddedResource):
|
|
213
232
|
resource = part.resource
|
|
214
|
-
if isinstance(resource, TextResourceContents):
|
|
233
|
+
if isinstance(resource, mcp_types.TextResourceContents):
|
|
215
234
|
return resource.text
|
|
216
|
-
elif isinstance(resource, BlobResourceContents):
|
|
217
|
-
return BinaryContent(
|
|
235
|
+
elif isinstance(resource, mcp_types.BlobResourceContents):
|
|
236
|
+
return messages.BinaryContent(
|
|
218
237
|
data=base64.b64decode(resource.blob),
|
|
219
238
|
media_type=resource.mimeType or 'application/octet-stream',
|
|
220
239
|
)
|
|
@@ -275,17 +294,11 @@ class MCPServerStdio(MCPServer):
|
|
|
275
294
|
By default the subprocess will not inherit any environment variables from the parent process.
|
|
276
295
|
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
|
|
277
296
|
"""
|
|
278
|
-
log_level: LoggingLevel | None = None
|
|
279
|
-
"""The log level to set when connecting to the server, if any.
|
|
280
|
-
|
|
281
|
-
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
282
|
-
|
|
283
|
-
If `None`, no log level will be set.
|
|
284
|
-
"""
|
|
285
297
|
|
|
286
298
|
cwd: str | Path | None = None
|
|
287
299
|
"""The working directory to use when spawning the process."""
|
|
288
300
|
|
|
301
|
+
# last fields are re-defined from the parent class so they appear as fields
|
|
289
302
|
tool_prefix: str | None = None
|
|
290
303
|
"""A prefix to add to all tools that are registered with the server.
|
|
291
304
|
|
|
@@ -294,11 +307,25 @@ class MCPServerStdio(MCPServer):
|
|
|
294
307
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
295
308
|
"""
|
|
296
309
|
|
|
310
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
311
|
+
"""The log level to set when connecting to the server, if any.
|
|
312
|
+
|
|
313
|
+
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
314
|
+
|
|
315
|
+
If `None`, no log level will be set.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
log_handler: LoggingFnT | None = None
|
|
319
|
+
"""A handler for logging messages from the server."""
|
|
320
|
+
|
|
321
|
+
timeout: float = 5
|
|
322
|
+
"""The timeout in seconds to wait for the client to initialize."""
|
|
323
|
+
|
|
297
324
|
process_tool_call: ProcessToolCallback | None = None
|
|
298
325
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
299
326
|
|
|
300
|
-
|
|
301
|
-
"""
|
|
327
|
+
allow_sampling: bool = True
|
|
328
|
+
"""Whether to allow MCP sampling through this client."""
|
|
302
329
|
|
|
303
330
|
@asynccontextmanager
|
|
304
331
|
async def client_streams(
|
|
@@ -313,15 +340,9 @@ class MCPServerStdio(MCPServer):
|
|
|
313
340
|
async with stdio_client(server=server) as (read_stream, write_stream):
|
|
314
341
|
yield read_stream, write_stream
|
|
315
342
|
|
|
316
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
317
|
-
return self.log_level
|
|
318
|
-
|
|
319
343
|
def __repr__(self) -> str:
|
|
320
344
|
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
|
|
321
345
|
|
|
322
|
-
def _get_client_initialize_timeout(self) -> float:
|
|
323
|
-
return self.timeout
|
|
324
|
-
|
|
325
346
|
|
|
326
347
|
@dataclass
|
|
327
348
|
class _MCPServerHTTP(MCPServer):
|
|
@@ -360,13 +381,6 @@ class _MCPServerHTTP(MCPServer):
|
|
|
360
381
|
```
|
|
361
382
|
"""
|
|
362
383
|
|
|
363
|
-
timeout: float = 5
|
|
364
|
-
"""Initial connection timeout in seconds for establishing the connection.
|
|
365
|
-
|
|
366
|
-
This timeout applies to the initial connection setup and handshake.
|
|
367
|
-
If the connection cannot be established within this time, the operation will fail.
|
|
368
|
-
"""
|
|
369
|
-
|
|
370
384
|
sse_read_timeout: float = 5 * 60
|
|
371
385
|
"""Maximum time in seconds to wait for new SSE messages before timing out.
|
|
372
386
|
|
|
@@ -375,7 +389,16 @@ class _MCPServerHTTP(MCPServer):
|
|
|
375
389
|
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
376
390
|
"""
|
|
377
391
|
|
|
378
|
-
|
|
392
|
+
# last fields are re-defined from the parent class so they appear as fields
|
|
393
|
+
tool_prefix: str | None = None
|
|
394
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
395
|
+
|
|
396
|
+
If not empty, will include a trailing underscore (`_`).
|
|
397
|
+
|
|
398
|
+
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
379
402
|
"""The log level to set when connecting to the server, if any.
|
|
380
403
|
|
|
381
404
|
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
@@ -383,17 +406,22 @@ class _MCPServerHTTP(MCPServer):
|
|
|
383
406
|
If `None`, no log level will be set.
|
|
384
407
|
"""
|
|
385
408
|
|
|
386
|
-
|
|
387
|
-
"""A
|
|
409
|
+
log_handler: LoggingFnT | None = None
|
|
410
|
+
"""A handler for logging messages from the server."""
|
|
388
411
|
|
|
389
|
-
|
|
412
|
+
timeout: float = 5
|
|
413
|
+
"""Initial connection timeout in seconds for establishing the connection.
|
|
390
414
|
|
|
391
|
-
|
|
415
|
+
This timeout applies to the initial connection setup and handshake.
|
|
416
|
+
If the connection cannot be established within this time, the operation will fail.
|
|
392
417
|
"""
|
|
393
418
|
|
|
394
419
|
process_tool_call: ProcessToolCallback | None = None
|
|
395
420
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
396
421
|
|
|
422
|
+
allow_sampling: bool = True
|
|
423
|
+
"""Whether to allow MCP sampling through this client."""
|
|
424
|
+
|
|
397
425
|
@property
|
|
398
426
|
@abstractmethod
|
|
399
427
|
def _transport_client(
|
|
@@ -419,7 +447,10 @@ class _MCPServerHTTP(MCPServer):
|
|
|
419
447
|
async def client_streams(
|
|
420
448
|
self,
|
|
421
449
|
) -> AsyncIterator[
|
|
422
|
-
tuple[
|
|
450
|
+
tuple[
|
|
451
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
452
|
+
MemoryObjectSendStream[SessionMessage],
|
|
453
|
+
]
|
|
423
454
|
]: # pragma: no cover
|
|
424
455
|
if self.http_client and self.headers:
|
|
425
456
|
raise ValueError('`http_client` is mutually exclusive with `headers`.')
|
|
@@ -451,15 +482,9 @@ class _MCPServerHTTP(MCPServer):
|
|
|
451
482
|
async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
|
|
452
483
|
yield read_stream, write_stream
|
|
453
484
|
|
|
454
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
455
|
-
return self.log_level
|
|
456
|
-
|
|
457
485
|
def __repr__(self) -> str: # pragma: no cover
|
|
458
486
|
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|
|
459
487
|
|
|
460
|
-
def _get_client_initialize_timeout(self) -> float: # pragma: no cover
|
|
461
|
-
return self.timeout
|
|
462
|
-
|
|
463
488
|
|
|
464
489
|
@dataclass
|
|
465
490
|
class MCPServerSSE(_MCPServerHTTP):
|
|
@@ -555,7 +580,11 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
555
580
|
|
|
556
581
|
|
|
557
582
|
ToolResult = (
|
|
558
|
-
str
|
|
583
|
+
str
|
|
584
|
+
| messages.BinaryContent
|
|
585
|
+
| dict[str, Any]
|
|
586
|
+
| list[Any]
|
|
587
|
+
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
|
|
559
588
|
)
|
|
560
589
|
"""The result type of a tool call."""
|
|
561
590
|
|
|
@@ -564,7 +593,7 @@ CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[
|
|
|
564
593
|
|
|
565
594
|
ProcessToolCallback = Callable[
|
|
566
595
|
[
|
|
567
|
-
RunContext[Any],
|
|
596
|
+
tools.RunContext[Any],
|
|
568
597
|
CallToolFunc,
|
|
569
598
|
str,
|
|
570
599
|
dict[str, Any],
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -20,9 +20,12 @@ from typing_extensions import Literal, TypeAliasType, TypedDict
|
|
|
20
20
|
|
|
21
21
|
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
22
22
|
|
|
23
|
+
from .. import _utils
|
|
24
|
+
from .._output import OutputObjectDefinition
|
|
23
25
|
from .._parts_manager import ModelResponsePartsManager
|
|
24
26
|
from ..exceptions import UserError
|
|
25
27
|
from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
|
|
28
|
+
from ..output import OutputMode
|
|
26
29
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
27
30
|
from ..settings import ModelSettings
|
|
28
31
|
from ..tools import ToolDefinition
|
|
@@ -300,13 +303,18 @@ KnownModelName = TypeAliasType(
|
|
|
300
303
|
"""
|
|
301
304
|
|
|
302
305
|
|
|
303
|
-
@dataclass
|
|
306
|
+
@dataclass(repr=False)
|
|
304
307
|
class ModelRequestParameters:
|
|
305
308
|
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""
|
|
306
309
|
|
|
307
310
|
function_tools: list[ToolDefinition] = field(default_factory=list)
|
|
308
|
-
|
|
311
|
+
|
|
312
|
+
output_mode: OutputMode = 'text'
|
|
313
|
+
output_object: OutputObjectDefinition | None = None
|
|
309
314
|
output_tools: list[ToolDefinition] = field(default_factory=list)
|
|
315
|
+
allow_text_output: bool = True
|
|
316
|
+
|
|
317
|
+
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
310
318
|
|
|
311
319
|
|
|
312
320
|
class Model(ABC):
|
|
@@ -351,6 +359,11 @@ class Model(ABC):
|
|
|
351
359
|
function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
|
|
352
360
|
output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
|
|
353
361
|
)
|
|
362
|
+
if output_object := model_request_parameters.output_object:
|
|
363
|
+
model_request_parameters = replace(
|
|
364
|
+
model_request_parameters,
|
|
365
|
+
output_object=_customize_output_object(transformer, output_object),
|
|
366
|
+
)
|
|
354
367
|
|
|
355
368
|
return model_request_parameters
|
|
356
369
|
|
|
@@ -718,3 +731,9 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit
|
|
|
718
731
|
if t.strict is None:
|
|
719
732
|
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
720
733
|
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
|
|
737
|
+
schema_transformer = transformer(o.json_schema, strict=True)
|
|
738
|
+
son_schema = schema_transformer.walk()
|
|
739
|
+
return replace(o, json_schema=son_schema)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -11,6 +11,8 @@ from typing import Callable, Union
|
|
|
11
11
|
|
|
12
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
13
13
|
|
|
14
|
+
from pydantic_ai.profiles import ModelProfileSpec
|
|
15
|
+
|
|
14
16
|
from .. import _utils, usage
|
|
15
17
|
from .._utils import PeekableAsyncStream
|
|
16
18
|
from ..messages import (
|
|
@@ -49,14 +51,27 @@ class FunctionModel(Model):
|
|
|
49
51
|
_system: str = field(default='function', repr=False)
|
|
50
52
|
|
|
51
53
|
@overload
|
|
52
|
-
def __init__(
|
|
54
|
+
def __init__(
|
|
55
|
+
self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
|
|
56
|
+
) -> None: ...
|
|
53
57
|
|
|
54
58
|
@overload
|
|
55
|
-
def __init__(
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
*,
|
|
62
|
+
stream_function: StreamFunctionDef,
|
|
63
|
+
model_name: str | None = None,
|
|
64
|
+
profile: ModelProfileSpec | None = None,
|
|
65
|
+
) -> None: ...
|
|
56
66
|
|
|
57
67
|
@overload
|
|
58
68
|
def __init__(
|
|
59
|
-
self,
|
|
69
|
+
self,
|
|
70
|
+
function: FunctionDef,
|
|
71
|
+
*,
|
|
72
|
+
stream_function: StreamFunctionDef,
|
|
73
|
+
model_name: str | None = None,
|
|
74
|
+
profile: ModelProfileSpec | None = None,
|
|
60
75
|
) -> None: ...
|
|
61
76
|
|
|
62
77
|
def __init__(
|
|
@@ -65,6 +80,7 @@ class FunctionModel(Model):
|
|
|
65
80
|
*,
|
|
66
81
|
stream_function: StreamFunctionDef | None = None,
|
|
67
82
|
model_name: str | None = None,
|
|
83
|
+
profile: ModelProfileSpec | None = None,
|
|
68
84
|
):
|
|
69
85
|
"""Initialize a `FunctionModel`.
|
|
70
86
|
|
|
@@ -74,6 +90,7 @@ class FunctionModel(Model):
|
|
|
74
90
|
function: The function to call for non-streamed requests.
|
|
75
91
|
stream_function: The function to call for streamed requests.
|
|
76
92
|
model_name: The name of the model. If not provided, a name is generated from the function names.
|
|
93
|
+
profile: The model profile to use.
|
|
77
94
|
"""
|
|
78
95
|
if function is None and stream_function is None:
|
|
79
96
|
raise TypeError('Either `function` or `stream_function` must be provided')
|
|
@@ -83,6 +100,7 @@ class FunctionModel(Model):
|
|
|
83
100
|
function_name = self.function.__name__ if self.function is not None else ''
|
|
84
101
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
85
102
|
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
103
|
+
self._profile = profile
|
|
86
104
|
|
|
87
105
|
async def request(
|
|
88
106
|
self,
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -16,6 +16,8 @@ from typing_extensions import NotRequired, TypedDict, assert_never
|
|
|
16
16
|
from pydantic_ai.providers import Provider, infer_provider
|
|
17
17
|
|
|
18
18
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
|
+
from .._output import OutputObjectDefinition
|
|
20
|
+
from ..exceptions import UserError
|
|
19
21
|
from ..messages import (
|
|
20
22
|
BinaryContent,
|
|
21
23
|
FileUrl,
|
|
@@ -203,12 +205,10 @@ class GeminiModel(Model):
|
|
|
203
205
|
def _get_tool_config(
|
|
204
206
|
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
|
|
205
207
|
) -> _GeminiToolConfig | None:
|
|
206
|
-
if model_request_parameters.allow_text_output:
|
|
207
|
-
return None
|
|
208
|
-
elif tools:
|
|
208
|
+
if not model_request_parameters.allow_text_output and tools:
|
|
209
209
|
return _tool_config([t['name'] for t in tools['function_declarations']])
|
|
210
210
|
else:
|
|
211
|
-
return
|
|
211
|
+
return None
|
|
212
212
|
|
|
213
213
|
@asynccontextmanager
|
|
214
214
|
async def _make_request(
|
|
@@ -231,6 +231,18 @@ class GeminiModel(Model):
|
|
|
231
231
|
request_data['toolConfig'] = tool_config
|
|
232
232
|
|
|
233
233
|
generation_config = _settings_to_generation_config(model_settings)
|
|
234
|
+
if model_request_parameters.output_mode == 'native':
|
|
235
|
+
if tools:
|
|
236
|
+
raise UserError('Gemini does not support structured output and tools at the same time.')
|
|
237
|
+
|
|
238
|
+
generation_config['response_mime_type'] = 'application/json'
|
|
239
|
+
|
|
240
|
+
output_object = model_request_parameters.output_object
|
|
241
|
+
assert output_object is not None
|
|
242
|
+
generation_config['response_schema'] = self._map_response_schema(output_object)
|
|
243
|
+
elif model_request_parameters.output_mode == 'prompted' and not tools:
|
|
244
|
+
generation_config['response_mime_type'] = 'application/json'
|
|
245
|
+
|
|
234
246
|
if generation_config:
|
|
235
247
|
request_data['generationConfig'] = generation_config
|
|
236
248
|
|
|
@@ -376,6 +388,15 @@ class GeminiModel(Model):
|
|
|
376
388
|
assert_never(item)
|
|
377
389
|
return content
|
|
378
390
|
|
|
391
|
+
def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]:
|
|
392
|
+
response_schema = o.json_schema.copy()
|
|
393
|
+
if o.name:
|
|
394
|
+
response_schema['title'] = o.name
|
|
395
|
+
if o.description:
|
|
396
|
+
response_schema['description'] = o.description
|
|
397
|
+
|
|
398
|
+
return response_schema
|
|
399
|
+
|
|
379
400
|
|
|
380
401
|
def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
|
|
381
402
|
config: _GeminiGenerationConfig = {}
|
|
@@ -577,6 +598,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
|
|
|
577
598
|
frequency_penalty: float
|
|
578
599
|
stop_sequences: list[str]
|
|
579
600
|
thinking_config: ThinkingConfig
|
|
601
|
+
response_mime_type: str
|
|
602
|
+
response_schema: dict[str, Any]
|
|
580
603
|
|
|
581
604
|
|
|
582
605
|
class _GeminiContent(TypedDict):
|