pydantic-ai-slim 0.7.6__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_cli.py +2 -1
- pydantic_ai/ag_ui.py +2 -2
- pydantic_ai/agent/__init__.py +22 -16
- pydantic_ai/agent/abstract.py +31 -18
- pydantic_ai/direct.py +5 -3
- pydantic_ai/durable_exec/temporal/__init__.py +67 -16
- pydantic_ai/durable_exec/temporal/_function_toolset.py +9 -2
- pydantic_ai/durable_exec/temporal/_logfire.py +5 -2
- pydantic_ai/mcp.py +48 -71
- pydantic_ai/messages.py +54 -13
- pydantic_ai/models/__init__.py +18 -8
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/bedrock.py +6 -2
- pydantic_ai/models/gemini.py +1 -1
- pydantic_ai/models/google.py +1 -1
- pydantic_ai/models/groq.py +1 -1
- pydantic_ai/models/huggingface.py +1 -1
- pydantic_ai/models/instrumented.py +14 -5
- pydantic_ai/models/mistral.py +2 -2
- pydantic_ai/models/openai.py +14 -4
- pydantic_ai/result.py +36 -18
- {pydantic_ai_slim-0.7.6.dist-info → pydantic_ai_slim-0.8.1.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-0.7.6.dist-info → pydantic_ai_slim-0.8.1.dist-info}/RECORD +26 -26
- {pydantic_ai_slim-0.7.6.dist-info → pydantic_ai_slim-0.8.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.6.dist-info → pydantic_ai_slim-0.8.1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.6.dist-info → pydantic_ai_slim-0.8.1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_cli.py
CHANGED
|
@@ -228,6 +228,7 @@ async def run_chat(
|
|
|
228
228
|
prog_name: str,
|
|
229
229
|
config_dir: Path | None = None,
|
|
230
230
|
deps: AgentDepsT = None,
|
|
231
|
+
message_history: list[ModelMessage] | None = None,
|
|
231
232
|
) -> int:
|
|
232
233
|
prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
|
|
233
234
|
prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -235,7 +236,7 @@ async def run_chat(
|
|
|
235
236
|
session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path)))
|
|
236
237
|
|
|
237
238
|
multiline = False
|
|
238
|
-
messages: list[ModelMessage] = []
|
|
239
|
+
messages: list[ModelMessage] = message_history[:] if message_history else []
|
|
239
240
|
|
|
240
241
|
while True:
|
|
241
242
|
try:
|
pydantic_ai/ag_ui.py
CHANGED
|
@@ -28,11 +28,11 @@ from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
|
28
28
|
from .agent import AbstractAgent, AgentRun
|
|
29
29
|
from .exceptions import UserError
|
|
30
30
|
from .messages import (
|
|
31
|
-
AgentStreamEvent,
|
|
32
31
|
FunctionToolResultEvent,
|
|
33
32
|
ModelMessage,
|
|
34
33
|
ModelRequest,
|
|
35
34
|
ModelResponse,
|
|
35
|
+
ModelResponseStreamEvent,
|
|
36
36
|
PartDeltaEvent,
|
|
37
37
|
PartStartEvent,
|
|
38
38
|
SystemPromptPart,
|
|
@@ -403,7 +403,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
|
|
|
403
403
|
|
|
404
404
|
async def _handle_model_request_event(
|
|
405
405
|
stream_ctx: _RequestStreamContext,
|
|
406
|
-
agent_event:
|
|
406
|
+
agent_event: ModelResponseStreamEvent,
|
|
407
407
|
) -> AsyncIterator[BaseEvent]:
|
|
408
408
|
"""Handle an agent event and yield AG-UI protocol events.
|
|
409
409
|
|
pydantic_ai/agent/__init__.py
CHANGED
|
@@ -26,7 +26,14 @@ from .. import (
|
|
|
26
26
|
models,
|
|
27
27
|
usage as _usage,
|
|
28
28
|
)
|
|
29
|
-
from .._agent_graph import
|
|
29
|
+
from .._agent_graph import (
|
|
30
|
+
CallToolsNode,
|
|
31
|
+
EndStrategy,
|
|
32
|
+
HistoryProcessor,
|
|
33
|
+
ModelRequestNode,
|
|
34
|
+
UserPromptNode,
|
|
35
|
+
capture_run_messages,
|
|
36
|
+
)
|
|
30
37
|
from .._output import OutputToolset
|
|
31
38
|
from .._tool_manager import ToolManager
|
|
32
39
|
from ..builtin_tools import AbstractBuiltinTool
|
|
@@ -60,13 +67,6 @@ from ..toolsets.prepared import PreparedToolset
|
|
|
60
67
|
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
|
|
61
68
|
from .wrapper import WrapperAgent
|
|
62
69
|
|
|
63
|
-
# Re-exporting like this improves auto-import behavior in PyCharm
|
|
64
|
-
capture_run_messages = _agent_graph.capture_run_messages
|
|
65
|
-
EndStrategy = _agent_graph.EndStrategy
|
|
66
|
-
CallToolsNode = _agent_graph.CallToolsNode
|
|
67
|
-
ModelRequestNode = _agent_graph.ModelRequestNode
|
|
68
|
-
UserPromptNode = _agent_graph.UserPromptNode
|
|
69
|
-
|
|
70
70
|
if TYPE_CHECKING:
|
|
71
71
|
from ..mcp import MCPServer
|
|
72
72
|
|
|
@@ -678,22 +678,28 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
678
678
|
self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
|
|
679
679
|
):
|
|
680
680
|
if settings.version == 1:
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
681
|
+
attrs = {
|
|
682
|
+
'all_messages_events': json.dumps(
|
|
683
|
+
[
|
|
684
|
+
InstrumentedModel.event_to_dict(e)
|
|
685
|
+
for e in settings.messages_to_otel_events(state.message_history)
|
|
686
|
+
]
|
|
687
|
+
)
|
|
688
|
+
}
|
|
685
689
|
else:
|
|
686
|
-
|
|
687
|
-
|
|
690
|
+
attrs = {
|
|
691
|
+
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
|
|
692
|
+
**settings.system_instructions_attributes(self._instructions),
|
|
693
|
+
}
|
|
688
694
|
|
|
689
695
|
return {
|
|
690
696
|
**usage.opentelemetry_attributes(),
|
|
691
|
-
|
|
697
|
+
**attrs,
|
|
692
698
|
'logfire.json_schema': json.dumps(
|
|
693
699
|
{
|
|
694
700
|
'type': 'object',
|
|
695
701
|
'properties': {
|
|
696
|
-
|
|
702
|
+
**{attr: {'type': 'array'} for attr in attrs.keys()},
|
|
697
703
|
'final_result': {'type': 'object'},
|
|
698
704
|
},
|
|
699
705
|
}
|
pydantic_ai/agent/abstract.py
CHANGED
|
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence
|
|
6
6
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
7
7
|
from types import FrameType
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic,
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import Self, TypeAlias, TypeIs, TypeVar
|
|
11
11
|
|
|
@@ -34,13 +34,6 @@ from ..tools import (
|
|
|
34
34
|
from ..toolsets import AbstractToolset
|
|
35
35
|
from ..usage import RunUsage, UsageLimits
|
|
36
36
|
|
|
37
|
-
# Re-exporting like this improves auto-import behavior in PyCharm
|
|
38
|
-
capture_run_messages = _agent_graph.capture_run_messages
|
|
39
|
-
EndStrategy = _agent_graph.EndStrategy
|
|
40
|
-
CallToolsNode = _agent_graph.CallToolsNode
|
|
41
|
-
ModelRequestNode = _agent_graph.ModelRequestNode
|
|
42
|
-
UserPromptNode = _agent_graph.UserPromptNode
|
|
43
|
-
|
|
44
37
|
if TYPE_CHECKING:
|
|
45
38
|
from fasta2a.applications import FastA2A
|
|
46
39
|
from fasta2a.broker import Broker
|
|
@@ -60,11 +53,7 @@ RunOutputDataT = TypeVar('RunOutputDataT')
|
|
|
60
53
|
"""Type variable for the result data of a run where `output_type` was customized on the run call."""
|
|
61
54
|
|
|
62
55
|
EventStreamHandler: TypeAlias = Callable[
|
|
63
|
-
[
|
|
64
|
-
RunContext[AgentDepsT],
|
|
65
|
-
AsyncIterable[Union[_messages.AgentStreamEvent, _messages.HandleResponseEvent]],
|
|
66
|
-
],
|
|
67
|
-
Awaitable[None],
|
|
56
|
+
[RunContext[AgentDepsT], AsyncIterable[_messages.AgentStreamEvent]], Awaitable[None]
|
|
68
57
|
]
|
|
69
58
|
"""A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools."""
|
|
70
59
|
|
|
@@ -452,7 +441,9 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
452
441
|
async with node.stream(graph_ctx) as stream:
|
|
453
442
|
final_result_event = None
|
|
454
443
|
|
|
455
|
-
async def stream_to_final(
|
|
444
|
+
async def stream_to_final(
|
|
445
|
+
stream: AgentStream,
|
|
446
|
+
) -> AsyncIterator[_messages.ModelResponseStreamEvent]:
|
|
456
447
|
nonlocal final_result_event
|
|
457
448
|
async for event in stream:
|
|
458
449
|
yield event
|
|
@@ -899,12 +890,18 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
899
890
|
lifespan=lifespan,
|
|
900
891
|
)
|
|
901
892
|
|
|
902
|
-
async def to_cli(
|
|
893
|
+
async def to_cli(
|
|
894
|
+
self: Self,
|
|
895
|
+
deps: AgentDepsT = None,
|
|
896
|
+
prog_name: str = 'pydantic-ai',
|
|
897
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
898
|
+
) -> None:
|
|
903
899
|
"""Run the agent in a CLI chat interface.
|
|
904
900
|
|
|
905
901
|
Args:
|
|
906
902
|
deps: The dependencies to pass to the agent.
|
|
907
903
|
prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'.
|
|
904
|
+
message_history: History of the conversation so far.
|
|
908
905
|
|
|
909
906
|
Example:
|
|
910
907
|
```python {title="agent_to_cli.py" test="skip"}
|
|
@@ -920,14 +917,28 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
920
917
|
|
|
921
918
|
from pydantic_ai._cli import run_chat
|
|
922
919
|
|
|
923
|
-
await run_chat(
|
|
920
|
+
await run_chat(
|
|
921
|
+
stream=True,
|
|
922
|
+
agent=self,
|
|
923
|
+
deps=deps,
|
|
924
|
+
console=Console(),
|
|
925
|
+
code_theme='monokai',
|
|
926
|
+
prog_name=prog_name,
|
|
927
|
+
message_history=message_history,
|
|
928
|
+
)
|
|
924
929
|
|
|
925
|
-
def to_cli_sync(
|
|
930
|
+
def to_cli_sync(
|
|
931
|
+
self: Self,
|
|
932
|
+
deps: AgentDepsT = None,
|
|
933
|
+
prog_name: str = 'pydantic-ai',
|
|
934
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
935
|
+
) -> None:
|
|
926
936
|
"""Run the agent in a CLI chat interface with the non-async interface.
|
|
927
937
|
|
|
928
938
|
Args:
|
|
929
939
|
deps: The dependencies to pass to the agent.
|
|
930
940
|
prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'.
|
|
941
|
+
message_history: History of the conversation so far.
|
|
931
942
|
|
|
932
943
|
```python {title="agent_to_cli_sync.py" test="skip"}
|
|
933
944
|
from pydantic_ai import Agent
|
|
@@ -937,4 +948,6 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
937
948
|
agent.to_cli_sync(prog_name='assistant')
|
|
938
949
|
```
|
|
939
950
|
"""
|
|
940
|
-
return get_event_loop().run_until_complete(
|
|
951
|
+
return get_event_loop().run_until_complete(
|
|
952
|
+
self.to_cli(deps=deps, prog_name=prog_name, message_history=message_history)
|
|
953
|
+
)
|
pydantic_ai/direct.py
CHANGED
|
@@ -275,7 +275,9 @@ class StreamedResponseSync:
|
|
|
275
275
|
"""
|
|
276
276
|
|
|
277
277
|
_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
|
|
278
|
-
_queue: queue.Queue[messages.
|
|
278
|
+
_queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field(
|
|
279
|
+
default_factory=queue.Queue, init=False
|
|
280
|
+
)
|
|
279
281
|
_thread: threading.Thread | None = field(default=None, init=False)
|
|
280
282
|
_stream_response: StreamedResponse | None = field(default=None, init=False)
|
|
281
283
|
_exception: Exception | None = field(default=None, init=False)
|
|
@@ -295,8 +297,8 @@ class StreamedResponseSync:
|
|
|
295
297
|
) -> None:
|
|
296
298
|
self._cleanup()
|
|
297
299
|
|
|
298
|
-
def __iter__(self) -> Iterator[messages.
|
|
299
|
-
"""Stream the response as an iterable of [`
|
|
300
|
+
def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]:
|
|
301
|
+
"""Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
|
|
300
302
|
self._check_context_manager_usage()
|
|
301
303
|
|
|
302
304
|
while True:
|
|
@@ -1,15 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from collections.abc import Sequence
|
|
4
|
+
from collections.abc import AsyncIterator, Sequence
|
|
5
|
+
from contextlib import AbstractAsyncContextManager
|
|
5
6
|
from dataclasses import replace
|
|
6
7
|
from typing import Any, Callable
|
|
7
8
|
|
|
8
9
|
from pydantic.errors import PydanticUserError
|
|
9
|
-
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
10
|
+
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
|
|
10
11
|
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
|
|
11
|
-
from temporalio.converter import DefaultPayloadConverter
|
|
12
|
-
from temporalio.
|
|
12
|
+
from temporalio.converter import DataConverter, DefaultPayloadConverter
|
|
13
|
+
from temporalio.service import ConnectConfig, ServiceClient
|
|
14
|
+
from temporalio.worker import (
|
|
15
|
+
Plugin as WorkerPlugin,
|
|
16
|
+
Replayer,
|
|
17
|
+
ReplayerConfig,
|
|
18
|
+
Worker,
|
|
19
|
+
WorkerConfig,
|
|
20
|
+
WorkflowReplayResult,
|
|
21
|
+
)
|
|
13
22
|
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
|
|
14
23
|
|
|
15
24
|
from ...exceptions import UserError
|
|
@@ -31,17 +40,15 @@ __all__ = [
|
|
|
31
40
|
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
32
41
|
"""Temporal client and worker plugin for Pydantic AI."""
|
|
33
42
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
DefaultPayloadConverter,
|
|
37
|
-
PydanticPayloadConverter,
|
|
38
|
-
):
|
|
39
|
-
warnings.warn( # pragma: no cover
|
|
40
|
-
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
|
|
41
|
-
)
|
|
43
|
+
def init_client_plugin(self, next: ClientPlugin) -> None:
|
|
44
|
+
self.next_client_plugin = next
|
|
42
45
|
|
|
43
|
-
|
|
44
|
-
|
|
46
|
+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
|
|
47
|
+
self.next_worker_plugin = next
|
|
48
|
+
|
|
49
|
+
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
50
|
+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
|
|
51
|
+
return self.next_client_plugin.configure_client(config)
|
|
45
52
|
|
|
46
53
|
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
47
54
|
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
|
|
@@ -67,7 +74,35 @@ class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
|
|
|
67
74
|
PydanticUserError,
|
|
68
75
|
]
|
|
69
76
|
|
|
70
|
-
return
|
|
77
|
+
return self.next_worker_plugin.configure_worker(config)
|
|
78
|
+
|
|
79
|
+
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
80
|
+
return await self.next_client_plugin.connect_service_client(config)
|
|
81
|
+
|
|
82
|
+
async def run_worker(self, worker: Worker) -> None:
|
|
83
|
+
await self.next_worker_plugin.run_worker(worker)
|
|
84
|
+
|
|
85
|
+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
|
|
86
|
+
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
|
|
87
|
+
return self.next_worker_plugin.configure_replayer(config)
|
|
88
|
+
|
|
89
|
+
def run_replayer(
|
|
90
|
+
self,
|
|
91
|
+
replayer: Replayer,
|
|
92
|
+
histories: AsyncIterator[WorkflowHistory],
|
|
93
|
+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
|
|
94
|
+
return self.next_worker_plugin.run_replayer(replayer, histories)
|
|
95
|
+
|
|
96
|
+
def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
|
|
97
|
+
if converter and converter.payload_converter_class not in (
|
|
98
|
+
DefaultPayloadConverter,
|
|
99
|
+
PydanticPayloadConverter,
|
|
100
|
+
):
|
|
101
|
+
warnings.warn( # pragma: no cover
|
|
102
|
+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return pydantic_data_converter
|
|
71
106
|
|
|
72
107
|
|
|
73
108
|
class AgentPlugin(WorkerPlugin):
|
|
@@ -76,8 +111,24 @@ class AgentPlugin(WorkerPlugin):
|
|
|
76
111
|
def __init__(self, agent: TemporalAgent[Any, Any]):
|
|
77
112
|
self.agent = agent
|
|
78
113
|
|
|
114
|
+
def init_worker_plugin(self, next: WorkerPlugin) -> None:
|
|
115
|
+
self.next_worker_plugin = next
|
|
116
|
+
|
|
79
117
|
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
|
|
80
118
|
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
|
|
81
119
|
# Activities are checked for name conflicts by Temporal.
|
|
82
120
|
config['activities'] = [*activities, *self.agent.temporal_activities]
|
|
83
|
-
return
|
|
121
|
+
return self.next_worker_plugin.configure_worker(config)
|
|
122
|
+
|
|
123
|
+
async def run_worker(self, worker: Worker) -> None:
|
|
124
|
+
await self.next_worker_plugin.run_worker(worker)
|
|
125
|
+
|
|
126
|
+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
|
|
127
|
+
return self.next_worker_plugin.configure_replayer(config)
|
|
128
|
+
|
|
129
|
+
def run_replayer(
|
|
130
|
+
self,
|
|
131
|
+
replayer: Replayer,
|
|
132
|
+
histories: AsyncIterator[WorkflowHistory],
|
|
133
|
+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
|
|
134
|
+
return self.next_worker_plugin.run_replayer(replayer, histories)
|
|
@@ -51,7 +51,10 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
51
51
|
'Removing or renaming tools during an agent run is not supported with Temporal.'
|
|
52
52
|
) from e
|
|
53
53
|
|
|
54
|
-
|
|
54
|
+
# The tool args will already have been validated into their proper types in the `ToolManager`,
|
|
55
|
+
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
|
|
56
|
+
args_dict = tool.args_validator.validate_python(params.tool_args)
|
|
57
|
+
return await self.wrapped.call_tool(name, args_dict, ctx, tool)
|
|
55
58
|
|
|
56
59
|
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
|
|
57
60
|
call_tool_activity.__annotations__['deps'] = deps_type
|
|
@@ -85,7 +88,11 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
|
|
|
85
88
|
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
|
|
86
89
|
activity=self.call_tool_activity,
|
|
87
90
|
args=[
|
|
88
|
-
_CallToolParams(
|
|
91
|
+
_CallToolParams(
|
|
92
|
+
name=name,
|
|
93
|
+
tool_args=tool_args,
|
|
94
|
+
serialized_run_context=serialized_run_context,
|
|
95
|
+
),
|
|
89
96
|
ctx.deps,
|
|
90
97
|
],
|
|
91
98
|
**tool_activity_config,
|
|
@@ -25,10 +25,13 @@ class LogfirePlugin(ClientPlugin):
|
|
|
25
25
|
self.setup_logfire = setup_logfire
|
|
26
26
|
self.metrics = metrics
|
|
27
27
|
|
|
28
|
+
def init_client_plugin(self, next: ClientPlugin) -> None:
|
|
29
|
+
self.next_client_plugin = next
|
|
30
|
+
|
|
28
31
|
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
29
32
|
interceptors = config.get('interceptors', [])
|
|
30
33
|
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
31
|
-
return
|
|
34
|
+
return self.next_client_plugin.configure_client(config)
|
|
32
35
|
|
|
33
36
|
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
|
|
34
37
|
logfire = self.setup_logfire()
|
|
@@ -45,4 +48,4 @@ class LogfirePlugin(ClientPlugin):
|
|
|
45
48
|
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
|
|
46
49
|
)
|
|
47
50
|
|
|
48
|
-
return await
|
|
51
|
+
return await self.next_client_plugin.connect_service_client(config)
|
pydantic_ai/mcp.py
CHANGED
|
@@ -18,14 +18,13 @@ import pydantic_core
|
|
|
18
18
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
19
19
|
from typing_extensions import Self, assert_never, deprecated
|
|
20
20
|
|
|
21
|
-
from pydantic_ai.
|
|
22
|
-
from pydantic_ai.tools import ToolDefinition
|
|
21
|
+
from pydantic_ai.tools import RunContext, ToolDefinition
|
|
23
22
|
|
|
24
23
|
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
25
24
|
|
|
26
25
|
try:
|
|
27
26
|
from mcp import types as mcp_types
|
|
28
|
-
from mcp.client.session import ClientSession, LoggingFnT
|
|
27
|
+
from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT
|
|
29
28
|
from mcp.client.sse import sse_client
|
|
30
29
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
31
30
|
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
@@ -57,14 +56,49 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
57
56
|
"""
|
|
58
57
|
|
|
59
58
|
tool_prefix: str | None
|
|
59
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
60
|
+
|
|
61
|
+
If not empty, will include a trailing underscore(`_`).
|
|
62
|
+
|
|
63
|
+
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
64
|
+
"""
|
|
65
|
+
|
|
60
66
|
log_level: mcp_types.LoggingLevel | None
|
|
67
|
+
"""The log level to set when connecting to the server, if any.
|
|
68
|
+
|
|
69
|
+
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
70
|
+
|
|
71
|
+
If `None`, no log level will be set.
|
|
72
|
+
"""
|
|
73
|
+
|
|
61
74
|
log_handler: LoggingFnT | None
|
|
75
|
+
"""A handler for logging messages from the server."""
|
|
76
|
+
|
|
62
77
|
timeout: float
|
|
78
|
+
"""The timeout in seconds to wait for the client to initialize."""
|
|
79
|
+
|
|
63
80
|
read_timeout: float
|
|
81
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
82
|
+
|
|
83
|
+
This timeout applies to the long-lived connection after it's established.
|
|
84
|
+
If no new messages are received within this time, the connection will be considered stale
|
|
85
|
+
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
86
|
+
"""
|
|
87
|
+
|
|
64
88
|
process_tool_call: ProcessToolCallback | None
|
|
89
|
+
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
90
|
+
|
|
65
91
|
allow_sampling: bool
|
|
92
|
+
"""Whether to allow MCP sampling through this client."""
|
|
93
|
+
|
|
66
94
|
sampling_model: models.Model | None
|
|
95
|
+
"""The model to use for sampling."""
|
|
96
|
+
|
|
67
97
|
max_retries: int
|
|
98
|
+
"""The maximum number of times to retry a tool call."""
|
|
99
|
+
|
|
100
|
+
elicitation_callback: ElicitationFnT | None = None
|
|
101
|
+
"""Callback function to handle elicitation requests from the server."""
|
|
68
102
|
|
|
69
103
|
_id: str | None
|
|
70
104
|
|
|
@@ -87,6 +121,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
87
121
|
allow_sampling: bool = True,
|
|
88
122
|
sampling_model: models.Model | None = None,
|
|
89
123
|
max_retries: int = 1,
|
|
124
|
+
elicitation_callback: ElicitationFnT | None = None,
|
|
90
125
|
*,
|
|
91
126
|
id: str | None = None,
|
|
92
127
|
):
|
|
@@ -99,6 +134,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
99
134
|
self.allow_sampling = allow_sampling
|
|
100
135
|
self.sampling_model = sampling_model
|
|
101
136
|
self.max_retries = max_retries
|
|
137
|
+
self.elicitation_callback = elicitation_callback
|
|
102
138
|
|
|
103
139
|
self._id = id or tool_prefix
|
|
104
140
|
|
|
@@ -247,6 +283,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
247
283
|
read_stream=self._read_stream,
|
|
248
284
|
write_stream=self._write_stream,
|
|
249
285
|
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
286
|
+
elicitation_callback=self.elicitation_callback,
|
|
250
287
|
logging_callback=self.log_handler,
|
|
251
288
|
read_timeout_seconds=timedelta(seconds=self.read_timeout),
|
|
252
289
|
)
|
|
@@ -404,46 +441,15 @@ class MCPServerStdio(MCPServer):
|
|
|
404
441
|
|
|
405
442
|
# last fields are re-defined from the parent class so they appear as fields
|
|
406
443
|
tool_prefix: str | None
|
|
407
|
-
"""A prefix to add to all tools that are registered with the server.
|
|
408
|
-
|
|
409
|
-
If not empty, will include a trailing underscore(`_`).
|
|
410
|
-
|
|
411
|
-
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
412
|
-
"""
|
|
413
|
-
|
|
414
444
|
log_level: mcp_types.LoggingLevel | None
|
|
415
|
-
"""The log level to set when connecting to the server, if any.
|
|
416
|
-
|
|
417
|
-
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
418
|
-
|
|
419
|
-
If `None`, no log level will be set.
|
|
420
|
-
"""
|
|
421
|
-
|
|
422
445
|
log_handler: LoggingFnT | None
|
|
423
|
-
"""A handler for logging messages from the server."""
|
|
424
|
-
|
|
425
446
|
timeout: float
|
|
426
|
-
"""The timeout in seconds to wait for the client to initialize."""
|
|
427
|
-
|
|
428
447
|
read_timeout: float
|
|
429
|
-
"""Maximum time in seconds to wait for new messages before timing out.
|
|
430
|
-
|
|
431
|
-
This timeout applies to the long-lived connection after it's established.
|
|
432
|
-
If no new messages are received within this time, the connection will be considered stale
|
|
433
|
-
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
434
|
-
"""
|
|
435
|
-
|
|
436
448
|
process_tool_call: ProcessToolCallback | None
|
|
437
|
-
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
438
|
-
|
|
439
449
|
allow_sampling: bool
|
|
440
|
-
"""Whether to allow MCP sampling through this client."""
|
|
441
|
-
|
|
442
450
|
sampling_model: models.Model | None
|
|
443
|
-
"""The model to use for sampling."""
|
|
444
|
-
|
|
445
451
|
max_retries: int
|
|
446
|
-
|
|
452
|
+
elicitation_callback: ElicitationFnT | None = None
|
|
447
453
|
|
|
448
454
|
def __init__(
|
|
449
455
|
self,
|
|
@@ -460,6 +466,7 @@ class MCPServerStdio(MCPServer):
|
|
|
460
466
|
allow_sampling: bool = True,
|
|
461
467
|
sampling_model: models.Model | None = None,
|
|
462
468
|
max_retries: int = 1,
|
|
469
|
+
elicitation_callback: ElicitationFnT | None = None,
|
|
463
470
|
*,
|
|
464
471
|
id: str | None = None,
|
|
465
472
|
):
|
|
@@ -479,6 +486,7 @@ class MCPServerStdio(MCPServer):
|
|
|
479
486
|
allow_sampling: Whether to allow MCP sampling through this client.
|
|
480
487
|
sampling_model: The model to use for sampling.
|
|
481
488
|
max_retries: The maximum number of times to retry a tool call.
|
|
489
|
+
elicitation_callback: Callback function to handle elicitation requests from the server.
|
|
482
490
|
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
|
|
483
491
|
"""
|
|
484
492
|
self.command = command
|
|
@@ -496,6 +504,7 @@ class MCPServerStdio(MCPServer):
|
|
|
496
504
|
allow_sampling,
|
|
497
505
|
sampling_model,
|
|
498
506
|
max_retries,
|
|
507
|
+
elicitation_callback,
|
|
499
508
|
id=id,
|
|
500
509
|
)
|
|
501
510
|
|
|
@@ -560,50 +569,15 @@ class _MCPServerHTTP(MCPServer):
|
|
|
560
569
|
|
|
561
570
|
# last fields are re-defined from the parent class so they appear as fields
|
|
562
571
|
tool_prefix: str | None
|
|
563
|
-
"""A prefix to add to all tools that are registered with the server.
|
|
564
|
-
|
|
565
|
-
If not empty, will include a trailing underscore (`_`).
|
|
566
|
-
|
|
567
|
-
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
568
|
-
"""
|
|
569
|
-
|
|
570
572
|
log_level: mcp_types.LoggingLevel | None
|
|
571
|
-
"""The log level to set when connecting to the server, if any.
|
|
572
|
-
|
|
573
|
-
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
574
|
-
|
|
575
|
-
If `None`, no log level will be set.
|
|
576
|
-
"""
|
|
577
|
-
|
|
578
573
|
log_handler: LoggingFnT | None
|
|
579
|
-
"""A handler for logging messages from the server."""
|
|
580
|
-
|
|
581
574
|
timeout: float
|
|
582
|
-
"""Initial connection timeout in seconds for establishing the connection.
|
|
583
|
-
|
|
584
|
-
This timeout applies to the initial connection setup and handshake.
|
|
585
|
-
If the connection cannot be established within this time, the operation will fail.
|
|
586
|
-
"""
|
|
587
|
-
|
|
588
575
|
read_timeout: float
|
|
589
|
-
"""Maximum time in seconds to wait for new messages before timing out.
|
|
590
|
-
|
|
591
|
-
This timeout applies to the long-lived connection after it's established.
|
|
592
|
-
If no new messages are received within this time, the connection will be considered stale
|
|
593
|
-
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
594
|
-
"""
|
|
595
|
-
|
|
596
576
|
process_tool_call: ProcessToolCallback | None
|
|
597
|
-
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
598
|
-
|
|
599
577
|
allow_sampling: bool
|
|
600
|
-
"""Whether to allow MCP sampling through this client."""
|
|
601
|
-
|
|
602
578
|
sampling_model: models.Model | None
|
|
603
|
-
"""The model to use for sampling."""
|
|
604
|
-
|
|
605
579
|
max_retries: int
|
|
606
|
-
|
|
580
|
+
elicitation_callback: ElicitationFnT | None = None
|
|
607
581
|
|
|
608
582
|
def __init__(
|
|
609
583
|
self,
|
|
@@ -621,6 +595,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
621
595
|
allow_sampling: bool = True,
|
|
622
596
|
sampling_model: models.Model | None = None,
|
|
623
597
|
max_retries: int = 1,
|
|
598
|
+
elicitation_callback: ElicitationFnT | None = None,
|
|
624
599
|
**_deprecated_kwargs: Any,
|
|
625
600
|
):
|
|
626
601
|
"""Build a new MCP server.
|
|
@@ -639,6 +614,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
639
614
|
allow_sampling: Whether to allow MCP sampling through this client.
|
|
640
615
|
sampling_model: The model to use for sampling.
|
|
641
616
|
max_retries: The maximum number of times to retry a tool call.
|
|
617
|
+
elicitation_callback: Callback function to handle elicitation requests from the server.
|
|
642
618
|
"""
|
|
643
619
|
if 'sse_read_timeout' in _deprecated_kwargs:
|
|
644
620
|
if read_timeout is not None:
|
|
@@ -668,6 +644,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
668
644
|
allow_sampling,
|
|
669
645
|
sampling_model,
|
|
670
646
|
max_retries,
|
|
647
|
+
elicitation_callback,
|
|
671
648
|
id=id,
|
|
672
649
|
)
|
|
673
650
|
|