pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +6 -0
- pydantic_ai/_agent_graph.py +67 -20
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_output.py +20 -12
- pydantic_ai/_run_context.py +6 -2
- pydantic_ai/_utils.py +26 -8
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -25
- pydantic_ai/agent/abstract.py +146 -9
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/__init__.py +11 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/exceptions.py +6 -1
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/messages.py +46 -8
- pydantic_ai/models/__init__.py +87 -38
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +26 -23
- pydantic_ai/models/groq.py +13 -5
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +251 -52
- pydantic_ai/models/outlines.py +563 -0
- pydantic_ai/models/test.py +6 -3
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/__init__.py +25 -12
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +91 -24
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/providers/outlines.py +40 -0
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/result.py +173 -8
- pydantic_ai/run.py +40 -24
- pydantic_ai/settings.py +8 -0
- pydantic_ai/tools.py +10 -6
- pydantic_ai/toolsets/fastmcp.py +215 -0
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/result.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
|
|
4
4
|
from copy import deepcopy
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
@@ -35,6 +35,7 @@ __all__ = (
|
|
|
35
35
|
'OutputDataT_inv',
|
|
36
36
|
'ToolOutput',
|
|
37
37
|
'OutputValidatorFunc',
|
|
38
|
+
'StreamedRunResultSync',
|
|
38
39
|
)
|
|
39
40
|
|
|
40
41
|
|
|
@@ -60,14 +61,26 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
60
61
|
|
|
61
62
|
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
|
|
62
63
|
"""Asynchronously stream the (validated) agent outputs."""
|
|
64
|
+
last_response: _messages.ModelResponse | None = None
|
|
63
65
|
async for response in self.stream_responses(debounce_by=debounce_by):
|
|
64
|
-
if self._raw_stream_response.final_result_event is
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
66
|
+
if self._raw_stream_response.final_result_event is None or (
|
|
67
|
+
last_response and response.parts == last_response.parts
|
|
68
|
+
):
|
|
69
|
+
continue
|
|
70
|
+
last_response = response
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
yield await self.validate_response_output(response, allow_partial=True)
|
|
74
|
+
except ValidationError:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
response = self.response
|
|
78
|
+
if self._raw_stream_response.final_result_event is None or (
|
|
79
|
+
last_response and response.parts == last_response.parts
|
|
80
|
+
):
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
yield await self.validate_response_output(response)
|
|
71
84
|
|
|
72
85
|
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
|
|
73
86
|
"""Asynchronously stream the (unvalidated) model responses for the agent."""
|
|
@@ -543,6 +556,158 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
543
556
|
await self._on_complete()
|
|
544
557
|
|
|
545
558
|
|
|
559
|
+
@dataclass(init=False)
|
|
560
|
+
class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]):
|
|
561
|
+
"""Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods."""
|
|
562
|
+
|
|
563
|
+
_streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
|
|
564
|
+
|
|
565
|
+
def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None:
|
|
566
|
+
self._streamed_run_result = streamed_run_result
|
|
567
|
+
|
|
568
|
+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
569
|
+
"""Return the history of messages.
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
573
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
574
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
575
|
+
not be modified.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
List of messages.
|
|
579
|
+
"""
|
|
580
|
+
return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content)
|
|
581
|
+
|
|
582
|
+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
|
|
583
|
+
"""Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
587
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
588
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
589
|
+
not be modified.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
JSON bytes representing the messages.
|
|
593
|
+
"""
|
|
594
|
+
return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content)
|
|
595
|
+
|
|
596
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
597
|
+
"""Return new messages associated with this run.
|
|
598
|
+
|
|
599
|
+
Messages from older runs are excluded.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
603
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
604
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
605
|
+
not be modified.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
List of new messages.
|
|
609
|
+
"""
|
|
610
|
+
return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content)
|
|
611
|
+
|
|
612
|
+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
|
|
613
|
+
"""Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
617
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
618
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
619
|
+
not be modified.
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
JSON bytes representing the new messages.
|
|
623
|
+
"""
|
|
624
|
+
return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content)
|
|
625
|
+
|
|
626
|
+
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
|
|
627
|
+
"""Stream the output as an iterable.
|
|
628
|
+
|
|
629
|
+
The pydantic validator for structured data will be called in
|
|
630
|
+
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
|
|
631
|
+
on each iteration.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
|
|
635
|
+
Debouncing is particularly important for long structured outputs to reduce the overhead of
|
|
636
|
+
performing validation as each token is received.
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
An iterable of the response data.
|
|
640
|
+
"""
|
|
641
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by))
|
|
642
|
+
|
|
643
|
+
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
|
|
644
|
+
"""Stream the text result as an iterable.
|
|
645
|
+
|
|
646
|
+
!!! note
|
|
647
|
+
Result validators will NOT be called on the text result if `delta=True`.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
|
|
651
|
+
up to the current point.
|
|
652
|
+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
653
|
+
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
654
|
+
performing validation as each token is received.
|
|
655
|
+
"""
|
|
656
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by))
|
|
657
|
+
|
|
658
|
+
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]:
|
|
659
|
+
"""Stream the response as an iterable of Structured LLM Messages.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
663
|
+
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
664
|
+
performing validation as each token is received.
|
|
665
|
+
|
|
666
|
+
Returns:
|
|
667
|
+
An iterable of the structured response message and whether that is the last message.
|
|
668
|
+
"""
|
|
669
|
+
return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by))
|
|
670
|
+
|
|
671
|
+
def get_output(self) -> OutputDataT:
|
|
672
|
+
"""Stream the whole response, validate and return it."""
|
|
673
|
+
return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output())
|
|
674
|
+
|
|
675
|
+
@property
|
|
676
|
+
def response(self) -> _messages.ModelResponse:
|
|
677
|
+
"""Return the current state of the response."""
|
|
678
|
+
return self._streamed_run_result.response
|
|
679
|
+
|
|
680
|
+
def usage(self) -> RunUsage:
|
|
681
|
+
"""Return the usage of the whole run.
|
|
682
|
+
|
|
683
|
+
!!! note
|
|
684
|
+
This won't return the full usage until the stream is finished.
|
|
685
|
+
"""
|
|
686
|
+
return self._streamed_run_result.usage()
|
|
687
|
+
|
|
688
|
+
def timestamp(self) -> datetime:
|
|
689
|
+
"""Get the timestamp of the response."""
|
|
690
|
+
return self._streamed_run_result.timestamp()
|
|
691
|
+
|
|
692
|
+
def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
|
|
693
|
+
"""Validate a structured result message."""
|
|
694
|
+
return _utils.get_event_loop().run_until_complete(
|
|
695
|
+
self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial)
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
@property
|
|
699
|
+
def is_complete(self) -> bool:
|
|
700
|
+
"""Whether the stream has all been received.
|
|
701
|
+
|
|
702
|
+
This is set to `True` when one of
|
|
703
|
+
[`stream_output`][pydantic_ai.result.StreamedRunResultSync.stream_output],
|
|
704
|
+
[`stream_text`][pydantic_ai.result.StreamedRunResultSync.stream_text],
|
|
705
|
+
[`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or
|
|
706
|
+
[`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes.
|
|
707
|
+
"""
|
|
708
|
+
return self._streamed_run_result.is_complete
|
|
709
|
+
|
|
710
|
+
|
|
546
711
|
@dataclass(repr=False)
|
|
547
712
|
class FinalResult(Generic[OutputDataT]):
|
|
548
713
|
"""Marker class storing the final output of an agent run and associated metadata."""
|
pydantic_ai/run.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
|
-
from collections.abc import AsyncIterator
|
|
4
|
+
from collections.abc import AsyncIterator, Sequence
|
|
5
5
|
from copy import deepcopy
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
|
|
8
8
|
|
|
9
|
-
from pydantic_graph import
|
|
9
|
+
from pydantic_graph import BaseNode, End, GraphRunContext
|
|
10
|
+
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
|
|
11
|
+
from pydantic_graph.beta.step import NodeStep
|
|
10
12
|
|
|
11
13
|
from . import (
|
|
12
14
|
_agent_graph,
|
|
@@ -112,12 +114,8 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
112
114
|
|
|
113
115
|
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
114
116
|
"""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
return next_node
|
|
118
|
-
if _agent_graph.is_agent_node(next_node):
|
|
119
|
-
return next_node
|
|
120
|
-
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
117
|
+
task = self._graph_run.next_task
|
|
118
|
+
return self._task_to_node(task)
|
|
121
119
|
|
|
122
120
|
@property
|
|
123
121
|
def result(self) -> AgentRunResult[OutputDataT] | None:
|
|
@@ -126,13 +124,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
126
124
|
Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
|
|
127
125
|
with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
|
|
128
126
|
"""
|
|
129
|
-
|
|
130
|
-
if
|
|
127
|
+
graph_run_output = self._graph_run.output
|
|
128
|
+
if graph_run_output is None:
|
|
131
129
|
return None
|
|
132
130
|
return AgentRunResult(
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
131
|
+
graph_run_output.output,
|
|
132
|
+
graph_run_output.tool_name,
|
|
133
|
+
self._graph_run.state,
|
|
136
134
|
self._graph_run.deps.new_message_index,
|
|
137
135
|
self._traceparent(required=False),
|
|
138
136
|
)
|
|
@@ -147,11 +145,28 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
147
145
|
self,
|
|
148
146
|
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
149
147
|
"""Advance to the next node automatically based on the last returned node."""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
148
|
+
task = await anext(self._graph_run)
|
|
149
|
+
return self._task_to_node(task)
|
|
150
|
+
|
|
151
|
+
def _task_to_node(
|
|
152
|
+
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
|
|
153
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
154
|
+
if isinstance(task, Sequence) and len(task) == 1:
|
|
155
|
+
first_task = task[0]
|
|
156
|
+
if isinstance(first_task.inputs, BaseNode): # pragma: no branch
|
|
157
|
+
base_node: BaseNode[
|
|
158
|
+
_agent_graph.GraphAgentState,
|
|
159
|
+
_agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT],
|
|
160
|
+
FinalResult[OutputDataT],
|
|
161
|
+
] = first_task.inputs # type: ignore[reportUnknownMemberType]
|
|
162
|
+
if _agent_graph.is_agent_node(node=base_node): # pragma: no branch
|
|
163
|
+
return base_node
|
|
164
|
+
if isinstance(task, EndMarker):
|
|
165
|
+
return End(task.value)
|
|
166
|
+
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
|
|
167
|
+
|
|
168
|
+
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
|
|
169
|
+
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
|
|
155
170
|
|
|
156
171
|
async def next(
|
|
157
172
|
self,
|
|
@@ -222,11 +237,12 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
222
237
|
"""
|
|
223
238
|
# Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
|
|
224
239
|
# on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
240
|
+
task = [self._node_to_task(node)]
|
|
241
|
+
try:
|
|
242
|
+
task = await self._graph_run.next(task)
|
|
243
|
+
except StopAsyncIteration:
|
|
244
|
+
pass
|
|
245
|
+
return self._task_to_node(task)
|
|
230
246
|
|
|
231
247
|
# TODO (v2): Make this a property
|
|
232
248
|
def usage(self) -> _usage.RunUsage:
|
|
@@ -234,7 +250,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
234
250
|
return self._graph_run.state.usage
|
|
235
251
|
|
|
236
252
|
def __repr__(self) -> str: # pragma: no cover
|
|
237
|
-
result = self._graph_run.
|
|
253
|
+
result = self._graph_run.output
|
|
238
254
|
result_repr = '<run not finished>' if result is None else repr(result.output)
|
|
239
255
|
return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
|
|
240
256
|
|
pydantic_ai/settings.py
CHANGED
|
@@ -24,6 +24,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
24
24
|
* Mistral
|
|
25
25
|
* Bedrock
|
|
26
26
|
* MCP Sampling
|
|
27
|
+
* Outlines (all providers)
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
30
|
temperature: float
|
|
@@ -43,6 +44,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
43
44
|
* Cohere
|
|
44
45
|
* Mistral
|
|
45
46
|
* Bedrock
|
|
47
|
+
* Outlines (Transformers, LlamaCpp, SgLang, VLLMOffline)
|
|
46
48
|
"""
|
|
47
49
|
|
|
48
50
|
top_p: float
|
|
@@ -61,6 +63,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
61
63
|
* Cohere
|
|
62
64
|
* Mistral
|
|
63
65
|
* Bedrock
|
|
66
|
+
* Outlines (Transformers, LlamaCpp, SgLang, VLLMOffline)
|
|
64
67
|
"""
|
|
65
68
|
|
|
66
69
|
timeout: float | Timeout
|
|
@@ -95,6 +98,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
95
98
|
* Cohere
|
|
96
99
|
* Mistral
|
|
97
100
|
* Gemini
|
|
101
|
+
* Outlines (LlamaCpp, VLLMOffline)
|
|
98
102
|
"""
|
|
99
103
|
|
|
100
104
|
presence_penalty: float
|
|
@@ -107,6 +111,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
107
111
|
* Cohere
|
|
108
112
|
* Gemini
|
|
109
113
|
* Mistral
|
|
114
|
+
* Outlines (LlamaCpp, SgLang, VLLMOffline)
|
|
110
115
|
"""
|
|
111
116
|
|
|
112
117
|
frequency_penalty: float
|
|
@@ -119,6 +124,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
119
124
|
* Cohere
|
|
120
125
|
* Gemini
|
|
121
126
|
* Mistral
|
|
127
|
+
* Outlines (LlamaCpp, SgLang, VLLMOffline)
|
|
122
128
|
"""
|
|
123
129
|
|
|
124
130
|
logit_bias: dict[str, int]
|
|
@@ -128,6 +134,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
128
134
|
|
|
129
135
|
* OpenAI
|
|
130
136
|
* Groq
|
|
137
|
+
* Outlines (Transformers, LlamaCpp, VLLMOffline)
|
|
131
138
|
"""
|
|
132
139
|
|
|
133
140
|
stop_sequences: list[str]
|
|
@@ -162,6 +169,7 @@ class ModelSettings(TypedDict, total=False):
|
|
|
162
169
|
* OpenAI
|
|
163
170
|
* Anthropic
|
|
164
171
|
* Groq
|
|
172
|
+
* Outlines (all providers)
|
|
165
173
|
"""
|
|
166
174
|
|
|
167
175
|
|
pydantic_ai/tools.py
CHANGED
|
@@ -240,16 +240,20 @@ class GenerateToolJsonSchema(GenerateJsonSchema):
|
|
|
240
240
|
return s
|
|
241
241
|
|
|
242
242
|
|
|
243
|
+
ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True)
|
|
244
|
+
"""Type variable for agent dependencies for a tool."""
|
|
245
|
+
|
|
246
|
+
|
|
243
247
|
@dataclass(init=False)
|
|
244
|
-
class Tool(Generic[
|
|
248
|
+
class Tool(Generic[ToolAgentDepsT]):
|
|
245
249
|
"""A tool function for an agent."""
|
|
246
250
|
|
|
247
|
-
function: ToolFuncEither[
|
|
251
|
+
function: ToolFuncEither[ToolAgentDepsT]
|
|
248
252
|
takes_ctx: bool
|
|
249
253
|
max_retries: int | None
|
|
250
254
|
name: str
|
|
251
255
|
description: str | None
|
|
252
|
-
prepare: ToolPrepareFunc[
|
|
256
|
+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None
|
|
253
257
|
docstring_format: DocstringFormat
|
|
254
258
|
require_parameter_descriptions: bool
|
|
255
259
|
strict: bool | None
|
|
@@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]):
|
|
|
265
269
|
|
|
266
270
|
def __init__(
|
|
267
271
|
self,
|
|
268
|
-
function: ToolFuncEither[
|
|
272
|
+
function: ToolFuncEither[ToolAgentDepsT],
|
|
269
273
|
*,
|
|
270
274
|
takes_ctx: bool | None = None,
|
|
271
275
|
max_retries: int | None = None,
|
|
272
276
|
name: str | None = None,
|
|
273
277
|
description: str | None = None,
|
|
274
|
-
prepare: ToolPrepareFunc[
|
|
278
|
+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None,
|
|
275
279
|
docstring_format: DocstringFormat = 'auto',
|
|
276
280
|
require_parameter_descriptions: bool = False,
|
|
277
281
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
@@ -413,7 +417,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
413
417
|
metadata=self.metadata,
|
|
414
418
|
)
|
|
415
419
|
|
|
416
|
-
async def prepare_tool_def(self, ctx: RunContext[
|
|
420
|
+
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
|
|
417
421
|
"""Get the tool definition.
|
|
418
422
|
|
|
419
423
|
By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from asyncio import Lock
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
6
|
+
from dataclasses import KW_ONLY, dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
9
|
+
|
|
10
|
+
from pydantic import AnyUrl
|
|
11
|
+
from typing_extensions import Self, assert_never
|
|
12
|
+
|
|
13
|
+
from pydantic_ai import messages
|
|
14
|
+
from pydantic_ai.exceptions import ModelRetry
|
|
15
|
+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
|
|
16
|
+
from pydantic_ai.toolsets import AbstractToolset
|
|
17
|
+
from pydantic_ai.toolsets.abstract import ToolsetTool
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from fastmcp.client import Client
|
|
21
|
+
from fastmcp.client.transports import ClientTransport
|
|
22
|
+
from fastmcp.exceptions import ToolError
|
|
23
|
+
from fastmcp.mcp_config import MCPConfig
|
|
24
|
+
from fastmcp.server import FastMCP
|
|
25
|
+
from mcp.server.fastmcp import FastMCP as FastMCP1Server
|
|
26
|
+
from mcp.types import (
|
|
27
|
+
AudioContent,
|
|
28
|
+
BlobResourceContents,
|
|
29
|
+
ContentBlock,
|
|
30
|
+
EmbeddedResource,
|
|
31
|
+
ImageContent,
|
|
32
|
+
ResourceLink,
|
|
33
|
+
TextContent,
|
|
34
|
+
TextResourceContents,
|
|
35
|
+
Tool as MCPTool,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from pydantic_ai.mcp import TOOL_SCHEMA_VALIDATOR
|
|
39
|
+
|
|
40
|
+
except ImportError as _import_error:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
'Please install the `fastmcp` package to use the FastMCP server, '
|
|
43
|
+
'you can use the `fastmcp` optional group — `pip install "pydantic-ai-slim[fastmcp]"`'
|
|
44
|
+
) from _import_error
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from fastmcp.client.client import CallToolResult
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
FastMCPToolResult = messages.BinaryContent | dict[str, Any] | str | None
|
|
52
|
+
|
|
53
|
+
ToolErrorBehavior = Literal['model_retry', 'error']
|
|
54
|
+
|
|
55
|
+
UNKNOWN_BINARY_MEDIA_TYPE = 'application/octet-stream'
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(init=False)
|
|
59
|
+
class FastMCPToolset(AbstractToolset[AgentDepsT]):
|
|
60
|
+
"""A FastMCP Toolset that uses the FastMCP Client to call tools from a local or remote MCP Server.
|
|
61
|
+
|
|
62
|
+
The Toolset can accept a FastMCP Client, a FastMCP Transport, or any other object which a FastMCP Transport can be created from.
|
|
63
|
+
|
|
64
|
+
See https://gofastmcp.com/clients/transports for a full list of transports available.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
client: Client[Any]
|
|
68
|
+
"""The FastMCP client to use."""
|
|
69
|
+
|
|
70
|
+
_: KW_ONLY
|
|
71
|
+
|
|
72
|
+
tool_error_behavior: Literal['model_retry', 'error']
|
|
73
|
+
"""The behavior to take when a tool error occurs."""
|
|
74
|
+
|
|
75
|
+
max_retries: int
|
|
76
|
+
"""The maximum number of retries to attempt if a tool call fails."""
|
|
77
|
+
|
|
78
|
+
_id: str | None
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
client: Client[Any]
|
|
83
|
+
| ClientTransport
|
|
84
|
+
| FastMCP
|
|
85
|
+
| FastMCP1Server
|
|
86
|
+
| AnyUrl
|
|
87
|
+
| Path
|
|
88
|
+
| MCPConfig
|
|
89
|
+
| dict[str, Any]
|
|
90
|
+
| str,
|
|
91
|
+
*,
|
|
92
|
+
max_retries: int = 1,
|
|
93
|
+
tool_error_behavior: Literal['model_retry', 'error'] = 'model_retry',
|
|
94
|
+
id: str | None = None,
|
|
95
|
+
) -> None:
|
|
96
|
+
if isinstance(client, Client):
|
|
97
|
+
self.client = client
|
|
98
|
+
else:
|
|
99
|
+
self.client = Client[Any](transport=client)
|
|
100
|
+
|
|
101
|
+
self._id = id
|
|
102
|
+
self.max_retries = max_retries
|
|
103
|
+
self.tool_error_behavior = tool_error_behavior
|
|
104
|
+
|
|
105
|
+
self._enter_lock: Lock = Lock()
|
|
106
|
+
self._running_count: int = 0
|
|
107
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def id(self) -> str | None:
|
|
111
|
+
return self._id
|
|
112
|
+
|
|
113
|
+
async def __aenter__(self) -> Self:
|
|
114
|
+
async with self._enter_lock:
|
|
115
|
+
if self._running_count == 0:
|
|
116
|
+
self._exit_stack = AsyncExitStack()
|
|
117
|
+
await self._exit_stack.enter_async_context(self.client)
|
|
118
|
+
|
|
119
|
+
self._running_count += 1
|
|
120
|
+
|
|
121
|
+
return self
|
|
122
|
+
|
|
123
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
124
|
+
async with self._enter_lock:
|
|
125
|
+
self._running_count -= 1
|
|
126
|
+
if self._running_count == 0 and self._exit_stack:
|
|
127
|
+
await self._exit_stack.aclose()
|
|
128
|
+
self._exit_stack = None
|
|
129
|
+
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
|
|
133
|
+
async with self:
|
|
134
|
+
mcp_tools: list[MCPTool] = await self.client.list_tools()
|
|
135
|
+
|
|
136
|
+
return {
|
|
137
|
+
tool.name: _convert_mcp_tool_to_toolset_tool(toolset=self, mcp_tool=tool, retries=self.max_retries)
|
|
138
|
+
for tool in mcp_tools
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
async def call_tool(
|
|
142
|
+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
143
|
+
) -> Any:
|
|
144
|
+
async with self:
|
|
145
|
+
try:
|
|
146
|
+
call_tool_result: CallToolResult = await self.client.call_tool(name=name, arguments=tool_args)
|
|
147
|
+
except ToolError as e:
|
|
148
|
+
if self.tool_error_behavior == 'model_retry':
|
|
149
|
+
raise ModelRetry(message=str(e)) from e
|
|
150
|
+
else:
|
|
151
|
+
raise e
|
|
152
|
+
|
|
153
|
+
# If we have structured content, return that
|
|
154
|
+
if call_tool_result.structured_content:
|
|
155
|
+
return call_tool_result.structured_content
|
|
156
|
+
|
|
157
|
+
# Otherwise, return the content
|
|
158
|
+
return _map_fastmcp_tool_results(parts=call_tool_result.content)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _convert_mcp_tool_to_toolset_tool(
|
|
162
|
+
toolset: FastMCPToolset[AgentDepsT],
|
|
163
|
+
mcp_tool: MCPTool,
|
|
164
|
+
retries: int,
|
|
165
|
+
) -> ToolsetTool[AgentDepsT]:
|
|
166
|
+
"""Convert an MCP tool to a toolset tool."""
|
|
167
|
+
return ToolsetTool[AgentDepsT](
|
|
168
|
+
tool_def=ToolDefinition(
|
|
169
|
+
name=mcp_tool.name,
|
|
170
|
+
description=mcp_tool.description,
|
|
171
|
+
parameters_json_schema=mcp_tool.inputSchema,
|
|
172
|
+
metadata={
|
|
173
|
+
'meta': mcp_tool.meta,
|
|
174
|
+
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
|
|
175
|
+
'output_schema': mcp_tool.outputSchema or None,
|
|
176
|
+
},
|
|
177
|
+
),
|
|
178
|
+
toolset=toolset,
|
|
179
|
+
max_retries=retries,
|
|
180
|
+
args_validator=TOOL_SCHEMA_VALIDATOR,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _map_fastmcp_tool_results(parts: list[ContentBlock]) -> list[FastMCPToolResult] | FastMCPToolResult:
|
|
185
|
+
"""Map FastMCP tool results to toolset tool results."""
|
|
186
|
+
mapped_results = [_map_fastmcp_tool_result(part) for part in parts]
|
|
187
|
+
|
|
188
|
+
if len(mapped_results) == 1:
|
|
189
|
+
return mapped_results[0]
|
|
190
|
+
|
|
191
|
+
return mapped_results
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _map_fastmcp_tool_result(part: ContentBlock) -> FastMCPToolResult:
|
|
195
|
+
if isinstance(part, TextContent):
|
|
196
|
+
return part.text
|
|
197
|
+
elif isinstance(part, ImageContent | AudioContent):
|
|
198
|
+
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
199
|
+
elif isinstance(part, EmbeddedResource):
|
|
200
|
+
if isinstance(part.resource, BlobResourceContents):
|
|
201
|
+
return messages.BinaryContent(
|
|
202
|
+
data=base64.b64decode(part.resource.blob),
|
|
203
|
+
media_type=part.resource.mimeType or UNKNOWN_BINARY_MEDIA_TYPE,
|
|
204
|
+
)
|
|
205
|
+
elif isinstance(part.resource, TextResourceContents):
|
|
206
|
+
return part.resource.text
|
|
207
|
+
else:
|
|
208
|
+
assert_never(part.resource)
|
|
209
|
+
elif isinstance(part, ResourceLink):
|
|
210
|
+
# ResourceLink is not yet supported by the FastMCP toolset as reading resources is not yet supported.
|
|
211
|
+
raise NotImplementedError(
|
|
212
|
+
'ResourceLink is not supported by the FastMCP toolset as reading resources is not yet supported.'
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
assert_never(part)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ._adapter import StateDeps, StateHandler, UIAdapter
|
|
4
|
+
from ._event_stream import SSE_CONTENT_TYPE, NativeEvent, OnCompleteFunc, UIEventStream
|
|
5
|
+
from ._messages_builder import MessagesBuilder
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'UIAdapter',
|
|
9
|
+
'UIEventStream',
|
|
10
|
+
'SSE_CONTENT_TYPE',
|
|
11
|
+
'StateDeps',
|
|
12
|
+
'StateHandler',
|
|
13
|
+
'NativeEvent',
|
|
14
|
+
'OnCompleteFunc',
|
|
15
|
+
'MessagesBuilder',
|
|
16
|
+
]
|