pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +37 -37
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +93 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +10 -3
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +84 -20
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +57 -30
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- pydantic_ai/usage.py +17 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/METADATA +7 -5
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/RECORD +58 -45
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/mcp.py
CHANGED
|
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
|
|
7
7
|
from asyncio import Lock
|
|
8
8
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
9
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
10
|
-
from dataclasses import
|
|
10
|
+
from dataclasses import field, replace
|
|
11
11
|
from datetime import timedelta
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from typing import Any, Callable
|
|
@@ -56,17 +56,17 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
56
56
|
See <https://modelcontextprotocol.io> for more information.
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
max_retries: int
|
|
68
|
-
|
|
69
|
-
|
|
59
|
+
tool_prefix: str | None
|
|
60
|
+
log_level: mcp_types.LoggingLevel | None
|
|
61
|
+
log_handler: LoggingFnT | None
|
|
62
|
+
timeout: float
|
|
63
|
+
read_timeout: float
|
|
64
|
+
process_tool_call: ProcessToolCallback | None
|
|
65
|
+
allow_sampling: bool
|
|
66
|
+
sampling_model: models.Model | None
|
|
67
|
+
max_retries: int
|
|
68
|
+
|
|
69
|
+
_id: str | None
|
|
70
70
|
|
|
71
71
|
_enter_lock: Lock = field(compare=False)
|
|
72
72
|
_running_count: int
|
|
@@ -76,6 +76,34 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
76
76
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
77
77
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
78
78
|
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
tool_prefix: str | None = None,
|
|
82
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
83
|
+
log_handler: LoggingFnT | None = None,
|
|
84
|
+
timeout: float = 5,
|
|
85
|
+
read_timeout: float = 5 * 60,
|
|
86
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
87
|
+
allow_sampling: bool = True,
|
|
88
|
+
sampling_model: models.Model | None = None,
|
|
89
|
+
max_retries: int = 1,
|
|
90
|
+
*,
|
|
91
|
+
id: str | None = None,
|
|
92
|
+
):
|
|
93
|
+
self.tool_prefix = tool_prefix
|
|
94
|
+
self.log_level = log_level
|
|
95
|
+
self.log_handler = log_handler
|
|
96
|
+
self.timeout = timeout
|
|
97
|
+
self.read_timeout = read_timeout
|
|
98
|
+
self.process_tool_call = process_tool_call
|
|
99
|
+
self.allow_sampling = allow_sampling
|
|
100
|
+
self.sampling_model = sampling_model
|
|
101
|
+
self.max_retries = max_retries
|
|
102
|
+
|
|
103
|
+
self._id = id or tool_prefix
|
|
104
|
+
|
|
105
|
+
self.__post_init__()
|
|
106
|
+
|
|
79
107
|
def __post_init__(self):
|
|
80
108
|
self._enter_lock = Lock()
|
|
81
109
|
self._running_count = 0
|
|
@@ -96,12 +124,19 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
96
124
|
yield
|
|
97
125
|
|
|
98
126
|
@property
|
|
99
|
-
def
|
|
100
|
-
return
|
|
127
|
+
def id(self) -> str | None:
|
|
128
|
+
return self._id
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def label(self) -> str:
|
|
132
|
+
if self.id:
|
|
133
|
+
return super().label # pragma: no cover
|
|
134
|
+
else:
|
|
135
|
+
return repr(self)
|
|
101
136
|
|
|
102
137
|
@property
|
|
103
138
|
def tool_name_conflict_hint(self) -> str:
|
|
104
|
-
return '
|
|
139
|
+
return 'Set the `tool_prefix` attribute to avoid name conflicts.'
|
|
105
140
|
|
|
106
141
|
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
107
142
|
"""Retrieve tools that are currently active on the server.
|
|
@@ -177,20 +212,25 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
177
212
|
|
|
178
213
|
async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
|
|
179
214
|
return {
|
|
180
|
-
name:
|
|
181
|
-
|
|
182
|
-
tool_def=ToolDefinition(
|
|
215
|
+
name: self.tool_for_tool_def(
|
|
216
|
+
ToolDefinition(
|
|
183
217
|
name=name,
|
|
184
218
|
description=mcp_tool.description,
|
|
185
219
|
parameters_json_schema=mcp_tool.inputSchema,
|
|
186
220
|
),
|
|
187
|
-
max_retries=self.max_retries,
|
|
188
|
-
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
189
221
|
)
|
|
190
222
|
for mcp_tool in await self.list_tools()
|
|
191
223
|
if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
|
|
192
224
|
}
|
|
193
225
|
|
|
226
|
+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
|
|
227
|
+
return ToolsetTool(
|
|
228
|
+
toolset=self,
|
|
229
|
+
tool_def=tool_def,
|
|
230
|
+
max_retries=self.max_retries,
|
|
231
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
232
|
+
)
|
|
233
|
+
|
|
194
234
|
async def __aenter__(self) -> Self:
|
|
195
235
|
"""Enter the MCP server context.
|
|
196
236
|
|
|
@@ -308,7 +348,6 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
308
348
|
assert_never(resource)
|
|
309
349
|
|
|
310
350
|
|
|
311
|
-
@dataclass
|
|
312
351
|
class MCPServerStdio(MCPServer):
|
|
313
352
|
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
|
|
314
353
|
|
|
@@ -353,18 +392,18 @@ class MCPServerStdio(MCPServer):
|
|
|
353
392
|
args: Sequence[str]
|
|
354
393
|
"""The arguments to pass to the command."""
|
|
355
394
|
|
|
356
|
-
env: dict[str, str] | None
|
|
395
|
+
env: dict[str, str] | None
|
|
357
396
|
"""The environment variables the CLI server will have access to.
|
|
358
397
|
|
|
359
398
|
By default the subprocess will not inherit any environment variables from the parent process.
|
|
360
399
|
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
|
|
361
400
|
"""
|
|
362
401
|
|
|
363
|
-
cwd: str | Path | None
|
|
402
|
+
cwd: str | Path | None
|
|
364
403
|
"""The working directory to use when spawning the process."""
|
|
365
404
|
|
|
366
405
|
# last fields are re-defined from the parent class so they appear as fields
|
|
367
|
-
tool_prefix: str | None
|
|
406
|
+
tool_prefix: str | None
|
|
368
407
|
"""A prefix to add to all tools that are registered with the server.
|
|
369
408
|
|
|
370
409
|
If not empty, will include a trailing underscore(`_`).
|
|
@@ -372,7 +411,7 @@ class MCPServerStdio(MCPServer):
|
|
|
372
411
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
373
412
|
"""
|
|
374
413
|
|
|
375
|
-
log_level: mcp_types.LoggingLevel | None
|
|
414
|
+
log_level: mcp_types.LoggingLevel | None
|
|
376
415
|
"""The log level to set when connecting to the server, if any.
|
|
377
416
|
|
|
378
417
|
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
@@ -380,23 +419,85 @@ class MCPServerStdio(MCPServer):
|
|
|
380
419
|
If `None`, no log level will be set.
|
|
381
420
|
"""
|
|
382
421
|
|
|
383
|
-
log_handler: LoggingFnT | None
|
|
422
|
+
log_handler: LoggingFnT | None
|
|
384
423
|
"""A handler for logging messages from the server."""
|
|
385
424
|
|
|
386
|
-
timeout: float
|
|
425
|
+
timeout: float
|
|
387
426
|
"""The timeout in seconds to wait for the client to initialize."""
|
|
388
427
|
|
|
389
|
-
|
|
428
|
+
read_timeout: float
|
|
429
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
430
|
+
|
|
431
|
+
This timeout applies to the long-lived connection after it's established.
|
|
432
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
433
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
434
|
+
"""
|
|
435
|
+
|
|
436
|
+
process_tool_call: ProcessToolCallback | None
|
|
390
437
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
391
438
|
|
|
392
|
-
allow_sampling: bool
|
|
439
|
+
allow_sampling: bool
|
|
393
440
|
"""Whether to allow MCP sampling through this client."""
|
|
394
441
|
|
|
395
|
-
|
|
442
|
+
sampling_model: models.Model | None
|
|
443
|
+
"""The model to use for sampling."""
|
|
444
|
+
|
|
445
|
+
max_retries: int
|
|
396
446
|
"""The maximum number of times to retry a tool call."""
|
|
397
447
|
|
|
398
|
-
|
|
399
|
-
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
command: str,
|
|
451
|
+
args: Sequence[str],
|
|
452
|
+
env: dict[str, str] | None = None,
|
|
453
|
+
cwd: str | Path | None = None,
|
|
454
|
+
tool_prefix: str | None = None,
|
|
455
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
456
|
+
log_handler: LoggingFnT | None = None,
|
|
457
|
+
timeout: float = 5,
|
|
458
|
+
read_timeout: float = 5 * 60,
|
|
459
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
460
|
+
allow_sampling: bool = True,
|
|
461
|
+
sampling_model: models.Model | None = None,
|
|
462
|
+
max_retries: int = 1,
|
|
463
|
+
*,
|
|
464
|
+
id: str | None = None,
|
|
465
|
+
):
|
|
466
|
+
"""Build a new MCP server.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
command: The command to run.
|
|
470
|
+
args: The arguments to pass to the command.
|
|
471
|
+
env: The environment variables to set in the subprocess.
|
|
472
|
+
cwd: The working directory to use when spawning the process.
|
|
473
|
+
tool_prefix: A prefix to add to all tools that are registered with the server.
|
|
474
|
+
log_level: The log level to set when connecting to the server, if any.
|
|
475
|
+
log_handler: A handler for logging messages from the server.
|
|
476
|
+
timeout: The timeout in seconds to wait for the client to initialize.
|
|
477
|
+
read_timeout: Maximum time in seconds to wait for new messages before timing out.
|
|
478
|
+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
|
|
479
|
+
allow_sampling: Whether to allow MCP sampling through this client.
|
|
480
|
+
sampling_model: The model to use for sampling.
|
|
481
|
+
max_retries: The maximum number of times to retry a tool call.
|
|
482
|
+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
|
|
483
|
+
"""
|
|
484
|
+
self.command = command
|
|
485
|
+
self.args = args
|
|
486
|
+
self.env = env
|
|
487
|
+
self.cwd = cwd
|
|
488
|
+
|
|
489
|
+
super().__init__(
|
|
490
|
+
tool_prefix,
|
|
491
|
+
log_level,
|
|
492
|
+
log_handler,
|
|
493
|
+
timeout,
|
|
494
|
+
read_timeout,
|
|
495
|
+
process_tool_call,
|
|
496
|
+
allow_sampling,
|
|
497
|
+
sampling_model,
|
|
498
|
+
max_retries,
|
|
499
|
+
id=id,
|
|
500
|
+
)
|
|
400
501
|
|
|
401
502
|
@asynccontextmanager
|
|
402
503
|
async def client_streams(
|
|
@@ -412,15 +513,20 @@ class MCPServerStdio(MCPServer):
|
|
|
412
513
|
yield read_stream, write_stream
|
|
413
514
|
|
|
414
515
|
def __repr__(self) -> str:
|
|
415
|
-
|
|
516
|
+
repr_args = [
|
|
517
|
+
f'command={self.command!r}',
|
|
518
|
+
f'args={self.args!r}',
|
|
519
|
+
]
|
|
520
|
+
if self.id:
|
|
521
|
+
repr_args.append(f'id={self.id!r}') # pragma: no cover
|
|
522
|
+
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
416
523
|
|
|
417
524
|
|
|
418
|
-
@dataclass(init=False)
|
|
419
525
|
class _MCPServerHTTP(MCPServer):
|
|
420
526
|
url: str
|
|
421
527
|
"""The URL of the endpoint on the MCP server."""
|
|
422
528
|
|
|
423
|
-
headers: dict[str, Any] | None
|
|
529
|
+
headers: dict[str, Any] | None
|
|
424
530
|
"""Optional HTTP headers to be sent with each request to the endpoint.
|
|
425
531
|
|
|
426
532
|
These headers will be passed directly to the underlying `httpx.AsyncClient`.
|
|
@@ -432,7 +538,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
432
538
|
See [`MCPServerHTTP.http_client`][pydantic_ai.mcp.MCPServerHTTP.http_client] for more information.
|
|
433
539
|
"""
|
|
434
540
|
|
|
435
|
-
http_client: httpx.AsyncClient | None
|
|
541
|
+
http_client: httpx.AsyncClient | None
|
|
436
542
|
"""An `httpx.AsyncClient` to use with the endpoint.
|
|
437
543
|
|
|
438
544
|
This client may be configured to use customized connection parameters like self-signed certificates.
|
|
@@ -452,16 +558,8 @@ class _MCPServerHTTP(MCPServer):
|
|
|
452
558
|
```
|
|
453
559
|
"""
|
|
454
560
|
|
|
455
|
-
read_timeout: float = 5 * 60
|
|
456
|
-
"""Maximum time in seconds to wait for new messages before timing out.
|
|
457
|
-
|
|
458
|
-
This timeout applies to the long-lived connection after it's established.
|
|
459
|
-
If no new messages are received within this time, the connection will be considered stale
|
|
460
|
-
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
461
|
-
"""
|
|
462
|
-
|
|
463
561
|
# last fields are re-defined from the parent class so they appear as fields
|
|
464
|
-
tool_prefix: str | None
|
|
562
|
+
tool_prefix: str | None
|
|
465
563
|
"""A prefix to add to all tools that are registered with the server.
|
|
466
564
|
|
|
467
565
|
If not empty, will include a trailing underscore (`_`).
|
|
@@ -469,7 +567,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
469
567
|
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
470
568
|
"""
|
|
471
569
|
|
|
472
|
-
log_level: mcp_types.LoggingLevel | None
|
|
570
|
+
log_level: mcp_types.LoggingLevel | None
|
|
473
571
|
"""The log level to set when connecting to the server, if any.
|
|
474
572
|
|
|
475
573
|
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
@@ -477,56 +575,81 @@ class _MCPServerHTTP(MCPServer):
|
|
|
477
575
|
If `None`, no log level will be set.
|
|
478
576
|
"""
|
|
479
577
|
|
|
480
|
-
log_handler: LoggingFnT | None
|
|
578
|
+
log_handler: LoggingFnT | None
|
|
481
579
|
"""A handler for logging messages from the server."""
|
|
482
580
|
|
|
483
|
-
timeout: float
|
|
581
|
+
timeout: float
|
|
484
582
|
"""Initial connection timeout in seconds for establishing the connection.
|
|
485
583
|
|
|
486
584
|
This timeout applies to the initial connection setup and handshake.
|
|
487
585
|
If the connection cannot be established within this time, the operation will fail.
|
|
488
586
|
"""
|
|
489
587
|
|
|
490
|
-
|
|
588
|
+
read_timeout: float
|
|
589
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
590
|
+
|
|
591
|
+
This timeout applies to the long-lived connection after it's established.
|
|
592
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
593
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
process_tool_call: ProcessToolCallback | None
|
|
491
597
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
492
598
|
|
|
493
|
-
allow_sampling: bool
|
|
599
|
+
allow_sampling: bool
|
|
494
600
|
"""Whether to allow MCP sampling through this client."""
|
|
495
601
|
|
|
496
|
-
|
|
497
|
-
"""The maximum number of times to retry a tool call."""
|
|
498
|
-
|
|
499
|
-
sampling_model: models.Model | None = None
|
|
602
|
+
sampling_model: models.Model | None
|
|
500
603
|
"""The model to use for sampling."""
|
|
501
604
|
|
|
605
|
+
max_retries: int
|
|
606
|
+
"""The maximum number of times to retry a tool call."""
|
|
607
|
+
|
|
502
608
|
def __init__(
|
|
503
609
|
self,
|
|
504
610
|
*,
|
|
505
611
|
url: str,
|
|
506
612
|
headers: dict[str, str] | None = None,
|
|
507
613
|
http_client: httpx.AsyncClient | None = None,
|
|
508
|
-
|
|
614
|
+
id: str | None = None,
|
|
509
615
|
tool_prefix: str | None = None,
|
|
510
616
|
log_level: mcp_types.LoggingLevel | None = None,
|
|
511
617
|
log_handler: LoggingFnT | None = None,
|
|
512
618
|
timeout: float = 5,
|
|
619
|
+
read_timeout: float | None = None,
|
|
513
620
|
process_tool_call: ProcessToolCallback | None = None,
|
|
514
621
|
allow_sampling: bool = True,
|
|
515
|
-
max_retries: int = 1,
|
|
516
622
|
sampling_model: models.Model | None = None,
|
|
517
|
-
|
|
623
|
+
max_retries: int = 1,
|
|
624
|
+
**_deprecated_kwargs: Any,
|
|
518
625
|
):
|
|
519
|
-
|
|
520
|
-
|
|
626
|
+
"""Build a new MCP server.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
url: The URL of the endpoint on the MCP server.
|
|
630
|
+
headers: Optional HTTP headers to be sent with each request to the endpoint.
|
|
631
|
+
http_client: An `httpx.AsyncClient` to use with the endpoint.
|
|
632
|
+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
|
|
633
|
+
tool_prefix: A prefix to add to all tools that are registered with the server.
|
|
634
|
+
log_level: The log level to set when connecting to the server, if any.
|
|
635
|
+
log_handler: A handler for logging messages from the server.
|
|
636
|
+
timeout: The timeout in seconds to wait for the client to initialize.
|
|
637
|
+
read_timeout: Maximum time in seconds to wait for new messages before timing out.
|
|
638
|
+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
|
|
639
|
+
allow_sampling: Whether to allow MCP sampling through this client.
|
|
640
|
+
sampling_model: The model to use for sampling.
|
|
641
|
+
max_retries: The maximum number of times to retry a tool call.
|
|
642
|
+
"""
|
|
643
|
+
if 'sse_read_timeout' in _deprecated_kwargs:
|
|
521
644
|
if read_timeout is not None:
|
|
522
645
|
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
|
|
523
646
|
|
|
524
647
|
warnings.warn(
|
|
525
648
|
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
|
|
526
649
|
)
|
|
527
|
-
read_timeout =
|
|
650
|
+
read_timeout = _deprecated_kwargs.pop('sse_read_timeout')
|
|
528
651
|
|
|
529
|
-
_utils.validate_empty_kwargs(
|
|
652
|
+
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
530
653
|
|
|
531
654
|
if read_timeout is None:
|
|
532
655
|
read_timeout = 5 * 60
|
|
@@ -534,16 +657,19 @@ class _MCPServerHTTP(MCPServer):
|
|
|
534
657
|
self.url = url
|
|
535
658
|
self.headers = headers
|
|
536
659
|
self.http_client = http_client
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
660
|
+
|
|
661
|
+
super().__init__(
|
|
662
|
+
tool_prefix,
|
|
663
|
+
log_level,
|
|
664
|
+
log_handler,
|
|
665
|
+
timeout,
|
|
666
|
+
read_timeout,
|
|
667
|
+
process_tool_call,
|
|
668
|
+
allow_sampling,
|
|
669
|
+
sampling_model,
|
|
670
|
+
max_retries,
|
|
671
|
+
id=id,
|
|
672
|
+
)
|
|
547
673
|
|
|
548
674
|
@property
|
|
549
675
|
@abstractmethod
|
|
@@ -606,10 +732,14 @@ class _MCPServerHTTP(MCPServer):
|
|
|
606
732
|
yield read_stream, write_stream
|
|
607
733
|
|
|
608
734
|
def __repr__(self) -> str: # pragma: no cover
|
|
609
|
-
|
|
735
|
+
repr_args = [
|
|
736
|
+
f'url={self.url!r}',
|
|
737
|
+
]
|
|
738
|
+
if self.id:
|
|
739
|
+
repr_args.append(f'id={self.id!r}')
|
|
740
|
+
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
610
741
|
|
|
611
742
|
|
|
612
|
-
@dataclass(init=False)
|
|
613
743
|
class MCPServerSSE(_MCPServerHTTP):
|
|
614
744
|
"""An MCP server that connects over streamable HTTP connections.
|
|
615
745
|
|
|
@@ -643,7 +773,6 @@ class MCPServerSSE(_MCPServerHTTP):
|
|
|
643
773
|
|
|
644
774
|
|
|
645
775
|
@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
|
|
646
|
-
@dataclass
|
|
647
776
|
class MCPServerHTTP(MCPServerSSE):
|
|
648
777
|
"""An MCP server that connects over HTTP using the old SSE transport.
|
|
649
778
|
|
|
@@ -672,7 +801,6 @@ class MCPServerHTTP(MCPServerSSE):
|
|
|
672
801
|
"""
|
|
673
802
|
|
|
674
803
|
|
|
675
|
-
@dataclass
|
|
676
804
|
class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
677
805
|
"""An MCP server that connects over HTTP using the Streamable HTTP transport.
|
|
678
806
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -490,8 +490,8 @@ _video_format_lookup: dict[str, VideoFormat] = {
|
|
|
490
490
|
class UserPromptPart:
|
|
491
491
|
"""A user prompt, generally written by the end user.
|
|
492
492
|
|
|
493
|
-
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.
|
|
494
|
-
[`Agent.run_sync`][pydantic_ai.
|
|
493
|
+
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.agent.AbstractAgent.run],
|
|
494
|
+
[`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream].
|
|
495
495
|
"""
|
|
496
496
|
|
|
497
497
|
content: str | Sequence[UserContent]
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -13,20 +13,32 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
13
13
|
from dataclasses import dataclass, field, replace
|
|
14
14
|
from datetime import datetime
|
|
15
15
|
from functools import cache, cached_property
|
|
16
|
-
from typing import Generic, TypeVar, overload
|
|
16
|
+
from typing import Any, Generic, TypeVar, overload
|
|
17
17
|
|
|
18
18
|
import httpx
|
|
19
19
|
from typing_extensions import Literal, TypeAliasType, TypedDict
|
|
20
20
|
|
|
21
|
-
from pydantic_ai.builtin_tools import AbstractBuiltinTool
|
|
22
|
-
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
23
|
-
|
|
24
21
|
from .. import _utils
|
|
25
22
|
from .._output import OutputObjectDefinition
|
|
26
23
|
from .._parts_manager import ModelResponsePartsManager
|
|
24
|
+
from .._run_context import RunContext
|
|
25
|
+
from ..builtin_tools import AbstractBuiltinTool
|
|
27
26
|
from ..exceptions import UserError
|
|
28
|
-
from ..messages import
|
|
27
|
+
from ..messages import (
|
|
28
|
+
AgentStreamEvent,
|
|
29
|
+
FileUrl,
|
|
30
|
+
FinalResultEvent,
|
|
31
|
+
ModelMessage,
|
|
32
|
+
ModelRequest,
|
|
33
|
+
ModelResponse,
|
|
34
|
+
ModelResponseStreamEvent,
|
|
35
|
+
PartStartEvent,
|
|
36
|
+
TextPart,
|
|
37
|
+
ToolCallPart,
|
|
38
|
+
VideoUrl,
|
|
39
|
+
)
|
|
29
40
|
from ..output import OutputMode
|
|
41
|
+
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
30
42
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
31
43
|
from ..settings import ModelSettings
|
|
32
44
|
from ..tools import ToolDefinition
|
|
@@ -182,6 +194,13 @@ KnownModelName = TypeAliasType(
|
|
|
182
194
|
'gpt-4o-mini-search-preview-2025-03-11',
|
|
183
195
|
'gpt-4o-search-preview',
|
|
184
196
|
'gpt-4o-search-preview-2025-03-11',
|
|
197
|
+
'gpt-5',
|
|
198
|
+
'gpt-5-2025-08-07',
|
|
199
|
+
'gpt-5-chat-latest',
|
|
200
|
+
'gpt-5-mini',
|
|
201
|
+
'gpt-5-mini-2025-08-07',
|
|
202
|
+
'gpt-5-nano',
|
|
203
|
+
'gpt-5-nano-2025-08-07',
|
|
185
204
|
'grok:grok-4',
|
|
186
205
|
'grok:grok-4-0709',
|
|
187
206
|
'grok:grok-3',
|
|
@@ -301,11 +320,18 @@ KnownModelName = TypeAliasType(
|
|
|
301
320
|
'openai:gpt-4o-mini-search-preview-2025-03-11',
|
|
302
321
|
'openai:gpt-4o-search-preview',
|
|
303
322
|
'openai:gpt-4o-search-preview-2025-03-11',
|
|
323
|
+
'openai:gpt-5',
|
|
324
|
+
'openai:gpt-5-2025-08-07',
|
|
304
325
|
'openai:o1',
|
|
326
|
+
'openai:gpt-5-chat-latest',
|
|
305
327
|
'openai:o1-2024-12-17',
|
|
328
|
+
'openai:gpt-5-mini',
|
|
306
329
|
'openai:o1-mini',
|
|
330
|
+
'openai:gpt-5-mini-2025-08-07',
|
|
307
331
|
'openai:o1-mini-2024-09-12',
|
|
332
|
+
'openai:gpt-5-nano',
|
|
308
333
|
'openai:o1-preview',
|
|
334
|
+
'openai:gpt-5-nano-2025-08-07',
|
|
309
335
|
'openai:o1-preview-2024-09-12',
|
|
310
336
|
'openai:o1-pro',
|
|
311
337
|
'openai:o1-pro-2025-03-19',
|
|
@@ -344,6 +370,10 @@ class ModelRequestParameters:
|
|
|
344
370
|
output_tools: list[ToolDefinition] = field(default_factory=list)
|
|
345
371
|
allow_text_output: bool = True
|
|
346
372
|
|
|
373
|
+
@cached_property
|
|
374
|
+
def tool_defs(self) -> dict[str, ToolDefinition]:
|
|
375
|
+
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}
|
|
376
|
+
|
|
347
377
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
348
378
|
|
|
349
379
|
|
|
@@ -383,12 +413,23 @@ class Model(ABC):
|
|
|
383
413
|
"""Make a request to the model."""
|
|
384
414
|
raise NotImplementedError()
|
|
385
415
|
|
|
416
|
+
async def count_tokens(
|
|
417
|
+
self,
|
|
418
|
+
messages: list[ModelMessage],
|
|
419
|
+
model_settings: ModelSettings | None,
|
|
420
|
+
model_request_parameters: ModelRequestParameters,
|
|
421
|
+
) -> Usage:
|
|
422
|
+
"""Make a request to the model for counting tokens."""
|
|
423
|
+
# This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
|
|
424
|
+
raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
|
|
425
|
+
|
|
386
426
|
@asynccontextmanager
|
|
387
427
|
async def request_stream(
|
|
388
428
|
self,
|
|
389
429
|
messages: list[ModelMessage],
|
|
390
430
|
model_settings: ModelSettings | None,
|
|
391
431
|
model_request_parameters: ModelRequestParameters,
|
|
432
|
+
run_context: RunContext[Any] | None = None,
|
|
392
433
|
) -> AsyncIterator[StreamedResponse]:
|
|
393
434
|
"""Make a request to the model and return a streaming response."""
|
|
394
435
|
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
@@ -501,14 +542,40 @@ class Model(ABC):
|
|
|
501
542
|
class StreamedResponse(ABC):
|
|
502
543
|
"""Streamed response from an LLM when calling a tool."""
|
|
503
544
|
|
|
545
|
+
model_request_parameters: ModelRequestParameters
|
|
546
|
+
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
547
|
+
|
|
504
548
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
505
|
-
_event_iterator: AsyncIterator[
|
|
549
|
+
_event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
506
550
|
_usage: Usage = field(default_factory=Usage, init=False)
|
|
507
551
|
|
|
508
|
-
def __aiter__(self) -> AsyncIterator[
|
|
509
|
-
"""Stream the response as an async iterable of [`
|
|
552
|
+
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
553
|
+
"""Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
554
|
+
|
|
555
|
+
This proxies the `_event_iterator()` and emits all events, while also checking for matches
|
|
556
|
+
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
|
|
557
|
+
first match is found.
|
|
558
|
+
"""
|
|
510
559
|
if self._event_iterator is None:
|
|
511
|
-
|
|
560
|
+
|
|
561
|
+
async def iterator_with_final_event(
|
|
562
|
+
iterator: AsyncIterator[ModelResponseStreamEvent],
|
|
563
|
+
) -> AsyncIterator[AgentStreamEvent]:
|
|
564
|
+
async for event in iterator:
|
|
565
|
+
yield event
|
|
566
|
+
if (
|
|
567
|
+
final_result_event := _get_final_result_event(event, self.model_request_parameters)
|
|
568
|
+
) is not None:
|
|
569
|
+
self.final_result_event = final_result_event
|
|
570
|
+
yield final_result_event
|
|
571
|
+
break
|
|
572
|
+
|
|
573
|
+
# If we broke out of the above loop, we need to yield the rest of the events
|
|
574
|
+
# If we didn't, this will just be a no-op
|
|
575
|
+
async for event in iterator:
|
|
576
|
+
yield event
|
|
577
|
+
|
|
578
|
+
self._event_iterator = iterator_with_final_event(self._get_event_iterator())
|
|
512
579
|
return self._event_iterator
|
|
513
580
|
|
|
514
581
|
@abstractmethod
|
|
@@ -636,6 +703,10 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
636
703
|
from .openai import OpenAIModel
|
|
637
704
|
|
|
638
705
|
return OpenAIModel(model_name, provider=provider)
|
|
706
|
+
elif provider == 'openai-responses':
|
|
707
|
+
from .openai import OpenAIResponsesModel
|
|
708
|
+
|
|
709
|
+
return OpenAIResponsesModel(model_name, provider='openai')
|
|
639
710
|
elif provider in ('google-gla', 'google-vertex'):
|
|
640
711
|
from .google import GoogleModel
|
|
641
712
|
|
|
@@ -810,3 +881,16 @@ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: Output
|
|
|
810
881
|
json_schema=json_schema,
|
|
811
882
|
strict=schema_transformer.is_strict_compatible if o.strict is None else o.strict,
|
|
812
883
|
)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestParameters) -> FinalResultEvent | None:
|
|
887
|
+
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
888
|
+
if isinstance(e, PartStartEvent):
|
|
889
|
+
new_part = e.part
|
|
890
|
+
if isinstance(new_part, TextPart) and params.allow_text_output: # pragma: no branch
|
|
891
|
+
return FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
892
|
+
elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
|
|
893
|
+
if tool_def.kind == 'output':
|
|
894
|
+
return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
|
|
895
|
+
elif tool_def.kind == 'deferred':
|
|
896
|
+
return FinalResultEvent(tool_name=None, tool_call_id=None)
|