pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. pydantic_ai/__init__.py +6 -0
  2. pydantic_ai/_agent_graph.py +67 -20
  3. pydantic_ai/_cli.py +2 -2
  4. pydantic_ai/_output.py +20 -12
  5. pydantic_ai/_run_context.py +6 -2
  6. pydantic_ai/_utils.py +26 -8
  7. pydantic_ai/ag_ui.py +50 -696
  8. pydantic_ai/agent/__init__.py +13 -25
  9. pydantic_ai/agent/abstract.py +146 -9
  10. pydantic_ai/builtin_tools.py +106 -4
  11. pydantic_ai/direct.py +16 -4
  12. pydantic_ai/durable_exec/dbos/_agent.py +3 -0
  13. pydantic_ai/durable_exec/prefect/_agent.py +3 -0
  14. pydantic_ai/durable_exec/temporal/__init__.py +11 -0
  15. pydantic_ai/durable_exec/temporal/_agent.py +3 -0
  16. pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
  17. pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
  18. pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
  19. pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
  20. pydantic_ai/exceptions.py +6 -1
  21. pydantic_ai/mcp.py +1 -22
  22. pydantic_ai/messages.py +46 -8
  23. pydantic_ai/models/__init__.py +87 -38
  24. pydantic_ai/models/anthropic.py +132 -11
  25. pydantic_ai/models/bedrock.py +4 -4
  26. pydantic_ai/models/cohere.py +0 -7
  27. pydantic_ai/models/gemini.py +9 -2
  28. pydantic_ai/models/google.py +26 -23
  29. pydantic_ai/models/groq.py +13 -5
  30. pydantic_ai/models/huggingface.py +2 -2
  31. pydantic_ai/models/openai.py +251 -52
  32. pydantic_ai/models/outlines.py +563 -0
  33. pydantic_ai/models/test.py +6 -3
  34. pydantic_ai/profiles/openai.py +7 -0
  35. pydantic_ai/providers/__init__.py +25 -12
  36. pydantic_ai/providers/anthropic.py +2 -2
  37. pydantic_ai/providers/bedrock.py +60 -16
  38. pydantic_ai/providers/gateway.py +60 -72
  39. pydantic_ai/providers/google.py +91 -24
  40. pydantic_ai/providers/openrouter.py +3 -0
  41. pydantic_ai/providers/outlines.py +40 -0
  42. pydantic_ai/providers/ovhcloud.py +95 -0
  43. pydantic_ai/result.py +173 -8
  44. pydantic_ai/run.py +40 -24
  45. pydantic_ai/settings.py +8 -0
  46. pydantic_ai/tools.py +10 -6
  47. pydantic_ai/toolsets/fastmcp.py +215 -0
  48. pydantic_ai/ui/__init__.py +16 -0
  49. pydantic_ai/ui/_adapter.py +386 -0
  50. pydantic_ai/ui/_event_stream.py +591 -0
  51. pydantic_ai/ui/_messages_builder.py +28 -0
  52. pydantic_ai/ui/ag_ui/__init__.py +9 -0
  53. pydantic_ai/ui/ag_ui/_adapter.py +187 -0
  54. pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
  55. pydantic_ai/ui/ag_ui/app.py +148 -0
  56. pydantic_ai/ui/vercel_ai/__init__.py +16 -0
  57. pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
  58. pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
  59. pydantic_ai/ui/vercel_ai/_utils.py +16 -0
  60. pydantic_ai/ui/vercel_ai/request_types.py +275 -0
  61. pydantic_ai/ui/vercel_ai/response_types.py +230 -0
  62. pydantic_ai/usage.py +13 -2
  63. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
  64. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
  65. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
  66. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
  67. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/result.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
3
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
4
4
  from copy import deepcopy
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime
@@ -35,6 +35,7 @@ __all__ = (
35
35
  'OutputDataT_inv',
36
36
  'ToolOutput',
37
37
  'OutputValidatorFunc',
38
+ 'StreamedRunResultSync',
38
39
  )
39
40
 
40
41
 
@@ -60,14 +61,26 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
60
61
 
61
62
  async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
62
63
  """Asynchronously stream the (validated) agent outputs."""
64
+ last_response: _messages.ModelResponse | None = None
63
65
  async for response in self.stream_responses(debounce_by=debounce_by):
64
- if self._raw_stream_response.final_result_event is not None:
65
- try:
66
- yield await self.validate_response_output(response, allow_partial=True)
67
- except ValidationError:
68
- pass
69
- if self._raw_stream_response.final_result_event is not None: # pragma: no branch
70
- yield await self.validate_response_output(self.response)
66
+ if self._raw_stream_response.final_result_event is None or (
67
+ last_response and response.parts == last_response.parts
68
+ ):
69
+ continue
70
+ last_response = response
71
+
72
+ try:
73
+ yield await self.validate_response_output(response, allow_partial=True)
74
+ except ValidationError:
75
+ pass
76
+
77
+ response = self.response
78
+ if self._raw_stream_response.final_result_event is None or (
79
+ last_response and response.parts == last_response.parts
80
+ ):
81
+ return
82
+
83
+ yield await self.validate_response_output(response)
71
84
 
72
85
  async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
73
86
  """Asynchronously stream the (unvalidated) model responses for the agent."""
@@ -543,6 +556,158 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
543
556
  await self._on_complete()
544
557
 
545
558
 
559
+ @dataclass(init=False)
560
+ class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]):
561
+ """Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods."""
562
+
563
+ _streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
564
+
565
+ def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None:
566
+ self._streamed_run_result = streamed_run_result
567
+
568
+ def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
569
+ """Return the history of messages.
570
+
571
+ Args:
572
+ output_tool_return_content: The return content of the tool call to set in the last message.
573
+ This provides a convenient way to modify the content of the output tool call if you want to continue
574
+ the conversation and want to set the response to the output tool call. If `None`, the last message will
575
+ not be modified.
576
+
577
+ Returns:
578
+ List of messages.
579
+ """
580
+ return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content)
581
+
582
+ def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
583
+ """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes.
584
+
585
+ Args:
586
+ output_tool_return_content: The return content of the tool call to set in the last message.
587
+ This provides a convenient way to modify the content of the output tool call if you want to continue
588
+ the conversation and want to set the response to the output tool call. If `None`, the last message will
589
+ not be modified.
590
+
591
+ Returns:
592
+ JSON bytes representing the messages.
593
+ """
594
+ return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content)
595
+
596
+ def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
597
+ """Return new messages associated with this run.
598
+
599
+ Messages from older runs are excluded.
600
+
601
+ Args:
602
+ output_tool_return_content: The return content of the tool call to set in the last message.
603
+ This provides a convenient way to modify the content of the output tool call if you want to continue
604
+ the conversation and want to set the response to the output tool call. If `None`, the last message will
605
+ not be modified.
606
+
607
+ Returns:
608
+ List of new messages.
609
+ """
610
+ return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content)
611
+
612
+ def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
613
+ """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes.
614
+
615
+ Args:
616
+ output_tool_return_content: The return content of the tool call to set in the last message.
617
+ This provides a convenient way to modify the content of the output tool call if you want to continue
618
+ the conversation and want to set the response to the output tool call. If `None`, the last message will
619
+ not be modified.
620
+
621
+ Returns:
622
+ JSON bytes representing the new messages.
623
+ """
624
+ return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content)
625
+
626
+ def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
627
+ """Stream the output as an iterable.
628
+
629
+ The pydantic validator for structured data will be called in
630
+ [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
631
+ on each iteration.
632
+
633
+ Args:
634
+ debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
635
+ Debouncing is particularly important for long structured outputs to reduce the overhead of
636
+ performing validation as each token is received.
637
+
638
+ Returns:
639
+ An iterable of the response data.
640
+ """
641
+ return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by))
642
+
643
+ def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
644
+ """Stream the text result as an iterable.
645
+
646
+ !!! note
647
+ Result validators will NOT be called on the text result if `delta=True`.
648
+
649
+ Args:
650
+ delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
651
+ up to the current point.
652
+ debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
653
+ Debouncing is particularly important for long structured responses to reduce the overhead of
654
+ performing validation as each token is received.
655
+ """
656
+ return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by))
657
+
658
+ def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]:
659
+ """Stream the response as an iterable of Structured LLM Messages.
660
+
661
+ Args:
662
+ debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
663
+ Debouncing is particularly important for long structured responses to reduce the overhead of
664
+ performing validation as each token is received.
665
+
666
+ Returns:
667
+ An iterable of the structured response message and whether that is the last message.
668
+ """
669
+ return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by))
670
+
671
+ def get_output(self) -> OutputDataT:
672
+ """Stream the whole response, validate and return it."""
673
+ return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output())
674
+
675
+ @property
676
+ def response(self) -> _messages.ModelResponse:
677
+ """Return the current state of the response."""
678
+ return self._streamed_run_result.response
679
+
680
+ def usage(self) -> RunUsage:
681
+ """Return the usage of the whole run.
682
+
683
+ !!! note
684
+ This won't return the full usage until the stream is finished.
685
+ """
686
+ return self._streamed_run_result.usage()
687
+
688
+ def timestamp(self) -> datetime:
689
+ """Get the timestamp of the response."""
690
+ return self._streamed_run_result.timestamp()
691
+
692
+ def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
693
+ """Validate a structured result message."""
694
+ return _utils.get_event_loop().run_until_complete(
695
+ self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial)
696
+ )
697
+
698
+ @property
699
+ def is_complete(self) -> bool:
700
+ """Whether the stream has all been received.
701
+
702
+ This is set to `True` when one of
703
+ [`stream_output`][pydantic_ai.result.StreamedRunResultSync.stream_output],
704
+ [`stream_text`][pydantic_ai.result.StreamedRunResultSync.stream_text],
705
+ [`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or
706
+ [`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes.
707
+ """
708
+ return self._streamed_run_result.is_complete
709
+
710
+
546
711
  @dataclass(repr=False)
547
712
  class FinalResult(Generic[OutputDataT]):
548
713
  """Marker class storing the final output of an agent run and associated metadata."""
pydantic_ai/run.py CHANGED
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import dataclasses
4
- from collections.abc import AsyncIterator
4
+ from collections.abc import AsyncIterator, Sequence
5
5
  from copy import deepcopy
6
6
  from datetime import datetime
7
7
  from typing import TYPE_CHECKING, Any, Generic, Literal, overload
8
8
 
9
- from pydantic_graph import End, GraphRun, GraphRunContext
9
+ from pydantic_graph import BaseNode, End, GraphRunContext
10
+ from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
11
+ from pydantic_graph.beta.step import NodeStep
10
12
 
11
13
  from . import (
12
14
  _agent_graph,
@@ -112,12 +114,8 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
112
114
 
113
115
  This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
114
116
  """
115
- next_node = self._graph_run.next_node
116
- if isinstance(next_node, End):
117
- return next_node
118
- if _agent_graph.is_agent_node(next_node):
119
- return next_node
120
- raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
117
+ task = self._graph_run.next_task
118
+ return self._task_to_node(task)
121
119
 
122
120
  @property
123
121
  def result(self) -> AgentRunResult[OutputDataT] | None:
@@ -126,13 +124,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
126
124
  Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
127
125
  with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
128
126
  """
129
- graph_run_result = self._graph_run.result
130
- if graph_run_result is None:
127
+ graph_run_output = self._graph_run.output
128
+ if graph_run_output is None:
131
129
  return None
132
130
  return AgentRunResult(
133
- graph_run_result.output.output,
134
- graph_run_result.output.tool_name,
135
- graph_run_result.state,
131
+ graph_run_output.output,
132
+ graph_run_output.tool_name,
133
+ self._graph_run.state,
136
134
  self._graph_run.deps.new_message_index,
137
135
  self._traceparent(required=False),
138
136
  )
@@ -147,11 +145,28 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
147
145
  self,
148
146
  ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
149
147
  """Advance to the next node automatically based on the last returned node."""
150
- next_node = await self._graph_run.__anext__()
151
- if _agent_graph.is_agent_node(node=next_node):
152
- return next_node
153
- assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
154
- return next_node
148
+ task = await anext(self._graph_run)
149
+ return self._task_to_node(task)
150
+
151
+ def _task_to_node(
152
+ self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
153
+ ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
154
+ if isinstance(task, Sequence) and len(task) == 1:
155
+ first_task = task[0]
156
+ if isinstance(first_task.inputs, BaseNode): # pragma: no branch
157
+ base_node: BaseNode[
158
+ _agent_graph.GraphAgentState,
159
+ _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT],
160
+ FinalResult[OutputDataT],
161
+ ] = first_task.inputs # type: ignore[reportUnknownMemberType]
162
+ if _agent_graph.is_agent_node(node=base_node): # pragma: no branch
163
+ return base_node
164
+ if isinstance(task, EndMarker):
165
+ return End(task.value)
166
+ raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
167
+
168
+ def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
169
+ return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
155
170
 
156
171
  async def next(
157
172
  self,
@@ -222,11 +237,12 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
222
237
  """
223
238
  # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
224
239
  # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
225
- next_node = await self._graph_run.next(node)
226
- if _agent_graph.is_agent_node(next_node):
227
- return next_node
228
- assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
229
- return next_node
240
+ task = [self._node_to_task(node)]
241
+ try:
242
+ task = await self._graph_run.next(task)
243
+ except StopAsyncIteration:
244
+ pass
245
+ return self._task_to_node(task)
230
246
 
231
247
  # TODO (v2): Make this a property
232
248
  def usage(self) -> _usage.RunUsage:
@@ -234,7 +250,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
234
250
  return self._graph_run.state.usage
235
251
 
236
252
  def __repr__(self) -> str: # pragma: no cover
237
- result = self._graph_run.result
253
+ result = self._graph_run.output
238
254
  result_repr = '<run not finished>' if result is None else repr(result.output)
239
255
  return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
240
256
 
pydantic_ai/settings.py CHANGED
@@ -24,6 +24,7 @@ class ModelSettings(TypedDict, total=False):
24
24
  * Mistral
25
25
  * Bedrock
26
26
  * MCP Sampling
27
+ * Outlines (all providers)
27
28
  """
28
29
 
29
30
  temperature: float
@@ -43,6 +44,7 @@ class ModelSettings(TypedDict, total=False):
43
44
  * Cohere
44
45
  * Mistral
45
46
  * Bedrock
47
+ * Outlines (Transformers, LlamaCpp, SgLang, VLLMOffline)
46
48
  """
47
49
 
48
50
  top_p: float
@@ -61,6 +63,7 @@ class ModelSettings(TypedDict, total=False):
61
63
  * Cohere
62
64
  * Mistral
63
65
  * Bedrock
66
+ * Outlines (Transformers, LlamaCpp, SgLang, VLLMOffline)
64
67
  """
65
68
 
66
69
  timeout: float | Timeout
@@ -95,6 +98,7 @@ class ModelSettings(TypedDict, total=False):
95
98
  * Cohere
96
99
  * Mistral
97
100
  * Gemini
101
+ * Outlines (LlamaCpp, VLLMOffline)
98
102
  """
99
103
 
100
104
  presence_penalty: float
@@ -107,6 +111,7 @@ class ModelSettings(TypedDict, total=False):
107
111
  * Cohere
108
112
  * Gemini
109
113
  * Mistral
114
+ * Outlines (LlamaCpp, SgLang, VLLMOffline)
110
115
  """
111
116
 
112
117
  frequency_penalty: float
@@ -119,6 +124,7 @@ class ModelSettings(TypedDict, total=False):
119
124
  * Cohere
120
125
  * Gemini
121
126
  * Mistral
127
+ * Outlines (LlamaCpp, SgLang, VLLMOffline)
122
128
  """
123
129
 
124
130
  logit_bias: dict[str, int]
@@ -128,6 +134,7 @@ class ModelSettings(TypedDict, total=False):
128
134
 
129
135
  * OpenAI
130
136
  * Groq
137
+ * Outlines (Transformers, LlamaCpp, VLLMOffline)
131
138
  """
132
139
 
133
140
  stop_sequences: list[str]
@@ -162,6 +169,7 @@ class ModelSettings(TypedDict, total=False):
162
169
  * OpenAI
163
170
  * Anthropic
164
171
  * Groq
172
+ * Outlines (all providers)
165
173
  """
166
174
 
167
175
 
pydantic_ai/tools.py CHANGED
@@ -240,16 +240,20 @@ class GenerateToolJsonSchema(GenerateJsonSchema):
240
240
  return s
241
241
 
242
242
 
243
+ ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True)
244
+ """Type variable for agent dependencies for a tool."""
245
+
246
+
243
247
  @dataclass(init=False)
244
- class Tool(Generic[AgentDepsT]):
248
+ class Tool(Generic[ToolAgentDepsT]):
245
249
  """A tool function for an agent."""
246
250
 
247
- function: ToolFuncEither[AgentDepsT]
251
+ function: ToolFuncEither[ToolAgentDepsT]
248
252
  takes_ctx: bool
249
253
  max_retries: int | None
250
254
  name: str
251
255
  description: str | None
252
- prepare: ToolPrepareFunc[AgentDepsT] | None
256
+ prepare: ToolPrepareFunc[ToolAgentDepsT] | None
253
257
  docstring_format: DocstringFormat
254
258
  require_parameter_descriptions: bool
255
259
  strict: bool | None
@@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]):
265
269
 
266
270
  def __init__(
267
271
  self,
268
- function: ToolFuncEither[AgentDepsT],
272
+ function: ToolFuncEither[ToolAgentDepsT],
269
273
  *,
270
274
  takes_ctx: bool | None = None,
271
275
  max_retries: int | None = None,
272
276
  name: str | None = None,
273
277
  description: str | None = None,
274
- prepare: ToolPrepareFunc[AgentDepsT] | None = None,
278
+ prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None,
275
279
  docstring_format: DocstringFormat = 'auto',
276
280
  require_parameter_descriptions: bool = False,
277
281
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
@@ -413,7 +417,7 @@ class Tool(Generic[AgentDepsT]):
413
417
  metadata=self.metadata,
414
418
  )
415
419
 
416
- async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
420
+ async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
417
421
  """Get the tool definition.
418
422
 
419
423
  By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
@@ -0,0 +1,215 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from asyncio import Lock
5
+ from contextlib import AsyncExitStack
6
+ from dataclasses import KW_ONLY, dataclass
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ from pydantic import AnyUrl
11
+ from typing_extensions import Self, assert_never
12
+
13
+ from pydantic_ai import messages
14
+ from pydantic_ai.exceptions import ModelRetry
15
+ from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
16
+ from pydantic_ai.toolsets import AbstractToolset
17
+ from pydantic_ai.toolsets.abstract import ToolsetTool
18
+
19
+ try:
20
+ from fastmcp.client import Client
21
+ from fastmcp.client.transports import ClientTransport
22
+ from fastmcp.exceptions import ToolError
23
+ from fastmcp.mcp_config import MCPConfig
24
+ from fastmcp.server import FastMCP
25
+ from mcp.server.fastmcp import FastMCP as FastMCP1Server
26
+ from mcp.types import (
27
+ AudioContent,
28
+ BlobResourceContents,
29
+ ContentBlock,
30
+ EmbeddedResource,
31
+ ImageContent,
32
+ ResourceLink,
33
+ TextContent,
34
+ TextResourceContents,
35
+ Tool as MCPTool,
36
+ )
37
+
38
+ from pydantic_ai.mcp import TOOL_SCHEMA_VALIDATOR
39
+
40
+ except ImportError as _import_error:
41
+ raise ImportError(
42
+ 'Please install the `fastmcp` package to use the FastMCP server, '
43
+ 'you can use the `fastmcp` optional group — `pip install "pydantic-ai-slim[fastmcp]"`'
44
+ ) from _import_error
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from fastmcp.client.client import CallToolResult
49
+
50
+
51
+ FastMCPToolResult = messages.BinaryContent | dict[str, Any] | str | None
52
+
53
+ ToolErrorBehavior = Literal['model_retry', 'error']
54
+
55
+ UNKNOWN_BINARY_MEDIA_TYPE = 'application/octet-stream'
56
+
57
+
58
+ @dataclass(init=False)
59
+ class FastMCPToolset(AbstractToolset[AgentDepsT]):
60
+ """A FastMCP Toolset that uses the FastMCP Client to call tools from a local or remote MCP Server.
61
+
62
+ The Toolset can accept a FastMCP Client, a FastMCP Transport, or any other object which a FastMCP Transport can be created from.
63
+
64
+ See https://gofastmcp.com/clients/transports for a full list of transports available.
65
+ """
66
+
67
+ client: Client[Any]
68
+ """The FastMCP client to use."""
69
+
70
+ _: KW_ONLY
71
+
72
+ tool_error_behavior: Literal['model_retry', 'error']
73
+ """The behavior to take when a tool error occurs."""
74
+
75
+ max_retries: int
76
+ """The maximum number of retries to attempt if a tool call fails."""
77
+
78
+ _id: str | None
79
+
80
+ def __init__(
81
+ self,
82
+ client: Client[Any]
83
+ | ClientTransport
84
+ | FastMCP
85
+ | FastMCP1Server
86
+ | AnyUrl
87
+ | Path
88
+ | MCPConfig
89
+ | dict[str, Any]
90
+ | str,
91
+ *,
92
+ max_retries: int = 1,
93
+ tool_error_behavior: Literal['model_retry', 'error'] = 'model_retry',
94
+ id: str | None = None,
95
+ ) -> None:
96
+ if isinstance(client, Client):
97
+ self.client = client
98
+ else:
99
+ self.client = Client[Any](transport=client)
100
+
101
+ self._id = id
102
+ self.max_retries = max_retries
103
+ self.tool_error_behavior = tool_error_behavior
104
+
105
+ self._enter_lock: Lock = Lock()
106
+ self._running_count: int = 0
107
+ self._exit_stack: AsyncExitStack | None = None
108
+
109
+ @property
110
+ def id(self) -> str | None:
111
+ return self._id
112
+
113
+ async def __aenter__(self) -> Self:
114
+ async with self._enter_lock:
115
+ if self._running_count == 0:
116
+ self._exit_stack = AsyncExitStack()
117
+ await self._exit_stack.enter_async_context(self.client)
118
+
119
+ self._running_count += 1
120
+
121
+ return self
122
+
123
+ async def __aexit__(self, *args: Any) -> bool | None:
124
+ async with self._enter_lock:
125
+ self._running_count -= 1
126
+ if self._running_count == 0 and self._exit_stack:
127
+ await self._exit_stack.aclose()
128
+ self._exit_stack = None
129
+
130
+ return None
131
+
132
+ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
133
+ async with self:
134
+ mcp_tools: list[MCPTool] = await self.client.list_tools()
135
+
136
+ return {
137
+ tool.name: _convert_mcp_tool_to_toolset_tool(toolset=self, mcp_tool=tool, retries=self.max_retries)
138
+ for tool in mcp_tools
139
+ }
140
+
141
+ async def call_tool(
142
+ self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
143
+ ) -> Any:
144
+ async with self:
145
+ try:
146
+ call_tool_result: CallToolResult = await self.client.call_tool(name=name, arguments=tool_args)
147
+ except ToolError as e:
148
+ if self.tool_error_behavior == 'model_retry':
149
+ raise ModelRetry(message=str(e)) from e
150
+ else:
151
+ raise e
152
+
153
+ # If we have structured content, return that
154
+ if call_tool_result.structured_content:
155
+ return call_tool_result.structured_content
156
+
157
+ # Otherwise, return the content
158
+ return _map_fastmcp_tool_results(parts=call_tool_result.content)
159
+
160
+
161
+ def _convert_mcp_tool_to_toolset_tool(
162
+ toolset: FastMCPToolset[AgentDepsT],
163
+ mcp_tool: MCPTool,
164
+ retries: int,
165
+ ) -> ToolsetTool[AgentDepsT]:
166
+ """Convert an MCP tool to a toolset tool."""
167
+ return ToolsetTool[AgentDepsT](
168
+ tool_def=ToolDefinition(
169
+ name=mcp_tool.name,
170
+ description=mcp_tool.description,
171
+ parameters_json_schema=mcp_tool.inputSchema,
172
+ metadata={
173
+ 'meta': mcp_tool.meta,
174
+ 'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
175
+ 'output_schema': mcp_tool.outputSchema or None,
176
+ },
177
+ ),
178
+ toolset=toolset,
179
+ max_retries=retries,
180
+ args_validator=TOOL_SCHEMA_VALIDATOR,
181
+ )
182
+
183
+
184
+ def _map_fastmcp_tool_results(parts: list[ContentBlock]) -> list[FastMCPToolResult] | FastMCPToolResult:
185
+ """Map FastMCP tool results to toolset tool results."""
186
+ mapped_results = [_map_fastmcp_tool_result(part) for part in parts]
187
+
188
+ if len(mapped_results) == 1:
189
+ return mapped_results[0]
190
+
191
+ return mapped_results
192
+
193
+
194
+ def _map_fastmcp_tool_result(part: ContentBlock) -> FastMCPToolResult:
195
+ if isinstance(part, TextContent):
196
+ return part.text
197
+ elif isinstance(part, ImageContent | AudioContent):
198
+ return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
199
+ elif isinstance(part, EmbeddedResource):
200
+ if isinstance(part.resource, BlobResourceContents):
201
+ return messages.BinaryContent(
202
+ data=base64.b64decode(part.resource.blob),
203
+ media_type=part.resource.mimeType or UNKNOWN_BINARY_MEDIA_TYPE,
204
+ )
205
+ elif isinstance(part.resource, TextResourceContents):
206
+ return part.resource.text
207
+ else:
208
+ assert_never(part.resource)
209
+ elif isinstance(part, ResourceLink):
210
+ # ResourceLink is not yet supported by the FastMCP toolset as reading resources is not yet supported.
211
+ raise NotImplementedError(
212
+ 'ResourceLink is not supported by the FastMCP toolset as reading resources is not yet supported.'
213
+ )
214
+ else:
215
+ assert_never(part)
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from ._adapter import StateDeps, StateHandler, UIAdapter
4
+ from ._event_stream import SSE_CONTENT_TYPE, NativeEvent, OnCompleteFunc, UIEventStream
5
+ from ._messages_builder import MessagesBuilder
6
+
7
+ __all__ = [
8
+ 'UIAdapter',
9
+ 'UIEventStream',
10
+ 'SSE_CONTENT_TYPE',
11
+ 'StateDeps',
12
+ 'StateHandler',
13
+ 'NativeEvent',
14
+ 'OnCompleteFunc',
15
+ 'MessagesBuilder',
16
+ ]