pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. pydantic_ai/__init__.py +6 -0
  2. pydantic_ai/_agent_graph.py +67 -20
  3. pydantic_ai/_cli.py +2 -2
  4. pydantic_ai/_output.py +20 -12
  5. pydantic_ai/_run_context.py +6 -2
  6. pydantic_ai/_utils.py +26 -8
  7. pydantic_ai/ag_ui.py +50 -696
  8. pydantic_ai/agent/__init__.py +13 -25
  9. pydantic_ai/agent/abstract.py +146 -9
  10. pydantic_ai/builtin_tools.py +106 -4
  11. pydantic_ai/direct.py +16 -4
  12. pydantic_ai/durable_exec/dbos/_agent.py +3 -0
  13. pydantic_ai/durable_exec/prefect/_agent.py +3 -0
  14. pydantic_ai/durable_exec/temporal/__init__.py +11 -0
  15. pydantic_ai/durable_exec/temporal/_agent.py +3 -0
  16. pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
  17. pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
  18. pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
  19. pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
  20. pydantic_ai/exceptions.py +6 -1
  21. pydantic_ai/mcp.py +1 -22
  22. pydantic_ai/messages.py +46 -8
  23. pydantic_ai/models/__init__.py +87 -38
  24. pydantic_ai/models/anthropic.py +132 -11
  25. pydantic_ai/models/bedrock.py +4 -4
  26. pydantic_ai/models/cohere.py +0 -7
  27. pydantic_ai/models/gemini.py +9 -2
  28. pydantic_ai/models/google.py +26 -23
  29. pydantic_ai/models/groq.py +13 -5
  30. pydantic_ai/models/huggingface.py +2 -2
  31. pydantic_ai/models/openai.py +251 -52
  32. pydantic_ai/models/outlines.py +563 -0
  33. pydantic_ai/models/test.py +6 -3
  34. pydantic_ai/profiles/openai.py +7 -0
  35. pydantic_ai/providers/__init__.py +25 -12
  36. pydantic_ai/providers/anthropic.py +2 -2
  37. pydantic_ai/providers/bedrock.py +60 -16
  38. pydantic_ai/providers/gateway.py +60 -72
  39. pydantic_ai/providers/google.py +91 -24
  40. pydantic_ai/providers/openrouter.py +3 -0
  41. pydantic_ai/providers/outlines.py +40 -0
  42. pydantic_ai/providers/ovhcloud.py +95 -0
  43. pydantic_ai/result.py +173 -8
  44. pydantic_ai/run.py +40 -24
  45. pydantic_ai/settings.py +8 -0
  46. pydantic_ai/tools.py +10 -6
  47. pydantic_ai/toolsets/fastmcp.py +215 -0
  48. pydantic_ai/ui/__init__.py +16 -0
  49. pydantic_ai/ui/_adapter.py +386 -0
  50. pydantic_ai/ui/_event_stream.py +591 -0
  51. pydantic_ai/ui/_messages_builder.py +28 -0
  52. pydantic_ai/ui/ag_ui/__init__.py +9 -0
  53. pydantic_ai/ui/ag_ui/_adapter.py +187 -0
  54. pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
  55. pydantic_ai/ui/ag_ui/app.py +148 -0
  56. pydantic_ai/ui/vercel_ai/__init__.py +16 -0
  57. pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
  58. pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
  59. pydantic_ai/ui/vercel_ai/_utils.py +16 -0
  60. pydantic_ai/ui/vercel_ai/request_types.py +275 -0
  61. pydantic_ai/ui/vercel_ai/response_types.py +230 -0
  62. pydantic_ai/usage.py +13 -2
  63. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
  64. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
  65. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
  66. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
  67. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/__init__.py CHANGED
@@ -12,6 +12,7 @@ from .agent import (
12
12
  from .builtin_tools import (
13
13
  CodeExecutionTool,
14
14
  ImageGenerationTool,
15
+ MCPServerTool,
15
16
  MemoryTool,
16
17
  UrlContextTool,
17
18
  WebSearchTool,
@@ -22,6 +23,7 @@ from .exceptions import (
22
23
  ApprovalRequired,
23
24
  CallDeferred,
24
25
  FallbackExceptionGroup,
26
+ IncompleteToolCall,
25
27
  ModelHTTPError,
26
28
  ModelRetry,
27
29
  UnexpectedModelBehavior,
@@ -63,6 +65,7 @@ from .messages import (
63
65
  ModelResponseStreamEvent,
64
66
  MultiModalContent,
65
67
  PartDeltaEvent,
68
+ PartEndEvent,
66
69
  PartStartEvent,
67
70
  RetryPromptPart,
68
71
  SystemPromptPart,
@@ -124,6 +127,7 @@ __all__ = (
124
127
  'ModelRetry',
125
128
  'ModelHTTPError',
126
129
  'FallbackExceptionGroup',
130
+ 'IncompleteToolCall',
127
131
  'UnexpectedModelBehavior',
128
132
  'UsageLimitExceeded',
129
133
  'UserError',
@@ -161,6 +165,7 @@ __all__ = (
161
165
  'ModelResponseStreamEvent',
162
166
  'MultiModalContent',
163
167
  'PartDeltaEvent',
168
+ 'PartEndEvent',
164
169
  'PartStartEvent',
165
170
  'RetryPromptPart',
166
171
  'SystemPromptPart',
@@ -211,6 +216,7 @@ __all__ = (
211
216
  'CodeExecutionTool',
212
217
  'ImageGenerationTool',
213
218
  'MemoryTool',
219
+ 'MCPServerTool',
214
220
  # output
215
221
  'ToolOutput',
216
222
  'NativeOutput',
@@ -20,7 +20,8 @@ from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION
20
20
  from pydantic_ai._tool_manager import ToolManager
21
21
  from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
22
22
  from pydantic_ai.builtin_tools import AbstractBuiltinTool
23
- from pydantic_graph import BaseNode, Graph, GraphRunContext
23
+ from pydantic_graph import BaseNode, GraphRunContext
24
+ from pydantic_graph.beta import Graph, GraphBuilder
24
25
  from pydantic_graph.nodes import End, NodeRunEndT
25
26
 
26
27
  from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
@@ -92,9 +93,28 @@ class GraphAgentState:
92
93
  retries: int = 0
93
94
  run_step: int = 0
94
95
 
95
- def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
96
+ def increment_retries(
97
+ self,
98
+ max_result_retries: int,
99
+ error: BaseException | None = None,
100
+ model_settings: ModelSettings | None = None,
101
+ ) -> None:
96
102
  self.retries += 1
97
103
  if self.retries > max_result_retries:
104
+ if (
105
+ self.message_history
106
+ and isinstance(model_response := self.message_history[-1], _messages.ModelResponse)
107
+ and model_response.finish_reason == 'length'
108
+ and model_response.parts
109
+ and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart)
110
+ ):
111
+ try:
112
+ tool_call.args_as_dict()
113
+ except Exception:
114
+ max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None
115
+ raise exceptions.IncompleteToolCall(
116
+ f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.'
117
+ )
98
118
  message = f'Exceeded maximum retries ({max_result_retries}) for output validation'
99
119
  if error:
100
120
  if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
@@ -247,6 +267,9 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
247
267
 
248
268
  next_message.instructions = await ctx.deps.get_instructions(run_context)
249
269
 
270
+ if not messages and not next_message.parts and not next_message.instructions:
271
+ raise exceptions.UserError('No message history, user prompt, or instructions provided')
272
+
250
273
  return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
251
274
 
252
275
  async def _handle_deferred_tool_results( # noqa: C901
@@ -568,8 +591,12 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
568
591
  # resubmit the most recent request that resulted in an empty response,
569
592
  # as the empty response and request will not create any items in the API payload,
570
593
  # in the hope the model will return a non-empty response this time.
571
- ctx.state.increment_retries(ctx.deps.max_result_retries)
572
- self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
594
+ ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
595
+ run_context = build_run_context(ctx)
596
+ instructions = await ctx.deps.get_instructions(run_context)
597
+ self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
598
+ _messages.ModelRequest(parts=[], instructions=instructions)
599
+ )
573
600
  return
574
601
 
575
602
  text = ''
@@ -630,8 +657,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
630
657
  )
631
658
  raise ToolRetryError(m)
632
659
  except ToolRetryError as e:
633
- ctx.state.increment_retries(ctx.deps.max_result_retries, e)
634
- self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
660
+ ctx.state.increment_retries(
661
+ ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
662
+ )
663
+ run_context = build_run_context(ctx)
664
+ instructions = await ctx.deps.get_instructions(run_context)
665
+ self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
666
+ _messages.ModelRequest(parts=[e.tool_retry], instructions=instructions)
667
+ )
635
668
 
636
669
  self._events_iterator = _run_stream()
637
670
 
@@ -788,10 +821,14 @@ async def process_tool_calls( # noqa: C901
788
821
  try:
789
822
  result_data = await tool_manager.handle_call(call)
790
823
  except exceptions.UnexpectedModelBehavior as e:
791
- ctx.state.increment_retries(ctx.deps.max_result_retries, e)
824
+ ctx.state.increment_retries(
825
+ ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
826
+ )
792
827
  raise e # pragma: lax no cover
793
828
  except ToolRetryError as e:
794
- ctx.state.increment_retries(ctx.deps.max_result_retries, e)
829
+ ctx.state.increment_retries(
830
+ ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings
831
+ )
795
832
  yield _messages.FunctionToolCallEvent(call)
796
833
  output_parts.append(e.tool_retry)
797
834
  yield _messages.FunctionToolResultEvent(e.tool_retry)
@@ -820,7 +857,7 @@ async def process_tool_calls( # noqa: C901
820
857
 
821
858
  # Then, we handle unknown tool calls
822
859
  if tool_calls_by_kind['unknown']:
823
- ctx.state.increment_retries(ctx.deps.max_result_retries)
860
+ ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
824
861
  calls_to_run.extend(tool_calls_by_kind['unknown'])
825
862
 
826
863
  calls_to_run_results: dict[str, DeferredToolResult] = {}
@@ -1129,22 +1166,32 @@ def build_agent_graph(
1129
1166
  name: str | None,
1130
1167
  deps_type: type[DepsT],
1131
1168
  output_type: OutputSpec[OutputT],
1132
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
1169
+ ) -> Graph[
1170
+ GraphAgentState,
1171
+ GraphAgentDeps[DepsT, OutputT],
1172
+ UserPromptNode[DepsT, OutputT],
1173
+ result.FinalResult[OutputT],
1174
+ ]:
1133
1175
  """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
1134
- nodes = (
1135
- UserPromptNode[DepsT],
1136
- ModelRequestNode[DepsT],
1137
- CallToolsNode[DepsT],
1138
- SetFinalResult[DepsT],
1139
- )
1140
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
1141
- nodes=nodes,
1176
+ g = GraphBuilder(
1142
1177
  name=name or 'Agent',
1143
1178
  state_type=GraphAgentState,
1144
- run_end_type=result.FinalResult[OutputT],
1179
+ deps_type=GraphAgentDeps[DepsT, OutputT],
1180
+ input_type=UserPromptNode[DepsT, OutputT],
1181
+ output_type=result.FinalResult[OutputT],
1145
1182
  auto_instrument=False,
1146
1183
  )
1147
- return graph
1184
+
1185
+ g.add(
1186
+ g.edge_from(g.start_node).to(UserPromptNode[DepsT, OutputT]),
1187
+ g.node(UserPromptNode[DepsT, OutputT]),
1188
+ g.node(ModelRequestNode[DepsT, OutputT]),
1189
+ g.node(CallToolsNode[DepsT, OutputT]),
1190
+ g.node(
1191
+ SetFinalResult[DepsT, OutputT],
1192
+ ),
1193
+ )
1194
+ return g.build(validate_graph_structure=False)
1148
1195
 
1149
1196
 
1150
1197
  async def _process_message_history(
pydantic_ai/_cli.py CHANGED
@@ -103,7 +103,7 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover
103
103
 
104
104
 
105
105
  def cli( # noqa: C901
106
- args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-4.1'
106
+ args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-5'
107
107
  ) -> int:
108
108
  """Run the CLI and return the exit code for the process."""
109
109
  parser = argparse.ArgumentParser(
@@ -124,7 +124,7 @@ Special prompts:
124
124
  '-m',
125
125
  '--model',
126
126
  nargs='?',
127
- help=f'Model to use, in format "<provider>:<model>" e.g. "openai:gpt-4.1" or "anthropic:claude-sonnet-4-0". Defaults to "{default_model}".',
127
+ help=f'Model to use, in format "<provider>:<model>" e.g. "openai:gpt-5" or "anthropic:claude-sonnet-4-5". Defaults to "{default_model}".',
128
128
  )
129
129
  # we don't want to autocomplete or list models that don't include the provider,
130
130
  # e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
pydantic_ai/_output.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import json
5
+ import re
5
6
  from abc import ABC, abstractmethod
6
7
  from collections.abc import Awaitable, Callable, Sequence
7
8
  from dataclasses import dataclass, field
@@ -70,6 +71,7 @@ Usage `OutputValidatorFunc[AgentDepsT, T]`.
70
71
 
71
72
  DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
72
73
  DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
74
+ OUTPUT_TOOL_NAME_SANITIZER = re.compile(r'[^a-zA-Z0-9-_]')
73
75
 
74
76
 
75
77
  async def execute_traced_output_function(
@@ -554,6 +556,20 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
554
556
  def mode(self) -> OutputMode:
555
557
  return 'prompted'
556
558
 
559
+ @classmethod
560
+ def build_instructions(cls, template: str, object_def: OutputObjectDefinition) -> str:
561
+ """Build instructions from a template and an object definition."""
562
+ schema = object_def.json_schema.copy()
563
+ if object_def.name:
564
+ schema['title'] = object_def.name
565
+ if object_def.description:
566
+ schema['description'] = object_def.description
567
+
568
+ if '{schema}' not in template:
569
+ template = '\n\n'.join([template, '{schema}'])
570
+
571
+ return template.format(schema=json.dumps(schema))
572
+
557
573
  def raise_if_unsupported(self, profile: ModelProfile) -> None:
558
574
  """Raise an error if the mode is not supported by this model."""
559
575
  super().raise_if_unsupported(profile)
@@ -561,18 +577,8 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
561
577
  def instructions(self, default_template: str) -> str:
562
578
  """Get instructions to tell model to output JSON matching the schema."""
563
579
  template = self.template or default_template
564
-
565
- if '{schema}' not in template:
566
- template = '\n\n'.join([template, '{schema}'])
567
-
568
580
  object_def = self.object_def
569
- schema = object_def.json_schema.copy()
570
- if object_def.name:
571
- schema['title'] = object_def.name
572
- if object_def.description:
573
- schema['description'] = object_def.description
574
-
575
- return template.format(schema=json.dumps(schema))
581
+ return self.build_instructions(template, object_def)
576
582
 
577
583
 
578
584
  @dataclass(init=False)
@@ -997,7 +1003,9 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
997
1003
  if name is None:
998
1004
  name = default_name
999
1005
  if multiple:
1000
- name += f'_{object_def.name}'
1006
+ # strip unsupported characters like "[" and "]" from generic class names
1007
+ safe_name = OUTPUT_TOOL_NAME_SANITIZER.sub('', object_def.name or '')
1008
+ name += f'_{safe_name}'
1001
1009
 
1002
1010
  i = 1
1003
1011
  original_name = name
@@ -16,15 +16,19 @@ if TYPE_CHECKING:
16
16
  from .models import Model
17
17
  from .result import RunUsage
18
18
 
19
+ # TODO (v2): Change the default for all typevars like this from `None` to `object`
19
20
  AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
20
21
  """Type variable for agent dependencies."""
21
22
 
23
+ RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True)
24
+ """Type variable for the agent dependencies in `RunContext`."""
25
+
22
26
 
23
27
  @dataclasses.dataclass(repr=False, kw_only=True)
24
- class RunContext(Generic[AgentDepsT]):
28
+ class RunContext(Generic[RunContextAgentDepsT]):
25
29
  """Information about the current call."""
26
30
 
27
- deps: AgentDepsT
31
+ deps: RunContextAgentDepsT
28
32
  """Dependencies for the agent."""
29
33
  model: Model
30
34
  """The model used in this run."""
pydantic_ai/_utils.py CHANGED
@@ -147,7 +147,7 @@ async def group_by_temporal(
147
147
  aiterable: The async iterable to group.
148
148
  soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
149
149
  a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
150
- as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
150
+ as soon as `anext(aiter)` returns. If `None`, no grouping/debouncing is performed
151
151
 
152
152
  Returns:
153
153
  A context manager usable as an async iterable of lists of items produced by the input async iterable.
@@ -171,7 +171,7 @@ async def group_by_temporal(
171
171
  buffer: list[T] = []
172
172
  group_start_time = time.monotonic()
173
173
 
174
- aiterator = aiterable.__aiter__()
174
+ aiterator = aiter(aiterable)
175
175
  while True:
176
176
  if group_start_time is None:
177
177
  # group hasn't started, we just wait for the maximum interval
@@ -182,9 +182,9 @@ async def group_by_temporal(
182
182
 
183
183
  # if there's no current task, we get the next one
184
184
  if task is None:
185
- # aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
185
+ # anext(aiter) returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
186
186
  # so far, this doesn't seem to be a problem
187
- task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
187
+ task = asyncio.create_task(anext(aiterator)) # pyright: ignore[reportArgumentType]
188
188
 
189
189
  # we use asyncio.wait to avoid cancelling the coroutine if it's not done
190
190
  done, _ = await asyncio.wait((task,), timeout=wait_time)
@@ -234,6 +234,15 @@ def sync_anext(iterator: Iterator[T]) -> T:
234
234
  raise StopAsyncIteration() from e
235
235
 
236
236
 
237
+ def sync_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
238
+ loop = get_event_loop()
239
+ while True:
240
+ try:
241
+ yield loop.run_until_complete(anext(async_iter))
242
+ except StopAsyncIteration:
243
+ break
244
+
245
+
237
246
  def now_utc() -> datetime:
238
247
  return datetime.now(tz=timezone.utc)
239
248
 
@@ -284,10 +293,10 @@ class PeekableAsyncStream(Generic[T]):
284
293
 
285
294
  # Otherwise, we need to fetch the next item from the underlying iterator.
286
295
  if self._source_iter is None:
287
- self._source_iter = self._source.__aiter__()
296
+ self._source_iter = aiter(self._source)
288
297
 
289
298
  try:
290
- self._buffer = await self._source_iter.__anext__()
299
+ self._buffer = await anext(self._source_iter)
291
300
  except StopAsyncIteration:
292
301
  self._exhausted = True
293
302
  return UNSET
@@ -318,10 +327,10 @@ class PeekableAsyncStream(Generic[T]):
318
327
 
319
328
  # Otherwise, fetch the next item from the source.
320
329
  if self._source_iter is None:
321
- self._source_iter = self._source.__aiter__()
330
+ self._source_iter = aiter(self._source)
322
331
 
323
332
  try:
324
- return await self._source_iter.__anext__()
333
+ return await anext(self._source_iter)
325
334
  except StopAsyncIteration:
326
335
  self._exhausted = True
327
336
  raise
@@ -489,3 +498,12 @@ def get_union_args(tp: Any) -> tuple[Any, ...]:
489
498
  return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
490
499
  else:
491
500
  return ()
501
+
502
+
503
+ def get_event_loop():
504
+ try:
505
+ event_loop = asyncio.get_event_loop()
506
+ except RuntimeError: # pragma: lax no cover
507
+ event_loop = asyncio.new_event_loop()
508
+ asyncio.set_event_loop(event_loop)
509
+ return event_loop