pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +25 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +65 -9
- pydantic_ai/models/anthropic.py +16 -7
- pydantic_ai/models/bedrock.py +8 -5
- pydantic_ai/models/cohere.py +1 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +9 -4
- pydantic_ai/models/gemini.py +15 -9
- pydantic_ai/models/google.py +18 -14
- pydantic_ai/models/groq.py +17 -14
- pydantic_ai/models/huggingface.py +18 -12
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +12 -18
- pydantic_ai/models/openai.py +29 -26
- pydantic_ai/models/test.py +3 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/mcp.py
CHANGED
|
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
|
|
7
7
|
from asyncio import Lock
|
|
8
8
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
9
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
10
|
-
from dataclasses import
|
|
10
|
+
from dataclasses import field, replace
|
|
11
11
|
from datetime import timedelta
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from typing import Any, Callable
|
|
@@ -56,17 +56,17 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
56
56
|
See <https://modelcontextprotocol.io> for more information.
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
max_retries: int
|
|
68
|
-
|
|
69
|
-
|
|
59
|
+
tool_prefix: str | None
|
|
60
|
+
log_level: mcp_types.LoggingLevel | None
|
|
61
|
+
log_handler: LoggingFnT | None
|
|
62
|
+
timeout: float
|
|
63
|
+
read_timeout: float
|
|
64
|
+
process_tool_call: ProcessToolCallback | None
|
|
65
|
+
allow_sampling: bool
|
|
66
|
+
sampling_model: models.Model | None
|
|
67
|
+
max_retries: int
|
|
68
|
+
|
|
69
|
+
_id: str | None
|
|
70
70
|
|
|
71
71
|
_enter_lock: Lock = field(compare=False)
|
|
72
72
|
_running_count: int
|
|
@@ -76,6 +76,34 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
76
76
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
77
77
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
78
78
|
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
tool_prefix: str | None = None,
|
|
82
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
83
|
+
log_handler: LoggingFnT | None = None,
|
|
84
|
+
timeout: float = 5,
|
|
85
|
+
read_timeout: float = 5 * 60,
|
|
86
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
87
|
+
allow_sampling: bool = True,
|
|
88
|
+
sampling_model: models.Model | None = None,
|
|
89
|
+
max_retries: int = 1,
|
|
90
|
+
*,
|
|
91
|
+
id: str | None = None,
|
|
92
|
+
):
|
|
93
|
+
self.tool_prefix = tool_prefix
|
|
94
|
+
self.log_level = log_level
|
|
95
|
+
self.log_handler = log_handler
|
|
96
|
+
self.timeout = timeout
|
|
97
|
+
self.read_timeout = read_timeout
|
|
98
|
+
self.process_tool_call = process_tool_call
|
|
99
|
+
self.allow_sampling = allow_sampling
|
|
100
|
+
self.sampling_model = sampling_model
|
|
101
|
+
self.max_retries = max_retries
|
|
102
|
+
|
|
103
|
+
self._id = id or tool_prefix
|
|
104
|
+
|
|
105
|
+
self.__post_init__()
|
|
106
|
+
|
|
79
107
|
def __post_init__(self):
|
|
80
108
|
self._enter_lock = Lock()
|
|
81
109
|
self._running_count = 0
|
|
@@ -96,12 +124,19 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
96
124
|
yield
|
|
97
125
|
|
|
98
126
|
@property
|
|
99
|
-
def
|
|
100
|
-
return
|
|
127
|
+
def id(self) -> str | None:
|
|
128
|
+
return self._id
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def label(self) -> str:
|
|
132
|
+
if self.id:
|
|
133
|
+
return super().label # pragma: no cover
|
|
134
|
+
else:
|
|
135
|
+
return repr(self)
|
|
101
136
|
|
|
102
137
|
@property
|
|
103
138
|
def tool_name_conflict_hint(self) -> str:
|
|
104
|
-
return '
|
|
139
|
+
return 'Set the `tool_prefix` attribute to avoid name conflicts.'
|
|
105
140
|
|
|
106
141
|
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
107
142
|
"""Retrieve tools that are currently active on the server.
|
|
@@ -177,20 +212,25 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
177
212
|
|
|
178
213
|
async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
|
|
179
214
|
return {
|
|
180
|
-
name:
|
|
181
|
-
|
|
182
|
-
tool_def=ToolDefinition(
|
|
215
|
+
name: self.tool_for_tool_def(
|
|
216
|
+
ToolDefinition(
|
|
183
217
|
name=name,
|
|
184
218
|
description=mcp_tool.description,
|
|
185
219
|
parameters_json_schema=mcp_tool.inputSchema,
|
|
186
220
|
),
|
|
187
|
-
max_retries=self.max_retries,
|
|
188
|
-
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
189
221
|
)
|
|
190
222
|
for mcp_tool in await self.list_tools()
|
|
191
223
|
if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
|
|
192
224
|
}
|
|
193
225
|
|
|
226
|
+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
|
|
227
|
+
return ToolsetTool(
|
|
228
|
+
toolset=self,
|
|
229
|
+
tool_def=tool_def,
|
|
230
|
+
max_retries=self.max_retries,
|
|
231
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
232
|
+
)
|
|
233
|
+
|
|
194
234
|
async def __aenter__(self) -> Self:
|
|
195
235
|
"""Enter the MCP server context.
|
|
196
236
|
|
|
@@ -308,7 +348,6 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
308
348
|
assert_never(resource)
|
|
309
349
|
|
|
310
350
|
|
|
311
|
-
@dataclass
|
|
312
351
|
class MCPServerStdio(MCPServer):
|
|
313
352
|
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
|
|
314
353
|
|
|
@@ -353,18 +392,18 @@ class MCPServerStdio(MCPServer):
|
|
|
353
392
|
args: Sequence[str]
|
|
354
393
|
"""The arguments to pass to the command."""
|
|
355
394
|
|
|
356
|
-
env: dict[str, str] | None
|
|
395
|
+
env: dict[str, str] | None
|
|
357
396
|
"""The environment variables the CLI server will have access to.
|
|
358
397
|
|
|
359
398
|
By default the subprocess will not inherit any environment variables from the parent process.
|
|
360
399
|
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
|
|
361
400
|
"""
|
|
362
401
|
|
|
363
|
-
cwd: str | Path | None
|
|
402
|
+
cwd: str | Path | None
|
|
364
403
|
"""The working directory to use when spawning the process."""
|
|
365
404
|
|
|
366
405
|
# last fields are re-defined from the parent class so they appear as fields
|
|
367
|
-
tool_prefix: str | None
|
|
406
|
+
tool_prefix: str | None
|
|
368
407
|
"""A prefix to add to all tools that are registered with the server.
|
|
369
408
|
|
|
370
409
|
If not empty, will include a trailing underscore(`_`).
|
|
@@ -372,7 +411,7 @@ class MCPServerStdio(MCPServer):
|
|
|
372
411
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
373
412
|
"""
|
|
374
413
|
|
|
375
|
-
log_level: mcp_types.LoggingLevel | None
|
|
414
|
+
log_level: mcp_types.LoggingLevel | None
|
|
376
415
|
"""The log level to set when connecting to the server, if any.
|
|
377
416
|
|
|
378
417
|
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
@@ -380,23 +419,85 @@ class MCPServerStdio(MCPServer):
|
|
|
380
419
|
If `None`, no log level will be set.
|
|
381
420
|
"""
|
|
382
421
|
|
|
383
|
-
log_handler: LoggingFnT | None
|
|
422
|
+
log_handler: LoggingFnT | None
|
|
384
423
|
"""A handler for logging messages from the server."""
|
|
385
424
|
|
|
386
|
-
timeout: float
|
|
425
|
+
timeout: float
|
|
387
426
|
"""The timeout in seconds to wait for the client to initialize."""
|
|
388
427
|
|
|
389
|
-
|
|
428
|
+
read_timeout: float
|
|
429
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
430
|
+
|
|
431
|
+
This timeout applies to the long-lived connection after it's established.
|
|
432
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
433
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
434
|
+
"""
|
|
435
|
+
|
|
436
|
+
process_tool_call: ProcessToolCallback | None
|
|
390
437
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
391
438
|
|
|
392
|
-
allow_sampling: bool
|
|
439
|
+
allow_sampling: bool
|
|
393
440
|
"""Whether to allow MCP sampling through this client."""
|
|
394
441
|
|
|
395
|
-
|
|
442
|
+
sampling_model: models.Model | None
|
|
443
|
+
"""The model to use for sampling."""
|
|
444
|
+
|
|
445
|
+
max_retries: int
|
|
396
446
|
"""The maximum number of times to retry a tool call."""
|
|
397
447
|
|
|
398
|
-
|
|
399
|
-
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
command: str,
|
|
451
|
+
args: Sequence[str],
|
|
452
|
+
env: dict[str, str] | None = None,
|
|
453
|
+
cwd: str | Path | None = None,
|
|
454
|
+
tool_prefix: str | None = None,
|
|
455
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
456
|
+
log_handler: LoggingFnT | None = None,
|
|
457
|
+
timeout: float = 5,
|
|
458
|
+
read_timeout: float = 5 * 60,
|
|
459
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
460
|
+
allow_sampling: bool = True,
|
|
461
|
+
sampling_model: models.Model | None = None,
|
|
462
|
+
max_retries: int = 1,
|
|
463
|
+
*,
|
|
464
|
+
id: str | None = None,
|
|
465
|
+
):
|
|
466
|
+
"""Build a new MCP server.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
command: The command to run.
|
|
470
|
+
args: The arguments to pass to the command.
|
|
471
|
+
env: The environment variables to set in the subprocess.
|
|
472
|
+
cwd: The working directory to use when spawning the process.
|
|
473
|
+
tool_prefix: A prefix to add to all tools that are registered with the server.
|
|
474
|
+
log_level: The log level to set when connecting to the server, if any.
|
|
475
|
+
log_handler: A handler for logging messages from the server.
|
|
476
|
+
timeout: The timeout in seconds to wait for the client to initialize.
|
|
477
|
+
read_timeout: Maximum time in seconds to wait for new messages before timing out.
|
|
478
|
+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
|
|
479
|
+
allow_sampling: Whether to allow MCP sampling through this client.
|
|
480
|
+
sampling_model: The model to use for sampling.
|
|
481
|
+
max_retries: The maximum number of times to retry a tool call.
|
|
482
|
+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
|
|
483
|
+
"""
|
|
484
|
+
self.command = command
|
|
485
|
+
self.args = args
|
|
486
|
+
self.env = env
|
|
487
|
+
self.cwd = cwd
|
|
488
|
+
|
|
489
|
+
super().__init__(
|
|
490
|
+
tool_prefix,
|
|
491
|
+
log_level,
|
|
492
|
+
log_handler,
|
|
493
|
+
timeout,
|
|
494
|
+
read_timeout,
|
|
495
|
+
process_tool_call,
|
|
496
|
+
allow_sampling,
|
|
497
|
+
sampling_model,
|
|
498
|
+
max_retries,
|
|
499
|
+
id=id,
|
|
500
|
+
)
|
|
400
501
|
|
|
401
502
|
@asynccontextmanager
|
|
402
503
|
async def client_streams(
|
|
@@ -412,15 +513,20 @@ class MCPServerStdio(MCPServer):
|
|
|
412
513
|
yield read_stream, write_stream
|
|
413
514
|
|
|
414
515
|
def __repr__(self) -> str:
|
|
415
|
-
|
|
516
|
+
repr_args = [
|
|
517
|
+
f'command={self.command!r}',
|
|
518
|
+
f'args={self.args!r}',
|
|
519
|
+
]
|
|
520
|
+
if self.id:
|
|
521
|
+
repr_args.append(f'id={self.id!r}') # pragma: no cover
|
|
522
|
+
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
416
523
|
|
|
417
524
|
|
|
418
|
-
@dataclass(init=False)
|
|
419
525
|
class _MCPServerHTTP(MCPServer):
|
|
420
526
|
url: str
|
|
421
527
|
"""The URL of the endpoint on the MCP server."""
|
|
422
528
|
|
|
423
|
-
headers: dict[str, Any] | None
|
|
529
|
+
headers: dict[str, Any] | None
|
|
424
530
|
"""Optional HTTP headers to be sent with each request to the endpoint.
|
|
425
531
|
|
|
426
532
|
These headers will be passed directly to the underlying `httpx.AsyncClient`.
|
|
@@ -432,7 +538,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
432
538
|
See [`MCPServerHTTP.http_client`][pydantic_ai.mcp.MCPServerHTTP.http_client] for more information.
|
|
433
539
|
"""
|
|
434
540
|
|
|
435
|
-
http_client: httpx.AsyncClient | None
|
|
541
|
+
http_client: httpx.AsyncClient | None
|
|
436
542
|
"""An `httpx.AsyncClient` to use with the endpoint.
|
|
437
543
|
|
|
438
544
|
This client may be configured to use customized connection parameters like self-signed certificates.
|
|
@@ -452,16 +558,8 @@ class _MCPServerHTTP(MCPServer):
|
|
|
452
558
|
```
|
|
453
559
|
"""
|
|
454
560
|
|
|
455
|
-
read_timeout: float = 5 * 60
|
|
456
|
-
"""Maximum time in seconds to wait for new messages before timing out.
|
|
457
|
-
|
|
458
|
-
This timeout applies to the long-lived connection after it's established.
|
|
459
|
-
If no new messages are received within this time, the connection will be considered stale
|
|
460
|
-
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
461
|
-
"""
|
|
462
|
-
|
|
463
561
|
# last fields are re-defined from the parent class so they appear as fields
|
|
464
|
-
tool_prefix: str | None
|
|
562
|
+
tool_prefix: str | None
|
|
465
563
|
"""A prefix to add to all tools that are registered with the server.
|
|
466
564
|
|
|
467
565
|
If not empty, will include a trailing underscore (`_`).
|
|
@@ -469,7 +567,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
469
567
|
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
470
568
|
"""
|
|
471
569
|
|
|
472
|
-
log_level: mcp_types.LoggingLevel | None
|
|
570
|
+
log_level: mcp_types.LoggingLevel | None
|
|
473
571
|
"""The log level to set when connecting to the server, if any.
|
|
474
572
|
|
|
475
573
|
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
@@ -477,56 +575,81 @@ class _MCPServerHTTP(MCPServer):
|
|
|
477
575
|
If `None`, no log level will be set.
|
|
478
576
|
"""
|
|
479
577
|
|
|
480
|
-
log_handler: LoggingFnT | None
|
|
578
|
+
log_handler: LoggingFnT | None
|
|
481
579
|
"""A handler for logging messages from the server."""
|
|
482
580
|
|
|
483
|
-
timeout: float
|
|
581
|
+
timeout: float
|
|
484
582
|
"""Initial connection timeout in seconds for establishing the connection.
|
|
485
583
|
|
|
486
584
|
This timeout applies to the initial connection setup and handshake.
|
|
487
585
|
If the connection cannot be established within this time, the operation will fail.
|
|
488
586
|
"""
|
|
489
587
|
|
|
490
|
-
|
|
588
|
+
read_timeout: float
|
|
589
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
590
|
+
|
|
591
|
+
This timeout applies to the long-lived connection after it's established.
|
|
592
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
593
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
process_tool_call: ProcessToolCallback | None
|
|
491
597
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
492
598
|
|
|
493
|
-
allow_sampling: bool
|
|
599
|
+
allow_sampling: bool
|
|
494
600
|
"""Whether to allow MCP sampling through this client."""
|
|
495
601
|
|
|
496
|
-
|
|
497
|
-
"""The maximum number of times to retry a tool call."""
|
|
498
|
-
|
|
499
|
-
sampling_model: models.Model | None = None
|
|
602
|
+
sampling_model: models.Model | None
|
|
500
603
|
"""The model to use for sampling."""
|
|
501
604
|
|
|
605
|
+
max_retries: int
|
|
606
|
+
"""The maximum number of times to retry a tool call."""
|
|
607
|
+
|
|
502
608
|
def __init__(
|
|
503
609
|
self,
|
|
504
610
|
*,
|
|
505
611
|
url: str,
|
|
506
612
|
headers: dict[str, str] | None = None,
|
|
507
613
|
http_client: httpx.AsyncClient | None = None,
|
|
508
|
-
|
|
614
|
+
id: str | None = None,
|
|
509
615
|
tool_prefix: str | None = None,
|
|
510
616
|
log_level: mcp_types.LoggingLevel | None = None,
|
|
511
617
|
log_handler: LoggingFnT | None = None,
|
|
512
618
|
timeout: float = 5,
|
|
619
|
+
read_timeout: float | None = None,
|
|
513
620
|
process_tool_call: ProcessToolCallback | None = None,
|
|
514
621
|
allow_sampling: bool = True,
|
|
515
|
-
max_retries: int = 1,
|
|
516
622
|
sampling_model: models.Model | None = None,
|
|
517
|
-
|
|
623
|
+
max_retries: int = 1,
|
|
624
|
+
**_deprecated_kwargs: Any,
|
|
518
625
|
):
|
|
519
|
-
|
|
520
|
-
|
|
626
|
+
"""Build a new MCP server.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
url: The URL of the endpoint on the MCP server.
|
|
630
|
+
headers: Optional HTTP headers to be sent with each request to the endpoint.
|
|
631
|
+
http_client: An `httpx.AsyncClient` to use with the endpoint.
|
|
632
|
+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
|
|
633
|
+
tool_prefix: A prefix to add to all tools that are registered with the server.
|
|
634
|
+
log_level: The log level to set when connecting to the server, if any.
|
|
635
|
+
log_handler: A handler for logging messages from the server.
|
|
636
|
+
timeout: The timeout in seconds to wait for the client to initialize.
|
|
637
|
+
read_timeout: Maximum time in seconds to wait for new messages before timing out.
|
|
638
|
+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
|
|
639
|
+
allow_sampling: Whether to allow MCP sampling through this client.
|
|
640
|
+
sampling_model: The model to use for sampling.
|
|
641
|
+
max_retries: The maximum number of times to retry a tool call.
|
|
642
|
+
"""
|
|
643
|
+
if 'sse_read_timeout' in _deprecated_kwargs:
|
|
521
644
|
if read_timeout is not None:
|
|
522
645
|
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
|
|
523
646
|
|
|
524
647
|
warnings.warn(
|
|
525
648
|
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
|
|
526
649
|
)
|
|
527
|
-
read_timeout =
|
|
650
|
+
read_timeout = _deprecated_kwargs.pop('sse_read_timeout')
|
|
528
651
|
|
|
529
|
-
_utils.validate_empty_kwargs(
|
|
652
|
+
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
530
653
|
|
|
531
654
|
if read_timeout is None:
|
|
532
655
|
read_timeout = 5 * 60
|
|
@@ -534,16 +657,19 @@ class _MCPServerHTTP(MCPServer):
|
|
|
534
657
|
self.url = url
|
|
535
658
|
self.headers = headers
|
|
536
659
|
self.http_client = http_client
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
660
|
+
|
|
661
|
+
super().__init__(
|
|
662
|
+
tool_prefix,
|
|
663
|
+
log_level,
|
|
664
|
+
log_handler,
|
|
665
|
+
timeout,
|
|
666
|
+
read_timeout,
|
|
667
|
+
process_tool_call,
|
|
668
|
+
allow_sampling,
|
|
669
|
+
sampling_model,
|
|
670
|
+
max_retries,
|
|
671
|
+
id=id,
|
|
672
|
+
)
|
|
547
673
|
|
|
548
674
|
@property
|
|
549
675
|
@abstractmethod
|
|
@@ -606,10 +732,14 @@ class _MCPServerHTTP(MCPServer):
|
|
|
606
732
|
yield read_stream, write_stream
|
|
607
733
|
|
|
608
734
|
def __repr__(self) -> str: # pragma: no cover
|
|
609
|
-
|
|
735
|
+
repr_args = [
|
|
736
|
+
f'url={self.url!r}',
|
|
737
|
+
]
|
|
738
|
+
if self.id:
|
|
739
|
+
repr_args.append(f'id={self.id!r}')
|
|
740
|
+
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
610
741
|
|
|
611
742
|
|
|
612
|
-
@dataclass(init=False)
|
|
613
743
|
class MCPServerSSE(_MCPServerHTTP):
|
|
614
744
|
"""An MCP server that connects over streamable HTTP connections.
|
|
615
745
|
|
|
@@ -643,7 +773,6 @@ class MCPServerSSE(_MCPServerHTTP):
|
|
|
643
773
|
|
|
644
774
|
|
|
645
775
|
@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
|
|
646
|
-
@dataclass
|
|
647
776
|
class MCPServerHTTP(MCPServerSSE):
|
|
648
777
|
"""An MCP server that connects over HTTP using the old SSE transport.
|
|
649
778
|
|
|
@@ -672,7 +801,6 @@ class MCPServerHTTP(MCPServerSSE):
|
|
|
672
801
|
"""
|
|
673
802
|
|
|
674
803
|
|
|
675
|
-
@dataclass
|
|
676
804
|
class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
677
805
|
"""An MCP server that connects over HTTP using the Streamable HTTP transport.
|
|
678
806
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -490,8 +490,8 @@ _video_format_lookup: dict[str, VideoFormat] = {
|
|
|
490
490
|
class UserPromptPart:
|
|
491
491
|
"""A user prompt, generally written by the end user.
|
|
492
492
|
|
|
493
|
-
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.
|
|
494
|
-
[`Agent.run_sync`][pydantic_ai.
|
|
493
|
+
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.agent.AbstractAgent.run],
|
|
494
|
+
[`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream].
|
|
495
495
|
"""
|
|
496
496
|
|
|
497
497
|
content: str | Sequence[UserContent]
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -13,20 +13,32 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
13
13
|
from dataclasses import dataclass, field, replace
|
|
14
14
|
from datetime import datetime
|
|
15
15
|
from functools import cache, cached_property
|
|
16
|
-
from typing import Generic, TypeVar, overload
|
|
16
|
+
from typing import Any, Generic, TypeVar, overload
|
|
17
17
|
|
|
18
18
|
import httpx
|
|
19
19
|
from typing_extensions import Literal, TypeAliasType, TypedDict
|
|
20
20
|
|
|
21
|
-
from pydantic_ai.builtin_tools import AbstractBuiltinTool
|
|
22
|
-
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
23
|
-
|
|
24
21
|
from .. import _utils
|
|
25
22
|
from .._output import OutputObjectDefinition
|
|
26
23
|
from .._parts_manager import ModelResponsePartsManager
|
|
24
|
+
from .._run_context import RunContext
|
|
25
|
+
from ..builtin_tools import AbstractBuiltinTool
|
|
27
26
|
from ..exceptions import UserError
|
|
28
|
-
from ..messages import
|
|
27
|
+
from ..messages import (
|
|
28
|
+
AgentStreamEvent,
|
|
29
|
+
FileUrl,
|
|
30
|
+
FinalResultEvent,
|
|
31
|
+
ModelMessage,
|
|
32
|
+
ModelRequest,
|
|
33
|
+
ModelResponse,
|
|
34
|
+
ModelResponseStreamEvent,
|
|
35
|
+
PartStartEvent,
|
|
36
|
+
TextPart,
|
|
37
|
+
ToolCallPart,
|
|
38
|
+
VideoUrl,
|
|
39
|
+
)
|
|
29
40
|
from ..output import OutputMode
|
|
41
|
+
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
30
42
|
from ..profiles._json_schema import JsonSchemaTransformer
|
|
31
43
|
from ..settings import ModelSettings
|
|
32
44
|
from ..tools import ToolDefinition
|
|
@@ -344,6 +356,10 @@ class ModelRequestParameters:
|
|
|
344
356
|
output_tools: list[ToolDefinition] = field(default_factory=list)
|
|
345
357
|
allow_text_output: bool = True
|
|
346
358
|
|
|
359
|
+
@cached_property
|
|
360
|
+
def tool_defs(self) -> dict[str, ToolDefinition]:
|
|
361
|
+
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}
|
|
362
|
+
|
|
347
363
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
348
364
|
|
|
349
365
|
|
|
@@ -389,6 +405,7 @@ class Model(ABC):
|
|
|
389
405
|
messages: list[ModelMessage],
|
|
390
406
|
model_settings: ModelSettings | None,
|
|
391
407
|
model_request_parameters: ModelRequestParameters,
|
|
408
|
+
run_context: RunContext[Any] | None = None,
|
|
392
409
|
) -> AsyncIterator[StreamedResponse]:
|
|
393
410
|
"""Make a request to the model and return a streaming response."""
|
|
394
411
|
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
@@ -501,14 +518,40 @@ class Model(ABC):
|
|
|
501
518
|
class StreamedResponse(ABC):
|
|
502
519
|
"""Streamed response from an LLM when calling a tool."""
|
|
503
520
|
|
|
521
|
+
model_request_parameters: ModelRequestParameters
|
|
522
|
+
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
523
|
+
|
|
504
524
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
505
|
-
_event_iterator: AsyncIterator[
|
|
525
|
+
_event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
506
526
|
_usage: Usage = field(default_factory=Usage, init=False)
|
|
507
527
|
|
|
508
|
-
def __aiter__(self) -> AsyncIterator[
|
|
509
|
-
"""Stream the response as an async iterable of [`
|
|
528
|
+
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
529
|
+
"""Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
530
|
+
|
|
531
|
+
This proxies the `_event_iterator()` and emits all events, while also checking for matches
|
|
532
|
+
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
|
|
533
|
+
first match is found.
|
|
534
|
+
"""
|
|
510
535
|
if self._event_iterator is None:
|
|
511
|
-
|
|
536
|
+
|
|
537
|
+
async def iterator_with_final_event(
|
|
538
|
+
iterator: AsyncIterator[ModelResponseStreamEvent],
|
|
539
|
+
) -> AsyncIterator[AgentStreamEvent]:
|
|
540
|
+
async for event in iterator:
|
|
541
|
+
yield event
|
|
542
|
+
if (
|
|
543
|
+
final_result_event := _get_final_result_event(event, self.model_request_parameters)
|
|
544
|
+
) is not None:
|
|
545
|
+
self.final_result_event = final_result_event
|
|
546
|
+
yield final_result_event
|
|
547
|
+
break
|
|
548
|
+
|
|
549
|
+
# If we broke out of the above loop, we need to yield the rest of the events
|
|
550
|
+
# If we didn't, this will just be a no-op
|
|
551
|
+
async for event in iterator:
|
|
552
|
+
yield event
|
|
553
|
+
|
|
554
|
+
self._event_iterator = iterator_with_final_event(self._get_event_iterator())
|
|
512
555
|
return self._event_iterator
|
|
513
556
|
|
|
514
557
|
@abstractmethod
|
|
@@ -810,3 +853,16 @@ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: Output
|
|
|
810
853
|
json_schema=json_schema,
|
|
811
854
|
strict=schema_transformer.is_strict_compatible if o.strict is None else o.strict,
|
|
812
855
|
)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestParameters) -> FinalResultEvent | None:
|
|
859
|
+
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
860
|
+
if isinstance(e, PartStartEvent):
|
|
861
|
+
new_part = e.part
|
|
862
|
+
if isinstance(new_part, TextPart) and params.allow_text_output: # pragma: no branch
|
|
863
|
+
return FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
864
|
+
elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
|
|
865
|
+
if tool_def.kind == 'output':
|
|
866
|
+
return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
|
|
867
|
+
elif tool_def.kind == 'deferred':
|
|
868
|
+
return FinalResultEvent(tool_name=None, tool_call_id=None)
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -21,7 +21,9 @@ from typing_extensions import assert_never
|
|
|
21
21
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
22
22
|
|
|
23
23
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
24
|
+
from .._run_context import RunContext
|
|
24
25
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
26
|
+
from ..exceptions import UserError
|
|
25
27
|
from ..messages import (
|
|
26
28
|
BinaryContent,
|
|
27
29
|
BuiltinToolCallPart,
|
|
@@ -196,13 +198,14 @@ class AnthropicModel(Model):
|
|
|
196
198
|
messages: list[ModelMessage],
|
|
197
199
|
model_settings: ModelSettings | None,
|
|
198
200
|
model_request_parameters: ModelRequestParameters,
|
|
201
|
+
run_context: RunContext[Any] | None = None,
|
|
199
202
|
) -> AsyncIterator[StreamedResponse]:
|
|
200
203
|
check_allow_model_requests()
|
|
201
204
|
response = await self._messages_create(
|
|
202
205
|
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
203
206
|
)
|
|
204
207
|
async with response:
|
|
205
|
-
yield await self._process_streamed_response(response)
|
|
208
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
206
209
|
|
|
207
210
|
@property
|
|
208
211
|
def model_name(self) -> AnthropicModelName:
|
|
@@ -329,7 +332,9 @@ class AnthropicModel(Model):
|
|
|
329
332
|
|
|
330
333
|
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
|
|
331
334
|
|
|
332
|
-
async def _process_streamed_response(
|
|
335
|
+
async def _process_streamed_response(
|
|
336
|
+
self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
|
|
337
|
+
) -> StreamedResponse:
|
|
333
338
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
334
339
|
first_chunk = await peekable_response.peek()
|
|
335
340
|
if isinstance(first_chunk, _utils.Unset):
|
|
@@ -338,14 +343,14 @@ class AnthropicModel(Model):
|
|
|
338
343
|
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
339
344
|
timestamp = datetime.now(tz=timezone.utc)
|
|
340
345
|
return AnthropicStreamedResponse(
|
|
341
|
-
|
|
346
|
+
model_request_parameters=model_request_parameters,
|
|
347
|
+
_model_name=self._model_name,
|
|
348
|
+
_response=peekable_response,
|
|
349
|
+
_timestamp=timestamp,
|
|
342
350
|
)
|
|
343
351
|
|
|
344
352
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
|
|
345
|
-
|
|
346
|
-
if model_request_parameters.output_tools:
|
|
347
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
348
|
-
return tools
|
|
353
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
349
354
|
|
|
350
355
|
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
|
|
351
356
|
tools: list[BetaToolUnionParam] = []
|
|
@@ -363,6 +368,10 @@ class AnthropicModel(Model):
|
|
|
363
368
|
)
|
|
364
369
|
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
365
370
|
tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
|
|
371
|
+
else: # pragma: no cover
|
|
372
|
+
raise UserError(
|
|
373
|
+
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
|
|
374
|
+
)
|
|
366
375
|
return tools
|
|
367
376
|
|
|
368
377
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|