pydantic-ai-slim 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +208 -127
- pydantic_ai/ag_ui.py +44 -33
- pydantic_ai/agent/__init__.py +38 -46
- pydantic_ai/agent/abstract.py +7 -7
- pydantic_ai/agent/wrapper.py +0 -1
- pydantic_ai/builtin_tools.py +18 -9
- pydantic_ai/durable_exec/dbos/_agent.py +14 -10
- pydantic_ai/durable_exec/dbos/_mcp_server.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +0 -1
- pydantic_ai/durable_exec/temporal/_logfire.py +15 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +17 -12
- pydantic_ai/mcp.py +120 -2
- pydantic_ai/models/cohere.py +2 -2
- pydantic_ai/models/openai.py +54 -9
- pydantic_ai/run.py +0 -2
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +50 -9
- {pydantic_ai_slim-1.0.5.dist-info → pydantic_ai_slim-1.0.7.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-1.0.5.dist-info → pydantic_ai_slim-1.0.7.dist-info}/RECORD +22 -22
- {pydantic_ai_slim-1.0.5.dist-info → pydantic_ai_slim-1.0.7.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.5.dist-info → pydantic_ai_slim-1.0.7.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.5.dist-info → pydantic_ai_slim-1.0.7.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/agent/__init__.py
CHANGED
|
@@ -45,15 +45,11 @@ from ..run import AgentRun, AgentRunResult
|
|
|
45
45
|
from ..settings import ModelSettings, merge_model_settings
|
|
46
46
|
from ..tools import (
|
|
47
47
|
AgentDepsT,
|
|
48
|
-
DeferredToolCallResult,
|
|
49
|
-
DeferredToolResult,
|
|
50
48
|
DeferredToolResults,
|
|
51
49
|
DocstringFormat,
|
|
52
50
|
GenerateToolJsonSchema,
|
|
53
51
|
RunContext,
|
|
54
52
|
Tool,
|
|
55
|
-
ToolApproved,
|
|
56
|
-
ToolDenied,
|
|
57
53
|
ToolFuncContext,
|
|
58
54
|
ToolFuncEither,
|
|
59
55
|
ToolFuncPlain,
|
|
@@ -462,7 +458,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
462
458
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
463
459
|
|
|
464
460
|
@asynccontextmanager
|
|
465
|
-
async def iter(
|
|
461
|
+
async def iter(
|
|
466
462
|
self,
|
|
467
463
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
468
464
|
*,
|
|
@@ -505,7 +501,6 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
505
501
|
[
|
|
506
502
|
UserPromptNode(
|
|
507
503
|
user_prompt='What is the capital of France?',
|
|
508
|
-
instructions=None,
|
|
509
504
|
instructions_functions=[],
|
|
510
505
|
system_prompts=(),
|
|
511
506
|
system_prompt_functions=[],
|
|
@@ -559,7 +554,6 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
559
554
|
del model
|
|
560
555
|
|
|
561
556
|
deps = self._get_deps(deps)
|
|
562
|
-
new_message_index = len(message_history) if message_history else 0
|
|
563
557
|
output_schema = self._prepare_output_schema(output_type, model_used.profile)
|
|
564
558
|
|
|
565
559
|
output_type_ = output_type or self.output_type
|
|
@@ -620,27 +614,12 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
620
614
|
instrumentation_settings = None
|
|
621
615
|
tracer = NoOpTracer()
|
|
622
616
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
for tool_call_id, approval in deferred_tool_results.approvals.items():
|
|
627
|
-
if approval is True:
|
|
628
|
-
approval = ToolApproved()
|
|
629
|
-
elif approval is False:
|
|
630
|
-
approval = ToolDenied()
|
|
631
|
-
tool_call_results[tool_call_id] = approval
|
|
632
|
-
|
|
633
|
-
if calls := deferred_tool_results.calls:
|
|
634
|
-
call_result_types = _utils.get_union_args(DeferredToolCallResult)
|
|
635
|
-
for tool_call_id, result in calls.items():
|
|
636
|
-
if not isinstance(result, call_result_types):
|
|
637
|
-
result = _messages.ToolReturn(result)
|
|
638
|
-
tool_call_results[tool_call_id] = result
|
|
639
|
-
|
|
640
|
-
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
617
|
+
graph_deps = _agent_graph.GraphAgentDeps[
|
|
618
|
+
AgentDepsT, RunOutputDataT
|
|
619
|
+
](
|
|
641
620
|
user_deps=deps,
|
|
642
621
|
prompt=user_prompt,
|
|
643
|
-
new_message_index=
|
|
622
|
+
new_message_index=0, # This will be set in `UserPromptNode` based on the length of the cleaned message history
|
|
644
623
|
model=model_used,
|
|
645
624
|
model_settings=model_settings,
|
|
646
625
|
usage_limits=usage_limits,
|
|
@@ -651,13 +630,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
651
630
|
history_processors=self.history_processors,
|
|
652
631
|
builtin_tools=list(self._builtin_tools),
|
|
653
632
|
tool_manager=tool_manager,
|
|
654
|
-
tool_call_results=tool_call_results,
|
|
655
633
|
tracer=tracer,
|
|
656
634
|
get_instructions=get_instructions,
|
|
657
635
|
instrumentation_settings=instrumentation_settings,
|
|
658
636
|
)
|
|
659
637
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
660
638
|
user_prompt=user_prompt,
|
|
639
|
+
deferred_tool_results=deferred_tool_results,
|
|
661
640
|
instructions=self._instructions,
|
|
662
641
|
instructions_functions=self._instructions_functions,
|
|
663
642
|
system_prompts=self._system_prompts,
|
|
@@ -1005,7 +984,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1005
984
|
require_parameter_descriptions: bool = False,
|
|
1006
985
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1007
986
|
strict: bool | None = None,
|
|
987
|
+
sequential: bool = False,
|
|
1008
988
|
requires_approval: bool = False,
|
|
989
|
+
metadata: dict[str, Any] | None = None,
|
|
1009
990
|
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
1010
991
|
|
|
1011
992
|
def tool(
|
|
@@ -1020,7 +1001,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1020
1001
|
require_parameter_descriptions: bool = False,
|
|
1021
1002
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1022
1003
|
strict: bool | None = None,
|
|
1004
|
+
sequential: bool = False,
|
|
1023
1005
|
requires_approval: bool = False,
|
|
1006
|
+
metadata: dict[str, Any] | None = None,
|
|
1024
1007
|
) -> Any:
|
|
1025
1008
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
1026
1009
|
|
|
@@ -1065,8 +1048,10 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1065
1048
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1066
1049
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1067
1050
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1051
|
+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
|
|
1068
1052
|
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1069
1053
|
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1054
|
+
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
|
|
1070
1055
|
"""
|
|
1071
1056
|
|
|
1072
1057
|
def tool_decorator(
|
|
@@ -1075,15 +1060,17 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1075
1060
|
# noinspection PyTypeChecker
|
|
1076
1061
|
self._function_toolset.add_function(
|
|
1077
1062
|
func_,
|
|
1078
|
-
True,
|
|
1079
|
-
name,
|
|
1080
|
-
retries,
|
|
1081
|
-
prepare,
|
|
1082
|
-
docstring_format,
|
|
1083
|
-
require_parameter_descriptions,
|
|
1084
|
-
schema_generator,
|
|
1085
|
-
strict,
|
|
1086
|
-
|
|
1063
|
+
takes_ctx=True,
|
|
1064
|
+
name=name,
|
|
1065
|
+
retries=retries,
|
|
1066
|
+
prepare=prepare,
|
|
1067
|
+
docstring_format=docstring_format,
|
|
1068
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
1069
|
+
schema_generator=schema_generator,
|
|
1070
|
+
strict=strict,
|
|
1071
|
+
sequential=sequential,
|
|
1072
|
+
requires_approval=requires_approval,
|
|
1073
|
+
metadata=metadata,
|
|
1087
1074
|
)
|
|
1088
1075
|
return func_
|
|
1089
1076
|
|
|
@@ -1104,7 +1091,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1104
1091
|
require_parameter_descriptions: bool = False,
|
|
1105
1092
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1106
1093
|
strict: bool | None = None,
|
|
1094
|
+
sequential: bool = False,
|
|
1107
1095
|
requires_approval: bool = False,
|
|
1096
|
+
metadata: dict[str, Any] | None = None,
|
|
1108
1097
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
1109
1098
|
|
|
1110
1099
|
def tool_plain(
|
|
@@ -1121,6 +1110,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1121
1110
|
strict: bool | None = None,
|
|
1122
1111
|
sequential: bool = False,
|
|
1123
1112
|
requires_approval: bool = False,
|
|
1113
|
+
metadata: dict[str, Any] | None = None,
|
|
1124
1114
|
) -> Any:
|
|
1125
1115
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
1126
1116
|
|
|
@@ -1168,22 +1158,24 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1168
1158
|
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
|
|
1169
1159
|
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1170
1160
|
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1161
|
+
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
|
|
1171
1162
|
"""
|
|
1172
1163
|
|
|
1173
1164
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
1174
1165
|
# noinspection PyTypeChecker
|
|
1175
1166
|
self._function_toolset.add_function(
|
|
1176
1167
|
func_,
|
|
1177
|
-
False,
|
|
1178
|
-
name,
|
|
1179
|
-
retries,
|
|
1180
|
-
prepare,
|
|
1181
|
-
docstring_format,
|
|
1182
|
-
require_parameter_descriptions,
|
|
1183
|
-
schema_generator,
|
|
1184
|
-
strict,
|
|
1185
|
-
sequential,
|
|
1186
|
-
requires_approval,
|
|
1168
|
+
takes_ctx=False,
|
|
1169
|
+
name=name,
|
|
1170
|
+
retries=retries,
|
|
1171
|
+
prepare=prepare,
|
|
1172
|
+
docstring_format=docstring_format,
|
|
1173
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
1174
|
+
schema_generator=schema_generator,
|
|
1175
|
+
strict=strict,
|
|
1176
|
+
sequential=sequential,
|
|
1177
|
+
requires_approval=requires_approval,
|
|
1178
|
+
metadata=metadata,
|
|
1187
1179
|
)
|
|
1188
1180
|
return func_
|
|
1189
1181
|
|
pydantic_ai/agent/abstract.py
CHANGED
|
@@ -499,12 +499,13 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
499
499
|
]
|
|
500
500
|
|
|
501
501
|
parts: list[_messages.ModelRequestPart] = []
|
|
502
|
-
async for _event in _agent_graph.
|
|
503
|
-
graph_ctx.deps.tool_manager,
|
|
504
|
-
tool_calls,
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
502
|
+
async for _event in _agent_graph.process_tool_calls(
|
|
503
|
+
tool_manager=graph_ctx.deps.tool_manager,
|
|
504
|
+
tool_calls=tool_calls,
|
|
505
|
+
tool_call_results=None,
|
|
506
|
+
final_result=final_result,
|
|
507
|
+
ctx=graph_ctx,
|
|
508
|
+
output_parts=parts,
|
|
508
509
|
):
|
|
509
510
|
pass
|
|
510
511
|
if parts:
|
|
@@ -621,7 +622,6 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
621
622
|
[
|
|
622
623
|
UserPromptNode(
|
|
623
624
|
user_prompt='What is the capital of France?',
|
|
624
|
-
instructions=None,
|
|
625
625
|
instructions_functions=[],
|
|
626
626
|
system_prompts=(),
|
|
627
627
|
system_prompt_functions=[],
|
pydantic_ai/agent/wrapper.py
CHANGED
pydantic_ai/builtin_tools.py
CHANGED
|
@@ -26,8 +26,9 @@ class WebSearchTool(AbstractBuiltinTool):
|
|
|
26
26
|
The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models.
|
|
27
27
|
|
|
28
28
|
Supported by:
|
|
29
|
+
|
|
29
30
|
* Anthropic
|
|
30
|
-
* OpenAI
|
|
31
|
+
* OpenAI Responses
|
|
31
32
|
* Groq
|
|
32
33
|
* Google
|
|
33
34
|
"""
|
|
@@ -36,15 +37,17 @@ class WebSearchTool(AbstractBuiltinTool):
|
|
|
36
37
|
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response.
|
|
37
38
|
|
|
38
39
|
Supported by:
|
|
39
|
-
|
|
40
|
+
|
|
41
|
+
* OpenAI Responses
|
|
40
42
|
"""
|
|
41
43
|
|
|
42
44
|
user_location: WebSearchUserLocation | None = None
|
|
43
45
|
"""The `user_location` parameter allows you to localize search results based on a user's location.
|
|
44
46
|
|
|
45
47
|
Supported by:
|
|
48
|
+
|
|
46
49
|
* Anthropic
|
|
47
|
-
* OpenAI
|
|
50
|
+
* OpenAI Responses
|
|
48
51
|
"""
|
|
49
52
|
|
|
50
53
|
blocked_domains: list[str] | None = None
|
|
@@ -53,8 +56,9 @@ class WebSearchTool(AbstractBuiltinTool):
|
|
|
53
56
|
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.
|
|
54
57
|
|
|
55
58
|
Supported by:
|
|
56
|
-
|
|
57
|
-
*
|
|
59
|
+
|
|
60
|
+
* Anthropic, see <https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering>
|
|
61
|
+
* Groq, see <https://console.groq.com/docs/agentic-tooling#search-settings>
|
|
58
62
|
"""
|
|
59
63
|
|
|
60
64
|
allowed_domains: list[str] | None = None
|
|
@@ -63,14 +67,16 @@ class WebSearchTool(AbstractBuiltinTool):
|
|
|
63
67
|
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.
|
|
64
68
|
|
|
65
69
|
Supported by:
|
|
66
|
-
|
|
67
|
-
*
|
|
70
|
+
|
|
71
|
+
* Anthropic, see <https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering>
|
|
72
|
+
* Groq, see <https://console.groq.com/docs/agentic-tooling#search-settings>
|
|
68
73
|
"""
|
|
69
74
|
|
|
70
75
|
max_uses: int | None = None
|
|
71
76
|
"""If provided, the tool will stop searching the web after the given number of uses.
|
|
72
77
|
|
|
73
78
|
Supported by:
|
|
79
|
+
|
|
74
80
|
* Anthropic
|
|
75
81
|
"""
|
|
76
82
|
|
|
@@ -79,8 +85,9 @@ class WebSearchUserLocation(TypedDict, total=False):
|
|
|
79
85
|
"""Allows you to localize search results based on a user's location.
|
|
80
86
|
|
|
81
87
|
Supported by:
|
|
88
|
+
|
|
82
89
|
* Anthropic
|
|
83
|
-
* OpenAI
|
|
90
|
+
* OpenAI Responses
|
|
84
91
|
"""
|
|
85
92
|
|
|
86
93
|
city: str
|
|
@@ -100,8 +107,9 @@ class CodeExecutionTool(AbstractBuiltinTool):
|
|
|
100
107
|
"""A builtin tool that allows your agent to execute code.
|
|
101
108
|
|
|
102
109
|
Supported by:
|
|
110
|
+
|
|
103
111
|
* Anthropic
|
|
104
|
-
* OpenAI
|
|
112
|
+
* OpenAI Responses
|
|
105
113
|
* Google
|
|
106
114
|
"""
|
|
107
115
|
|
|
@@ -110,5 +118,6 @@ class UrlContextTool(AbstractBuiltinTool):
|
|
|
110
118
|
"""Allows your agent to access contents from URLs.
|
|
111
119
|
|
|
112
120
|
Supported by:
|
|
121
|
+
|
|
113
122
|
* Google
|
|
114
123
|
"""
|
|
@@ -15,7 +15,6 @@ from pydantic_ai import (
|
|
|
15
15
|
)
|
|
16
16
|
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
|
|
17
17
|
from pydantic_ai.exceptions import UserError
|
|
18
|
-
from pydantic_ai.mcp import MCPServer
|
|
19
18
|
from pydantic_ai.models import Model
|
|
20
19
|
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
21
20
|
from pydantic_ai.result import StreamedRunResult
|
|
@@ -29,7 +28,6 @@ from pydantic_ai.tools import (
|
|
|
29
28
|
)
|
|
30
29
|
from pydantic_ai.toolsets import AbstractToolset
|
|
31
30
|
|
|
32
|
-
from ._mcp_server import DBOSMCPServer
|
|
33
31
|
from ._model import DBOSModel
|
|
34
32
|
from ._utils import StepConfig
|
|
35
33
|
|
|
@@ -86,14 +84,21 @@ class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT], DBOSConfiguredInstance):
|
|
|
86
84
|
|
|
87
85
|
def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
|
|
88
86
|
# Replace MCPServer with DBOSMCPServer
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
87
|
+
try:
|
|
88
|
+
from pydantic_ai.mcp import MCPServer
|
|
89
|
+
|
|
90
|
+
from ._mcp_server import DBOSMCPServer
|
|
91
|
+
except ImportError:
|
|
92
|
+
pass
|
|
95
93
|
else:
|
|
96
|
-
|
|
94
|
+
if isinstance(toolset, MCPServer):
|
|
95
|
+
return DBOSMCPServer(
|
|
96
|
+
wrapped=toolset,
|
|
97
|
+
step_name_prefix=dbosagent_name,
|
|
98
|
+
step_config=self._mcp_step_config,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return toolset
|
|
97
102
|
|
|
98
103
|
dbos_toolsets = [toolset.visit_and_replace(dbosify_toolset) for toolset in wrapped.toolsets]
|
|
99
104
|
self._toolsets = dbos_toolsets
|
|
@@ -622,7 +627,6 @@ class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT], DBOSConfiguredInstance):
|
|
|
622
627
|
[
|
|
623
628
|
UserPromptNode(
|
|
624
629
|
user_prompt='What is the capital of France?',
|
|
625
|
-
instructions=None,
|
|
626
630
|
instructions_functions=[],
|
|
627
631
|
system_prompts=(),
|
|
628
632
|
system_prompt_functions=[],
|
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC
|
|
4
4
|
from collections.abc import Callable
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
7
|
from dbos import DBOS
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
|
-
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
11
10
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
11
|
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
|
|
13
12
|
from pydantic_ai.toolsets.wrapper import WrapperToolset
|
|
14
13
|
|
|
15
14
|
from ._utils import StepConfig
|
|
16
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
|
|
19
21
|
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
5
|
-
from logfire import Logfire
|
|
6
|
-
from opentelemetry.trace import get_tracer
|
|
7
6
|
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
8
|
-
from temporalio.contrib.opentelemetry import TracingInterceptor
|
|
9
7
|
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
|
|
10
8
|
from temporalio.service import ConnectConfig, ServiceClient
|
|
11
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from logfire import Logfire
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
def _default_setup_logfire() -> Logfire:
|
|
14
15
|
import logfire
|
|
@@ -22,6 +23,14 @@ class LogfirePlugin(ClientPlugin):
|
|
|
22
23
|
"""Temporal client plugin for Logfire."""
|
|
23
24
|
|
|
24
25
|
def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
|
|
26
|
+
try:
|
|
27
|
+
import logfire # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
28
|
+
except ImportError as _import_error:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
'Please install the `logfire` package to use the Logfire plugin, '
|
|
31
|
+
'you can use the `logfire` optional group — `pip install "pydantic-ai-slim[logfire]"`'
|
|
32
|
+
) from _import_error
|
|
33
|
+
|
|
25
34
|
self.setup_logfire = setup_logfire
|
|
26
35
|
self.metrics = metrics
|
|
27
36
|
|
|
@@ -29,6 +38,9 @@ class LogfirePlugin(ClientPlugin):
|
|
|
29
38
|
self.next_client_plugin = next
|
|
30
39
|
|
|
31
40
|
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
41
|
+
from opentelemetry.trace import get_tracer
|
|
42
|
+
from temporalio.contrib.opentelemetry import TracingInterceptor
|
|
43
|
+
|
|
32
44
|
interceptors = config.get('interceptors', [])
|
|
33
45
|
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
34
46
|
return self.next_client_plugin.configure_client(config)
|
|
@@ -6,7 +6,6 @@ from typing import Any, Literal
|
|
|
6
6
|
|
|
7
7
|
from temporalio.workflow import ActivityConfig
|
|
8
8
|
|
|
9
|
-
from pydantic_ai.mcp import MCPServer
|
|
10
9
|
from pydantic_ai.tools import AgentDepsT
|
|
11
10
|
from pydantic_ai.toolsets.abstract import AbstractToolset
|
|
12
11
|
from pydantic_ai.toolsets.function import FunctionToolset
|
|
@@ -63,16 +62,22 @@ def temporalize_toolset(
|
|
|
63
62
|
deps_type=deps_type,
|
|
64
63
|
run_context_type=run_context_type,
|
|
65
64
|
)
|
|
66
|
-
elif isinstance(toolset, MCPServer):
|
|
67
|
-
from ._mcp_server import TemporalMCPServer
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
run_context_type=run_context_type,
|
|
76
|
-
)
|
|
66
|
+
try:
|
|
67
|
+
from pydantic_ai.mcp import MCPServer
|
|
68
|
+
|
|
69
|
+
from ._mcp_server import TemporalMCPServer
|
|
70
|
+
except ImportError:
|
|
71
|
+
pass
|
|
77
72
|
else:
|
|
78
|
-
|
|
73
|
+
if isinstance(toolset, MCPServer):
|
|
74
|
+
return TemporalMCPServer(
|
|
75
|
+
toolset,
|
|
76
|
+
activity_name_prefix=activity_name_prefix,
|
|
77
|
+
activity_config=activity_config,
|
|
78
|
+
tool_activity_config=tool_activity_config,
|
|
79
|
+
deps_type=deps_type,
|
|
80
|
+
run_context_type=run_context_type,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return toolset
|
pydantic_ai/mcp.py
CHANGED
|
@@ -10,12 +10,14 @@ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontext
|
|
|
10
10
|
from dataclasses import field, replace
|
|
11
11
|
from datetime import timedelta
|
|
12
12
|
from pathlib import Path
|
|
13
|
-
from typing import Any
|
|
13
|
+
from typing import Annotated, Any
|
|
14
14
|
|
|
15
15
|
import anyio
|
|
16
16
|
import httpx
|
|
17
17
|
import pydantic_core
|
|
18
18
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
19
|
+
from pydantic import BaseModel, Discriminator, Field, Tag
|
|
20
|
+
from pydantic_core import CoreSchema, core_schema
|
|
19
21
|
from typing_extensions import Self, assert_never, deprecated
|
|
20
22
|
|
|
21
23
|
from pydantic_ai.tools import RunContext, ToolDefinition
|
|
@@ -41,7 +43,7 @@ except ImportError as _import_error:
|
|
|
41
43
|
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
42
44
|
from . import _mcp, _utils, exceptions, messages, models
|
|
43
45
|
|
|
44
|
-
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
46
|
+
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP', 'load_mcp_servers'
|
|
45
47
|
|
|
46
48
|
TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator(
|
|
47
49
|
schema=pydantic_core.core_schema.dict_schema(
|
|
@@ -254,6 +256,11 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
254
256
|
name=name,
|
|
255
257
|
description=mcp_tool.description,
|
|
256
258
|
parameters_json_schema=mcp_tool.inputSchema,
|
|
259
|
+
metadata={
|
|
260
|
+
'meta': mcp_tool.meta,
|
|
261
|
+
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
|
|
262
|
+
'output_schema': mcp_tool.outputSchema or None,
|
|
263
|
+
},
|
|
257
264
|
),
|
|
258
265
|
)
|
|
259
266
|
for mcp_tool in await self.list_tools()
|
|
@@ -498,6 +505,22 @@ class MCPServerStdio(MCPServer):
|
|
|
498
505
|
id=id,
|
|
499
506
|
)
|
|
500
507
|
|
|
508
|
+
@classmethod
|
|
509
|
+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> CoreSchema:
|
|
510
|
+
return core_schema.no_info_after_validator_function(
|
|
511
|
+
lambda dct: MCPServerStdio(**dct),
|
|
512
|
+
core_schema.typed_dict_schema(
|
|
513
|
+
{
|
|
514
|
+
'command': core_schema.typed_dict_field(core_schema.str_schema()),
|
|
515
|
+
'args': core_schema.typed_dict_field(core_schema.list_schema(core_schema.str_schema())),
|
|
516
|
+
'env': core_schema.typed_dict_field(
|
|
517
|
+
core_schema.dict_schema(core_schema.str_schema(), core_schema.str_schema()),
|
|
518
|
+
required=False,
|
|
519
|
+
),
|
|
520
|
+
}
|
|
521
|
+
),
|
|
522
|
+
)
|
|
523
|
+
|
|
501
524
|
@asynccontextmanager
|
|
502
525
|
async def client_streams(
|
|
503
526
|
self,
|
|
@@ -520,6 +543,16 @@ class MCPServerStdio(MCPServer):
|
|
|
520
543
|
repr_args.append(f'id={self.id!r}')
|
|
521
544
|
return f'{self.__class__.__name__}({", ".join(repr_args)})'
|
|
522
545
|
|
|
546
|
+
def __eq__(self, value: object, /) -> bool:
|
|
547
|
+
if not isinstance(value, MCPServerStdio):
|
|
548
|
+
return False # pragma: no cover
|
|
549
|
+
return (
|
|
550
|
+
self.command == value.command
|
|
551
|
+
and self.args == value.args
|
|
552
|
+
and self.env == value.env
|
|
553
|
+
and self.cwd == value.cwd
|
|
554
|
+
)
|
|
555
|
+
|
|
523
556
|
|
|
524
557
|
class _MCPServerHTTP(MCPServer):
|
|
525
558
|
url: str
|
|
@@ -733,10 +766,29 @@ class MCPServerSSE(_MCPServerHTTP):
|
|
|
733
766
|
1. This will connect to a server running on `localhost:3001`.
|
|
734
767
|
"""
|
|
735
768
|
|
|
769
|
+
@classmethod
|
|
770
|
+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> CoreSchema:
|
|
771
|
+
return core_schema.no_info_after_validator_function(
|
|
772
|
+
lambda dct: MCPServerSSE(**dct),
|
|
773
|
+
core_schema.typed_dict_schema(
|
|
774
|
+
{
|
|
775
|
+
'url': core_schema.typed_dict_field(core_schema.str_schema()),
|
|
776
|
+
'headers': core_schema.typed_dict_field(
|
|
777
|
+
core_schema.dict_schema(core_schema.str_schema(), core_schema.str_schema()), required=False
|
|
778
|
+
),
|
|
779
|
+
}
|
|
780
|
+
),
|
|
781
|
+
)
|
|
782
|
+
|
|
736
783
|
@property
|
|
737
784
|
def _transport_client(self):
|
|
738
785
|
return sse_client # pragma: no cover
|
|
739
786
|
|
|
787
|
+
def __eq__(self, value: object, /) -> bool:
|
|
788
|
+
if not isinstance(value, MCPServerSSE):
|
|
789
|
+
return False # pragma: no cover
|
|
790
|
+
return self.url == value.url
|
|
791
|
+
|
|
740
792
|
|
|
741
793
|
@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
|
|
742
794
|
class MCPServerHTTP(MCPServerSSE):
|
|
@@ -790,10 +842,29 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
790
842
|
```
|
|
791
843
|
"""
|
|
792
844
|
|
|
845
|
+
@classmethod
|
|
846
|
+
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> CoreSchema:
|
|
847
|
+
return core_schema.no_info_after_validator_function(
|
|
848
|
+
lambda dct: MCPServerStreamableHTTP(**dct),
|
|
849
|
+
core_schema.typed_dict_schema(
|
|
850
|
+
{
|
|
851
|
+
'url': core_schema.typed_dict_field(core_schema.str_schema()),
|
|
852
|
+
'headers': core_schema.typed_dict_field(
|
|
853
|
+
core_schema.dict_schema(core_schema.str_schema(), core_schema.str_schema()), required=False
|
|
854
|
+
),
|
|
855
|
+
}
|
|
856
|
+
),
|
|
857
|
+
)
|
|
858
|
+
|
|
793
859
|
@property
|
|
794
860
|
def _transport_client(self):
|
|
795
861
|
return streamablehttp_client # pragma: no cover
|
|
796
862
|
|
|
863
|
+
def __eq__(self, value: object, /) -> bool:
|
|
864
|
+
if not isinstance(value, MCPServerStreamableHTTP):
|
|
865
|
+
return False # pragma: no cover
|
|
866
|
+
return self.url == value.url
|
|
867
|
+
|
|
797
868
|
|
|
798
869
|
ToolResult = (
|
|
799
870
|
str
|
|
@@ -823,3 +894,50 @@ It accepts a run context, the original tool call function, a tool name, and argu
|
|
|
823
894
|
Allows wrapping an MCP server tool call to customize it, including adding extra request
|
|
824
895
|
metadata.
|
|
825
896
|
"""
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def _mcp_server_discriminator(value: dict[str, Any]) -> str | None:
|
|
900
|
+
if 'url' in value:
|
|
901
|
+
if value['url'].endswith('/sse'):
|
|
902
|
+
return 'sse'
|
|
903
|
+
return 'streamable-http'
|
|
904
|
+
return 'stdio'
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
class MCPServerConfig(BaseModel):
|
|
908
|
+
"""Configuration for MCP servers."""
|
|
909
|
+
|
|
910
|
+
mcp_servers: Annotated[
|
|
911
|
+
dict[
|
|
912
|
+
str,
|
|
913
|
+
Annotated[
|
|
914
|
+
Annotated[MCPServerStdio, Tag('stdio')]
|
|
915
|
+
| Annotated[MCPServerStreamableHTTP, Tag('streamable-http')]
|
|
916
|
+
| Annotated[MCPServerSSE, Tag('sse')],
|
|
917
|
+
Discriminator(_mcp_server_discriminator),
|
|
918
|
+
],
|
|
919
|
+
],
|
|
920
|
+
Field(alias='mcpServers'),
|
|
921
|
+
]
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def load_mcp_servers(config_path: str | Path) -> list[MCPServerStdio | MCPServerStreamableHTTP | MCPServerSSE]:
|
|
925
|
+
"""Load MCP servers from a configuration file.
|
|
926
|
+
|
|
927
|
+
Args:
|
|
928
|
+
config_path: The path to the configuration file.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
A list of MCP servers.
|
|
932
|
+
|
|
933
|
+
Raises:
|
|
934
|
+
FileNotFoundError: If the configuration file does not exist.
|
|
935
|
+
ValidationError: If the configuration file does not match the schema.
|
|
936
|
+
"""
|
|
937
|
+
config_path = Path(config_path)
|
|
938
|
+
|
|
939
|
+
if not config_path.exists():
|
|
940
|
+
raise FileNotFoundError(f'Config file {config_path} not found')
|
|
941
|
+
|
|
942
|
+
config = MCPServerConfig.model_validate_json(config_path.read_bytes())
|
|
943
|
+
return list(config.mcp_servers.values())
|