pydantic-ai-slim 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +229 -134
- pydantic_ai/ag_ui.py +51 -40
- pydantic_ai/agent/__init__.py +35 -45
- pydantic_ai/agent/abstract.py +7 -7
- pydantic_ai/agent/wrapper.py +0 -1
- pydantic_ai/durable_exec/dbos/_agent.py +14 -10
- pydantic_ai/durable_exec/dbos/_mcp_server.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +0 -1
- pydantic_ai/durable_exec/temporal/_logfire.py +15 -3
- pydantic_ai/durable_exec/temporal/_toolset.py +17 -12
- pydantic_ai/mcp.py +5 -0
- pydantic_ai/models/__init__.py +4 -6
- pydantic_ai/result.py +3 -5
- pydantic_ai/run.py +0 -2
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +50 -9
- pydantic_ai/usage.py +2 -2
- {pydantic_ai_slim-1.0.6.dist-info → pydantic_ai_slim-1.0.8.dist-info}/METADATA +3 -3
- {pydantic_ai_slim-1.0.6.dist-info → pydantic_ai_slim-1.0.8.dist-info}/RECORD +22 -22
- {pydantic_ai_slim-1.0.6.dist-info → pydantic_ai_slim-1.0.8.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.6.dist-info → pydantic_ai_slim-1.0.8.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.6.dist-info → pydantic_ai_slim-1.0.8.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/ag_ui.py
CHANGED
|
@@ -23,6 +23,7 @@ from typing import (
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
from pydantic import BaseModel, ValidationError
|
|
26
|
+
from typing_extensions import assert_never
|
|
26
27
|
|
|
27
28
|
from . import _utils
|
|
28
29
|
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
@@ -32,7 +33,9 @@ from .messages import (
|
|
|
32
33
|
FunctionToolResultEvent,
|
|
33
34
|
ModelMessage,
|
|
34
35
|
ModelRequest,
|
|
36
|
+
ModelRequestPart,
|
|
35
37
|
ModelResponse,
|
|
38
|
+
ModelResponsePart,
|
|
36
39
|
ModelResponseStreamEvent,
|
|
37
40
|
PartDeltaEvent,
|
|
38
41
|
PartStartEvent,
|
|
@@ -556,15 +559,15 @@ async def _handle_tool_result_event(
|
|
|
556
559
|
content=result.model_response_str(),
|
|
557
560
|
)
|
|
558
561
|
|
|
559
|
-
# Now check for
|
|
560
|
-
|
|
561
|
-
if isinstance(
|
|
562
|
-
yield
|
|
563
|
-
elif isinstance(
|
|
562
|
+
# Now check for AG-UI events returned by the tool calls.
|
|
563
|
+
possible_event = result.metadata or result.content
|
|
564
|
+
if isinstance(possible_event, BaseEvent):
|
|
565
|
+
yield possible_event
|
|
566
|
+
elif isinstance(possible_event, str | bytes): # pragma: no branch
|
|
564
567
|
# Avoid iterable check for strings and bytes.
|
|
565
568
|
pass
|
|
566
|
-
elif isinstance(
|
|
567
|
-
for item in
|
|
569
|
+
elif isinstance(possible_event, Iterable): # pragma: no branch
|
|
570
|
+
for item in possible_event: # type: ignore[reportUnknownMemberType]
|
|
568
571
|
if isinstance(item, BaseEvent): # pragma: no branch
|
|
569
572
|
yield item
|
|
570
573
|
|
|
@@ -573,49 +576,57 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
|
573
576
|
"""Convert a AG-UI history to a Pydantic AI one."""
|
|
574
577
|
result: list[ModelMessage] = []
|
|
575
578
|
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
579
|
+
request_parts: list[ModelRequestPart] | None = None
|
|
580
|
+
response_parts: list[ModelResponsePart] | None = None
|
|
576
581
|
for msg in messages:
|
|
577
|
-
if isinstance(msg, UserMessage):
|
|
578
|
-
|
|
582
|
+
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage | ToolMessage):
|
|
583
|
+
if request_parts is None:
|
|
584
|
+
request_parts = []
|
|
585
|
+
result.append(ModelRequest(parts=request_parts))
|
|
586
|
+
response_parts = None
|
|
587
|
+
|
|
588
|
+
if isinstance(msg, UserMessage):
|
|
589
|
+
request_parts.append(UserPromptPart(content=msg.content))
|
|
590
|
+
elif isinstance(msg, SystemMessage | DeveloperMessage):
|
|
591
|
+
request_parts.append(SystemPromptPart(content=msg.content))
|
|
592
|
+
elif isinstance(msg, ToolMessage):
|
|
593
|
+
tool_name = tool_calls.get(msg.tool_call_id)
|
|
594
|
+
if tool_name is None: # pragma: no cover
|
|
595
|
+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
596
|
+
|
|
597
|
+
request_parts.append(
|
|
598
|
+
ToolReturnPart(
|
|
599
|
+
tool_name=tool_name,
|
|
600
|
+
content=msg.content,
|
|
601
|
+
tool_call_id=msg.tool_call_id,
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
assert_never(msg)
|
|
606
|
+
|
|
579
607
|
elif isinstance(msg, AssistantMessage):
|
|
608
|
+
if response_parts is None:
|
|
609
|
+
response_parts = []
|
|
610
|
+
result.append(ModelResponse(parts=response_parts))
|
|
611
|
+
request_parts = None
|
|
612
|
+
|
|
580
613
|
if msg.content:
|
|
581
|
-
|
|
614
|
+
response_parts.append(TextPart(content=msg.content))
|
|
582
615
|
|
|
583
616
|
if msg.tool_calls:
|
|
584
617
|
for tool_call in msg.tool_calls:
|
|
585
618
|
tool_calls[tool_call.id] = tool_call.function.name
|
|
586
619
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
tool_call_id=tool_call.id,
|
|
593
|
-
args=tool_call.function.arguments,
|
|
594
|
-
)
|
|
595
|
-
for tool_call in msg.tool_calls
|
|
596
|
-
]
|
|
620
|
+
response_parts.extend(
|
|
621
|
+
ToolCallPart(
|
|
622
|
+
tool_name=tool_call.function.name,
|
|
623
|
+
tool_call_id=tool_call.id,
|
|
624
|
+
args=tool_call.function.arguments,
|
|
597
625
|
)
|
|
626
|
+
for tool_call in msg.tool_calls
|
|
598
627
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
elif isinstance(msg, ToolMessage):
|
|
602
|
-
tool_name = tool_calls.get(msg.tool_call_id)
|
|
603
|
-
if tool_name is None: # pragma: no cover
|
|
604
|
-
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
605
|
-
|
|
606
|
-
result.append(
|
|
607
|
-
ModelRequest(
|
|
608
|
-
parts=[
|
|
609
|
-
ToolReturnPart(
|
|
610
|
-
tool_name=tool_name,
|
|
611
|
-
content=msg.content,
|
|
612
|
-
tool_call_id=msg.tool_call_id,
|
|
613
|
-
)
|
|
614
|
-
]
|
|
615
|
-
)
|
|
616
|
-
)
|
|
617
|
-
elif isinstance(msg, DeveloperMessage): # pragma: no branch
|
|
618
|
-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
628
|
+
else:
|
|
629
|
+
assert_never(msg)
|
|
619
630
|
|
|
620
631
|
return result
|
|
621
632
|
|
pydantic_ai/agent/__init__.py
CHANGED
|
@@ -45,15 +45,11 @@ from ..run import AgentRun, AgentRunResult
|
|
|
45
45
|
from ..settings import ModelSettings, merge_model_settings
|
|
46
46
|
from ..tools import (
|
|
47
47
|
AgentDepsT,
|
|
48
|
-
DeferredToolCallResult,
|
|
49
|
-
DeferredToolResult,
|
|
50
48
|
DeferredToolResults,
|
|
51
49
|
DocstringFormat,
|
|
52
50
|
GenerateToolJsonSchema,
|
|
53
51
|
RunContext,
|
|
54
52
|
Tool,
|
|
55
|
-
ToolApproved,
|
|
56
|
-
ToolDenied,
|
|
57
53
|
ToolFuncContext,
|
|
58
54
|
ToolFuncEither,
|
|
59
55
|
ToolFuncPlain,
|
|
@@ -462,7 +458,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
462
458
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
463
459
|
|
|
464
460
|
@asynccontextmanager
|
|
465
|
-
async def iter(
|
|
461
|
+
async def iter(
|
|
466
462
|
self,
|
|
467
463
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
468
464
|
*,
|
|
@@ -505,7 +501,6 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
505
501
|
[
|
|
506
502
|
UserPromptNode(
|
|
507
503
|
user_prompt='What is the capital of France?',
|
|
508
|
-
instructions=None,
|
|
509
504
|
instructions_functions=[],
|
|
510
505
|
system_prompts=(),
|
|
511
506
|
system_prompt_functions=[],
|
|
@@ -559,7 +554,6 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
559
554
|
del model
|
|
560
555
|
|
|
561
556
|
deps = self._get_deps(deps)
|
|
562
|
-
new_message_index = len(message_history) if message_history else 0
|
|
563
557
|
output_schema = self._prepare_output_schema(output_type, model_used.profile)
|
|
564
558
|
|
|
565
559
|
output_type_ = output_type or self.output_type
|
|
@@ -620,27 +614,10 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
620
614
|
instrumentation_settings = None
|
|
621
615
|
tracer = NoOpTracer()
|
|
622
616
|
|
|
623
|
-
tool_call_results: dict[str, DeferredToolResult] | None = None
|
|
624
|
-
if deferred_tool_results is not None:
|
|
625
|
-
tool_call_results = {}
|
|
626
|
-
for tool_call_id, approval in deferred_tool_results.approvals.items():
|
|
627
|
-
if approval is True:
|
|
628
|
-
approval = ToolApproved()
|
|
629
|
-
elif approval is False:
|
|
630
|
-
approval = ToolDenied()
|
|
631
|
-
tool_call_results[tool_call_id] = approval
|
|
632
|
-
|
|
633
|
-
if calls := deferred_tool_results.calls:
|
|
634
|
-
call_result_types = _utils.get_union_args(DeferredToolCallResult)
|
|
635
|
-
for tool_call_id, result in calls.items():
|
|
636
|
-
if not isinstance(result, call_result_types):
|
|
637
|
-
result = _messages.ToolReturn(result)
|
|
638
|
-
tool_call_results[tool_call_id] = result
|
|
639
|
-
|
|
640
617
|
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
641
618
|
user_deps=deps,
|
|
642
619
|
prompt=user_prompt,
|
|
643
|
-
new_message_index=
|
|
620
|
+
new_message_index=len(message_history) if message_history else 0,
|
|
644
621
|
model=model_used,
|
|
645
622
|
model_settings=model_settings,
|
|
646
623
|
usage_limits=usage_limits,
|
|
@@ -651,13 +628,13 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
651
628
|
history_processors=self.history_processors,
|
|
652
629
|
builtin_tools=list(self._builtin_tools),
|
|
653
630
|
tool_manager=tool_manager,
|
|
654
|
-
tool_call_results=tool_call_results,
|
|
655
631
|
tracer=tracer,
|
|
656
632
|
get_instructions=get_instructions,
|
|
657
633
|
instrumentation_settings=instrumentation_settings,
|
|
658
634
|
)
|
|
659
635
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
660
636
|
user_prompt=user_prompt,
|
|
637
|
+
deferred_tool_results=deferred_tool_results,
|
|
661
638
|
instructions=self._instructions,
|
|
662
639
|
instructions_functions=self._instructions_functions,
|
|
663
640
|
system_prompts=self._system_prompts,
|
|
@@ -1005,7 +982,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1005
982
|
require_parameter_descriptions: bool = False,
|
|
1006
983
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1007
984
|
strict: bool | None = None,
|
|
985
|
+
sequential: bool = False,
|
|
1008
986
|
requires_approval: bool = False,
|
|
987
|
+
metadata: dict[str, Any] | None = None,
|
|
1009
988
|
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
1010
989
|
|
|
1011
990
|
def tool(
|
|
@@ -1020,7 +999,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1020
999
|
require_parameter_descriptions: bool = False,
|
|
1021
1000
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1022
1001
|
strict: bool | None = None,
|
|
1002
|
+
sequential: bool = False,
|
|
1023
1003
|
requires_approval: bool = False,
|
|
1004
|
+
metadata: dict[str, Any] | None = None,
|
|
1024
1005
|
) -> Any:
|
|
1025
1006
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
1026
1007
|
|
|
@@ -1065,8 +1046,10 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1065
1046
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1066
1047
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1067
1048
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1049
|
+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
|
|
1068
1050
|
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1069
1051
|
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1052
|
+
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
|
|
1070
1053
|
"""
|
|
1071
1054
|
|
|
1072
1055
|
def tool_decorator(
|
|
@@ -1075,15 +1058,17 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1075
1058
|
# noinspection PyTypeChecker
|
|
1076
1059
|
self._function_toolset.add_function(
|
|
1077
1060
|
func_,
|
|
1078
|
-
True,
|
|
1079
|
-
name,
|
|
1080
|
-
retries,
|
|
1081
|
-
prepare,
|
|
1082
|
-
docstring_format,
|
|
1083
|
-
require_parameter_descriptions,
|
|
1084
|
-
schema_generator,
|
|
1085
|
-
strict,
|
|
1086
|
-
|
|
1061
|
+
takes_ctx=True,
|
|
1062
|
+
name=name,
|
|
1063
|
+
retries=retries,
|
|
1064
|
+
prepare=prepare,
|
|
1065
|
+
docstring_format=docstring_format,
|
|
1066
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
1067
|
+
schema_generator=schema_generator,
|
|
1068
|
+
strict=strict,
|
|
1069
|
+
sequential=sequential,
|
|
1070
|
+
requires_approval=requires_approval,
|
|
1071
|
+
metadata=metadata,
|
|
1087
1072
|
)
|
|
1088
1073
|
return func_
|
|
1089
1074
|
|
|
@@ -1104,7 +1089,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1104
1089
|
require_parameter_descriptions: bool = False,
|
|
1105
1090
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1106
1091
|
strict: bool | None = None,
|
|
1092
|
+
sequential: bool = False,
|
|
1107
1093
|
requires_approval: bool = False,
|
|
1094
|
+
metadata: dict[str, Any] | None = None,
|
|
1108
1095
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
1109
1096
|
|
|
1110
1097
|
def tool_plain(
|
|
@@ -1121,6 +1108,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1121
1108
|
strict: bool | None = None,
|
|
1122
1109
|
sequential: bool = False,
|
|
1123
1110
|
requires_approval: bool = False,
|
|
1111
|
+
metadata: dict[str, Any] | None = None,
|
|
1124
1112
|
) -> Any:
|
|
1125
1113
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
1126
1114
|
|
|
@@ -1168,22 +1156,24 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
1168
1156
|
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
|
|
1169
1157
|
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
|
|
1170
1158
|
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
|
|
1159
|
+
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
|
|
1171
1160
|
"""
|
|
1172
1161
|
|
|
1173
1162
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
1174
1163
|
# noinspection PyTypeChecker
|
|
1175
1164
|
self._function_toolset.add_function(
|
|
1176
1165
|
func_,
|
|
1177
|
-
False,
|
|
1178
|
-
name,
|
|
1179
|
-
retries,
|
|
1180
|
-
prepare,
|
|
1181
|
-
docstring_format,
|
|
1182
|
-
require_parameter_descriptions,
|
|
1183
|
-
schema_generator,
|
|
1184
|
-
strict,
|
|
1185
|
-
sequential,
|
|
1186
|
-
requires_approval,
|
|
1166
|
+
takes_ctx=False,
|
|
1167
|
+
name=name,
|
|
1168
|
+
retries=retries,
|
|
1169
|
+
prepare=prepare,
|
|
1170
|
+
docstring_format=docstring_format,
|
|
1171
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
1172
|
+
schema_generator=schema_generator,
|
|
1173
|
+
strict=strict,
|
|
1174
|
+
sequential=sequential,
|
|
1175
|
+
requires_approval=requires_approval,
|
|
1176
|
+
metadata=metadata,
|
|
1187
1177
|
)
|
|
1188
1178
|
return func_
|
|
1189
1179
|
|
pydantic_ai/agent/abstract.py
CHANGED
|
@@ -499,12 +499,13 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
499
499
|
]
|
|
500
500
|
|
|
501
501
|
parts: list[_messages.ModelRequestPart] = []
|
|
502
|
-
async for _event in _agent_graph.
|
|
503
|
-
graph_ctx.deps.tool_manager,
|
|
504
|
-
tool_calls,
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
502
|
+
async for _event in _agent_graph.process_tool_calls(
|
|
503
|
+
tool_manager=graph_ctx.deps.tool_manager,
|
|
504
|
+
tool_calls=tool_calls,
|
|
505
|
+
tool_call_results=None,
|
|
506
|
+
final_result=final_result,
|
|
507
|
+
ctx=graph_ctx,
|
|
508
|
+
output_parts=parts,
|
|
508
509
|
):
|
|
509
510
|
pass
|
|
510
511
|
if parts:
|
|
@@ -621,7 +622,6 @@ class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
|
|
|
621
622
|
[
|
|
622
623
|
UserPromptNode(
|
|
623
624
|
user_prompt='What is the capital of France?',
|
|
624
|
-
instructions=None,
|
|
625
625
|
instructions_functions=[],
|
|
626
626
|
system_prompts=(),
|
|
627
627
|
system_prompt_functions=[],
|
pydantic_ai/agent/wrapper.py
CHANGED
|
@@ -15,7 +15,6 @@ from pydantic_ai import (
|
|
|
15
15
|
)
|
|
16
16
|
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
|
|
17
17
|
from pydantic_ai.exceptions import UserError
|
|
18
|
-
from pydantic_ai.mcp import MCPServer
|
|
19
18
|
from pydantic_ai.models import Model
|
|
20
19
|
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
21
20
|
from pydantic_ai.result import StreamedRunResult
|
|
@@ -29,7 +28,6 @@ from pydantic_ai.tools import (
|
|
|
29
28
|
)
|
|
30
29
|
from pydantic_ai.toolsets import AbstractToolset
|
|
31
30
|
|
|
32
|
-
from ._mcp_server import DBOSMCPServer
|
|
33
31
|
from ._model import DBOSModel
|
|
34
32
|
from ._utils import StepConfig
|
|
35
33
|
|
|
@@ -86,14 +84,21 @@ class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT], DBOSConfiguredInstance):
|
|
|
86
84
|
|
|
87
85
|
def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
|
|
88
86
|
# Replace MCPServer with DBOSMCPServer
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
87
|
+
try:
|
|
88
|
+
from pydantic_ai.mcp import MCPServer
|
|
89
|
+
|
|
90
|
+
from ._mcp_server import DBOSMCPServer
|
|
91
|
+
except ImportError:
|
|
92
|
+
pass
|
|
95
93
|
else:
|
|
96
|
-
|
|
94
|
+
if isinstance(toolset, MCPServer):
|
|
95
|
+
return DBOSMCPServer(
|
|
96
|
+
wrapped=toolset,
|
|
97
|
+
step_name_prefix=dbosagent_name,
|
|
98
|
+
step_config=self._mcp_step_config,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return toolset
|
|
97
102
|
|
|
98
103
|
dbos_toolsets = [toolset.visit_and_replace(dbosify_toolset) for toolset in wrapped.toolsets]
|
|
99
104
|
self._toolsets = dbos_toolsets
|
|
@@ -622,7 +627,6 @@ class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT], DBOSConfiguredInstance):
|
|
|
622
627
|
[
|
|
623
628
|
UserPromptNode(
|
|
624
629
|
user_prompt='What is the capital of France?',
|
|
625
|
-
instructions=None,
|
|
626
630
|
instructions_functions=[],
|
|
627
631
|
system_prompts=(),
|
|
628
632
|
system_prompt_functions=[],
|
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC
|
|
4
4
|
from collections.abc import Callable
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
7
|
from dbos import DBOS
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
|
-
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
11
10
|
from pydantic_ai.tools import AgentDepsT, RunContext
|
|
12
11
|
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
|
|
13
12
|
from pydantic_ai.toolsets.wrapper import WrapperToolset
|
|
14
13
|
|
|
15
14
|
from ._utils import StepConfig
|
|
16
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from pydantic_ai.mcp import MCPServer, ToolResult
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
class DBOSMCPServer(WrapperToolset[AgentDepsT], ABC):
|
|
19
21
|
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
5
|
-
from logfire import Logfire
|
|
6
|
-
from opentelemetry.trace import get_tracer
|
|
7
6
|
from temporalio.client import ClientConfig, Plugin as ClientPlugin
|
|
8
|
-
from temporalio.contrib.opentelemetry import TracingInterceptor
|
|
9
7
|
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
|
|
10
8
|
from temporalio.service import ConnectConfig, ServiceClient
|
|
11
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from logfire import Logfire
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
def _default_setup_logfire() -> Logfire:
|
|
14
15
|
import logfire
|
|
@@ -22,6 +23,14 @@ class LogfirePlugin(ClientPlugin):
|
|
|
22
23
|
"""Temporal client plugin for Logfire."""
|
|
23
24
|
|
|
24
25
|
def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
|
|
26
|
+
try:
|
|
27
|
+
import logfire # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
28
|
+
except ImportError as _import_error:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
'Please install the `logfire` package to use the Logfire plugin, '
|
|
31
|
+
'you can use the `logfire` optional group — `pip install "pydantic-ai-slim[logfire]"`'
|
|
32
|
+
) from _import_error
|
|
33
|
+
|
|
25
34
|
self.setup_logfire = setup_logfire
|
|
26
35
|
self.metrics = metrics
|
|
27
36
|
|
|
@@ -29,6 +38,9 @@ class LogfirePlugin(ClientPlugin):
|
|
|
29
38
|
self.next_client_plugin = next
|
|
30
39
|
|
|
31
40
|
def configure_client(self, config: ClientConfig) -> ClientConfig:
|
|
41
|
+
from opentelemetry.trace import get_tracer
|
|
42
|
+
from temporalio.contrib.opentelemetry import TracingInterceptor
|
|
43
|
+
|
|
32
44
|
interceptors = config.get('interceptors', [])
|
|
33
45
|
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
|
|
34
46
|
return self.next_client_plugin.configure_client(config)
|
|
@@ -6,7 +6,6 @@ from typing import Any, Literal
|
|
|
6
6
|
|
|
7
7
|
from temporalio.workflow import ActivityConfig
|
|
8
8
|
|
|
9
|
-
from pydantic_ai.mcp import MCPServer
|
|
10
9
|
from pydantic_ai.tools import AgentDepsT
|
|
11
10
|
from pydantic_ai.toolsets.abstract import AbstractToolset
|
|
12
11
|
from pydantic_ai.toolsets.function import FunctionToolset
|
|
@@ -63,16 +62,22 @@ def temporalize_toolset(
|
|
|
63
62
|
deps_type=deps_type,
|
|
64
63
|
run_context_type=run_context_type,
|
|
65
64
|
)
|
|
66
|
-
elif isinstance(toolset, MCPServer):
|
|
67
|
-
from ._mcp_server import TemporalMCPServer
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
run_context_type=run_context_type,
|
|
76
|
-
)
|
|
66
|
+
try:
|
|
67
|
+
from pydantic_ai.mcp import MCPServer
|
|
68
|
+
|
|
69
|
+
from ._mcp_server import TemporalMCPServer
|
|
70
|
+
except ImportError:
|
|
71
|
+
pass
|
|
77
72
|
else:
|
|
78
|
-
|
|
73
|
+
if isinstance(toolset, MCPServer):
|
|
74
|
+
return TemporalMCPServer(
|
|
75
|
+
toolset,
|
|
76
|
+
activity_name_prefix=activity_name_prefix,
|
|
77
|
+
activity_config=activity_config,
|
|
78
|
+
tool_activity_config=tool_activity_config,
|
|
79
|
+
deps_type=deps_type,
|
|
80
|
+
run_context_type=run_context_type,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return toolset
|
pydantic_ai/mcp.py
CHANGED
|
@@ -256,6 +256,11 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
256
256
|
name=name,
|
|
257
257
|
description=mcp_tool.description,
|
|
258
258
|
parameters_json_schema=mcp_tool.inputSchema,
|
|
259
|
+
metadata={
|
|
260
|
+
'meta': mcp_tool.meta,
|
|
261
|
+
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
|
|
262
|
+
'output_schema': mcp_tool.outputSchema or None,
|
|
263
|
+
},
|
|
259
264
|
),
|
|
260
265
|
)
|
|
261
266
|
for mcp_tool in await self.list_tools()
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -783,6 +783,8 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
783
783
|
The client is cached based on the provider parameter. If provider is None, it's used for non-provider specific
|
|
784
784
|
requests (like downloading images). Multiple agents and calls can share the same client when they use the same provider.
|
|
785
785
|
|
|
786
|
+
Each client will get its own transport with its own connection pool. The default pool size is defined by `httpx.DEFAULT_LIMITS`.
|
|
787
|
+
|
|
786
788
|
There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as
|
|
787
789
|
described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
|
|
788
790
|
examples, it's very useful not to.
|
|
@@ -793,6 +795,8 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
793
795
|
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
794
796
|
if client.is_closed:
|
|
795
797
|
# This happens if the context manager is used, so we need to create a new client.
|
|
798
|
+
# Since there is no API from `functools.cache` to clear the cache for a specific
|
|
799
|
+
# key, clear the entire cache here as a workaround.
|
|
796
800
|
_cached_async_http_client.cache_clear()
|
|
797
801
|
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
798
802
|
return client
|
|
@@ -801,17 +805,11 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
801
805
|
@cache
|
|
802
806
|
def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
803
807
|
return httpx.AsyncClient(
|
|
804
|
-
transport=_cached_async_http_transport(),
|
|
805
808
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
806
809
|
headers={'User-Agent': get_user_agent()},
|
|
807
810
|
)
|
|
808
811
|
|
|
809
812
|
|
|
810
|
-
@cache
|
|
811
|
-
def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
|
|
812
|
-
return httpx.AsyncHTTPTransport()
|
|
813
|
-
|
|
814
|
-
|
|
815
813
|
DataT = TypeVar('DataT', str, bytes)
|
|
816
814
|
|
|
817
815
|
|
pydantic_ai/result.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
|
4
|
-
from copy import
|
|
4
|
+
from copy import deepcopy
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import Generic, cast, overload
|
|
@@ -56,7 +56,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
56
56
|
_initial_run_ctx_usage: RunUsage = field(init=False)
|
|
57
57
|
|
|
58
58
|
def __post_init__(self):
|
|
59
|
-
self._initial_run_ctx_usage =
|
|
59
|
+
self._initial_run_ctx_usage = deepcopy(self._run_ctx.usage)
|
|
60
60
|
|
|
61
61
|
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
|
|
62
62
|
"""Asynchronously stream the (validated) agent outputs."""
|
|
@@ -322,9 +322,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
322
322
|
self.all_messages(output_tool_return_content=output_tool_return_content)
|
|
323
323
|
)
|
|
324
324
|
|
|
325
|
-
def new_messages(
|
|
326
|
-
self, *, output_tool_return_content: str | None = None
|
|
327
|
-
) -> list[_messages.ModelMessage]: # pragma: no cover
|
|
325
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
328
326
|
"""Return new messages associated with this run.
|
|
329
327
|
|
|
330
328
|
Messages from older runs are excluded.
|
pydantic_ai/run.py
CHANGED
|
@@ -48,7 +48,6 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
48
48
|
[
|
|
49
49
|
UserPromptNode(
|
|
50
50
|
user_prompt='What is the capital of France?',
|
|
51
|
-
instructions=None,
|
|
52
51
|
instructions_functions=[],
|
|
53
52
|
system_prompts=(),
|
|
54
53
|
system_prompt_functions=[],
|
|
@@ -183,7 +182,6 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
183
182
|
[
|
|
184
183
|
UserPromptNode(
|
|
185
184
|
user_prompt='What is the capital of France?',
|
|
186
|
-
instructions=None,
|
|
187
185
|
instructions_functions=[],
|
|
188
186
|
system_prompts=(),
|
|
189
187
|
system_prompt_functions=[],
|