pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from typing import overload
4
+
5
+ from httpx import AsyncClient as AsyncHTTPClient
6
+ from openai import AsyncOpenAI
7
+
8
+ from pydantic_ai.models import cached_async_http_client
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.amazon import amazon_model_profile
11
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
12
+ from pydantic_ai.profiles.cohere import cohere_model_profile
13
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
14
+ from pydantic_ai.profiles.google import google_model_profile
15
+ from pydantic_ai.profiles.grok import grok_model_profile
16
+ from pydantic_ai.profiles.groq import groq_model_profile
17
+ from pydantic_ai.profiles.meta import meta_model_profile
18
+ from pydantic_ai.profiles.mistral import mistral_model_profile
19
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
20
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21
+ from pydantic_ai.profiles.qwen import qwen_model_profile
22
+ from pydantic_ai.providers import Provider
23
+
24
+ try:
25
+ from openai import AsyncOpenAI
26
+ except ImportError as _import_error: # pragma: no cover
27
+ raise ImportError(
28
+ 'Please install the `openai` package to use the LiteLLM provider, '
29
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
30
+ ) from _import_error
31
+
32
+
33
+ class LiteLLMProvider(Provider[AsyncOpenAI]):
34
+ """Provider for LiteLLM API."""
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return 'litellm'
39
+
40
+ @property
41
+ def base_url(self) -> str:
42
+ return str(self.client.base_url)
43
+
44
+ @property
45
+ def client(self) -> AsyncOpenAI:
46
+ return self._client
47
+
48
+ def model_profile(self, model_name: str) -> ModelProfile | None:
49
+ # Map provider prefixes to their profile functions
50
+ provider_to_profile = {
51
+ 'anthropic': anthropic_model_profile,
52
+ 'openai': openai_model_profile,
53
+ 'google': google_model_profile,
54
+ 'mistralai': mistral_model_profile,
55
+ 'mistral': mistral_model_profile,
56
+ 'cohere': cohere_model_profile,
57
+ 'amazon': amazon_model_profile,
58
+ 'bedrock': amazon_model_profile,
59
+ 'meta-llama': meta_model_profile,
60
+ 'meta': meta_model_profile,
61
+ 'groq': groq_model_profile,
62
+ 'deepseek': deepseek_model_profile,
63
+ 'moonshotai': moonshotai_model_profile,
64
+ 'x-ai': grok_model_profile,
65
+ 'qwen': qwen_model_profile,
66
+ }
67
+
68
+ profile = None
69
+
70
+ # Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
71
+ if '/' in model_name:
72
+ provider_prefix, model_suffix = model_name.split('/', 1)
73
+ if provider_prefix in provider_to_profile:
74
+ profile = provider_to_profile[provider_prefix](model_suffix)
75
+
76
+ # If no profile found, default to OpenAI profile
77
+ if profile is None:
78
+ profile = openai_model_profile(model_name)
79
+
80
+ # As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
81
+ # we maintain that behavior
82
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
83
+
84
+ @overload
85
+ def __init__(
86
+ self,
87
+ *,
88
+ api_key: str | None = None,
89
+ api_base: str | None = None,
90
+ ) -> None: ...
91
+
92
+ @overload
93
+ def __init__(
94
+ self,
95
+ *,
96
+ api_key: str | None = None,
97
+ api_base: str | None = None,
98
+ http_client: AsyncHTTPClient,
99
+ ) -> None: ...
100
+
101
+ @overload
102
+ def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...
103
+
104
+ def __init__(
105
+ self,
106
+ *,
107
+ api_key: str | None = None,
108
+ api_base: str | None = None,
109
+ openai_client: AsyncOpenAI | None = None,
110
+ http_client: AsyncHTTPClient | None = None,
111
+ ) -> None:
112
+ """Initialize a LiteLLM provider.
113
+
114
+ Args:
115
+ api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
116
+ api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
117
+ openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
118
+ http_client: Custom HTTP client to use.
119
+ """
120
+ if openai_client is not None:
121
+ self._client = openai_client
122
+ return
123
+
124
+ # Create OpenAI client that will be used with LiteLLM's completion function
125
+ # The actual API calls will be intercepted and routed through LiteLLM
126
+ if http_client is not None:
127
+ self._client = AsyncOpenAI(
128
+ base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
129
+ )
130
+ else:
131
+ http_client = cached_async_http_client(provider='litellm')
132
+ self._client = AsyncOpenAI(
133
+ base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
134
+ )
pydantic_ai/result.py CHANGED
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Awaitable, Callable
3
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
4
4
  from copy import copy
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime
7
- from typing import Generic, cast
7
+ from typing import Generic, cast, overload
8
8
 
9
9
  from pydantic import ValidationError
10
- from typing_extensions import TypeVar
11
-
12
- from pydantic_ai._tool_manager import ToolManager
10
+ from typing_extensions import TypeVar, deprecated
13
11
 
14
12
  from . import _utils, exceptions, messages as _messages, models
15
13
  from ._output import (
@@ -22,11 +20,14 @@ from ._output import (
22
20
  ToolOutputSchema,
23
21
  )
24
22
  from ._run_context import AgentDepsT, RunContext
23
+ from ._tool_manager import ToolManager
25
24
  from .messages import ModelResponseStreamEvent
26
25
  from .output import (
26
+ DeferredToolRequests,
27
27
  OutputDataT,
28
28
  ToolOutput,
29
29
  )
30
+ from .run import AgentRunResult
30
31
  from .usage import RunUsage, UsageLimits
31
32
 
32
33
  __all__ = (
@@ -41,7 +42,7 @@ T = TypeVar('T')
41
42
  """An invariant TypeVar."""
42
43
 
43
44
 
44
- @dataclass
45
+ @dataclass(kw_only=True)
45
46
  class AgentStream(Generic[AgentDepsT, OutputDataT]):
46
47
  _raw_stream_response: models.StreamedResponse
47
48
  _output_schema: OutputSchema[OutputDataT]
@@ -62,11 +63,11 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
62
63
  async for response in self.stream_responses(debounce_by=debounce_by):
63
64
  if self._raw_stream_response.final_result_event is not None:
64
65
  try:
65
- yield await self._validate_response(response, allow_partial=True)
66
+ yield await self.validate_response_output(response, allow_partial=True)
66
67
  except ValidationError:
67
68
  pass
68
69
  if self._raw_stream_response.final_result_event is not None: # pragma: no branch
69
- yield await self._validate_response(self._raw_stream_response.get())
70
+ yield await self.validate_response_output(self._raw_stream_response.get())
70
71
 
71
72
  async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
72
73
  """Asynchronously stream the (unvalidated) model responses for the agent."""
@@ -127,9 +128,11 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
127
128
  async for _ in self:
128
129
  pass
129
130
 
130
- return await self._validate_response(self._raw_stream_response.get())
131
+ return await self.validate_response_output(self._raw_stream_response.get())
131
132
 
132
- async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
133
+ async def validate_response_output(
134
+ self, message: _messages.ModelResponse, *, allow_partial: bool = False
135
+ ) -> OutputDataT:
133
136
  """Validate a structured result message."""
134
137
  final_result_event = self._raw_stream_response.final_result_event
135
138
  if final_result_event is None:
@@ -153,12 +156,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
153
156
  return await self._tool_manager.handle_call(
154
157
  tool_call, allow_partial=allow_partial, wrap_validation_errors=False
155
158
  )
156
- elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
157
- if not self._output_schema.allows_deferred_tool_calls:
159
+ elif deferred_tool_requests := _get_deferred_tool_requests(message.parts, self._tool_manager):
160
+ if not self._output_schema.allows_deferred_tools:
158
161
  raise exceptions.UserError(
159
- 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
162
+ 'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
160
163
  )
161
- return cast(OutputDataT, deferred_tool_calls)
164
+ return cast(OutputDataT, deferred_tool_requests)
162
165
  elif isinstance(self._output_schema, TextOutputSchema):
163
166
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
164
167
 
@@ -231,26 +234,61 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
231
234
  return self._agent_stream_iterator
232
235
 
233
236
 
234
- @dataclass
237
+ @dataclass(init=False)
235
238
  class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
236
239
  """Result of a streamed run that returns structured data via a tool call."""
237
240
 
238
241
  _all_messages: list[_messages.ModelMessage]
239
242
  _new_message_index: int
240
243
 
241
- _stream_response: AgentStream[AgentDepsT, OutputDataT]
242
- _on_complete: Callable[[], Awaitable[None]]
244
+ _stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None
245
+ _on_complete: Callable[[], Awaitable[None]] | None = None
246
+
247
+ _run_result: AgentRunResult[OutputDataT] | None = None
243
248
 
244
249
  is_complete: bool = field(default=False, init=False)
245
250
  """Whether the stream has all been received.
246
251
 
247
252
  This is set to `True` when one of
248
- [`stream`][pydantic_ai.result.StreamedRunResult.stream],
253
+ [`stream_output`][pydantic_ai.result.StreamedRunResult.stream_output],
249
254
  [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text],
250
- [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or
255
+ [`stream_responses`][pydantic_ai.result.StreamedRunResult.stream_responses] or
251
256
  [`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
252
257
  """
253
258
 
259
+ @overload
260
+ def __init__(
261
+ self,
262
+ all_messages: list[_messages.ModelMessage],
263
+ new_message_index: int,
264
+ stream_response: AgentStream[AgentDepsT, OutputDataT] | None,
265
+ on_complete: Callable[[], Awaitable[None]] | None,
266
+ ) -> None: ...
267
+
268
+ @overload
269
+ def __init__(
270
+ self,
271
+ all_messages: list[_messages.ModelMessage],
272
+ new_message_index: int,
273
+ *,
274
+ run_result: AgentRunResult[OutputDataT],
275
+ ) -> None: ...
276
+
277
+ def __init__(
278
+ self,
279
+ all_messages: list[_messages.ModelMessage],
280
+ new_message_index: int,
281
+ stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None,
282
+ on_complete: Callable[[], Awaitable[None]] | None = None,
283
+ run_result: AgentRunResult[OutputDataT] | None = None,
284
+ ) -> None:
285
+ self._all_messages = all_messages
286
+ self._new_message_index = new_message_index
287
+
288
+ self._stream_response = stream_response
289
+ self._on_complete = on_complete
290
+ self._run_result = run_result
291
+
254
292
  def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
255
293
  """Return the history of _messages.
256
294
 
@@ -318,24 +356,35 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
318
356
  self.new_messages(output_tool_return_content=output_tool_return_content)
319
357
  )
320
358
 
359
+ @deprecated('`StreamedRunResult.stream` is deprecated, use `stream_output` instead.')
321
360
  async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
322
- """Stream the response as an async iterable.
361
+ async for output in self.stream_output(debounce_by=debounce_by):
362
+ yield output
363
+
364
+ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
365
+ """Stream the output as an async iterable.
323
366
 
324
367
  The pydantic validator for structured data will be called in
325
368
  [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
326
369
  on each iteration.
327
370
 
328
371
  Args:
329
- debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
330
- Debouncing is particularly important for long structured responses to reduce the overhead of
372
+ debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
373
+ Debouncing is particularly important for long structured outputs to reduce the overhead of
331
374
  performing validation as each token is received.
332
375
 
333
376
  Returns:
334
377
  An async iterable of the response data.
335
378
  """
336
- async for output in self._stream_response.stream_output(debounce_by=debounce_by):
337
- yield output
338
- await self._marked_completed(self._stream_response.get())
379
+ if self._run_result is not None:
380
+ yield self._run_result.output
381
+ await self._marked_completed()
382
+ elif self._stream_response is not None:
383
+ async for output in self._stream_response.stream_output(debounce_by=debounce_by):
384
+ yield output
385
+ await self._marked_completed(self._stream_response.get())
386
+ else:
387
+ raise ValueError('No stream response or run result provided') # pragma: no cover
339
388
 
340
389
  async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
341
390
  """Stream the text result as an async iterable.
@@ -350,12 +399,30 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
350
399
  Debouncing is particularly important for long structured responses to reduce the overhead of
351
400
  performing validation as each token is received.
352
401
  """
353
- async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
354
- yield text
355
- await self._marked_completed(self._stream_response.get())
402
+ if self._run_result is not None: # pragma: no cover
403
+ # We can't really get here, as `_run_result` is only set in `run_stream` when `CallToolsNode` produces `DeferredToolRequests` output
404
+ # as a result of a tool function raising `CallDeferred` or `ApprovalRequired`.
405
+ # That'll change if we ever support something like `raise EndRun(output: OutputT)` where `OutputT` could be `str`.
406
+ if not isinstance(self._run_result.output, str):
407
+ raise exceptions.UserError('stream_text() can only be used with text responses')
408
+ yield self._run_result.output
409
+ await self._marked_completed()
410
+ elif self._stream_response is not None:
411
+ async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
412
+ yield text
413
+ await self._marked_completed(self._stream_response.get())
414
+ else:
415
+ raise ValueError('No stream response or run result provided') # pragma: no cover
356
416
 
417
+ @deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
357
418
  async def stream_structured(
358
419
  self, *, debounce_by: float | None = 0.1
420
+ ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
421
+ async for msg, last in self.stream_responses(debounce_by=debounce_by):
422
+ yield msg, last
423
+
424
+ async def stream_responses(
425
+ self, *, debounce_by: float | None = 0.1
359
426
  ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
360
427
  """Stream the response as an async iterable of Structured LLM Messages.
361
428
 
@@ -367,20 +434,34 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
367
434
  Returns:
368
435
  An async iterable of the structured response message and whether that is the last message.
369
436
  """
370
- # if the message currently has any parts with content, yield before streaming
371
- async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
372
- yield msg, False
373
-
374
- msg = self._stream_response.get()
375
- yield msg, True
376
-
377
- await self._marked_completed(msg)
437
+ if self._run_result is not None:
438
+ model_response = cast(_messages.ModelResponse, self.all_messages()[-1])
439
+ yield model_response, True
440
+ await self._marked_completed()
441
+ elif self._stream_response is not None:
442
+ # if the message currently has any parts with content, yield before streaming
443
+ async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
444
+ yield msg, False
445
+
446
+ msg = self._stream_response.get()
447
+ yield msg, True
448
+
449
+ await self._marked_completed(msg)
450
+ else:
451
+ raise ValueError('No stream response or run result provided') # pragma: no cover
378
452
 
379
453
  async def get_output(self) -> OutputDataT:
380
454
  """Stream the whole response, validate and return it."""
381
- output = await self._stream_response.get_output()
382
- await self._marked_completed(self._stream_response.get())
383
- return output
455
+ if self._run_result is not None:
456
+ output = self._run_result.output
457
+ await self._marked_completed()
458
+ return output
459
+ elif self._stream_response is not None:
460
+ output = await self._stream_response.get_output()
461
+ await self._marked_completed(self._stream_response.get())
462
+ return output
463
+ else:
464
+ raise ValueError('No stream response or run result provided') # pragma: no cover
384
465
 
385
466
  def usage(self) -> RunUsage:
386
467
  """Return the usage of the whole run.
@@ -388,24 +469,45 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
388
469
  !!! note
389
470
  This won't return the full usage until the stream is finished.
390
471
  """
391
- return self._stream_response.usage()
472
+ if self._run_result is not None:
473
+ return self._run_result.usage()
474
+ elif self._stream_response is not None:
475
+ return self._stream_response.usage()
476
+ else:
477
+ raise ValueError('No stream response or run result provided') # pragma: no cover
392
478
 
393
479
  def timestamp(self) -> datetime:
394
480
  """Get the timestamp of the response."""
395
- return self._stream_response.timestamp()
481
+ if self._run_result is not None:
482
+ return self._run_result.timestamp()
483
+ elif self._stream_response is not None:
484
+ return self._stream_response.timestamp()
485
+ else:
486
+ raise ValueError('No stream response or run result provided') # pragma: no cover
396
487
 
488
+ @deprecated('`validate_structured_output` is deprecated, use `validate_response_output` instead.')
397
489
  async def validate_structured_output(
398
490
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
491
+ ) -> OutputDataT:
492
+ return await self.validate_response_output(message, allow_partial=allow_partial)
493
+
494
+ async def validate_response_output(
495
+ self, message: _messages.ModelResponse, *, allow_partial: bool = False
399
496
  ) -> OutputDataT:
400
497
  """Validate a structured result message."""
401
- return await self._stream_response._validate_response( # pyright: ignore[reportPrivateUsage]
402
- message, allow_partial=allow_partial
403
- )
498
+ if self._run_result is not None:
499
+ return self._run_result.output
500
+ elif self._stream_response is not None:
501
+ return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
502
+ else:
503
+ raise ValueError('No stream response or run result provided') # pragma: no cover
404
504
 
405
- async def _marked_completed(self, message: _messages.ModelResponse) -> None:
505
+ async def _marked_completed(self, message: _messages.ModelResponse | None = None) -> None:
406
506
  self.is_complete = True
407
- self._all_messages.append(message)
408
- await self._on_complete()
507
+ if message is not None:
508
+ self._all_messages.append(message)
509
+ if self._on_complete is not None:
510
+ await self._on_complete()
409
511
 
410
512
 
411
513
  @dataclass(repr=False)
@@ -414,8 +516,10 @@ class FinalResult(Generic[OutputDataT]):
414
516
 
415
517
  output: OutputDataT
416
518
  """The final result data."""
519
+
417
520
  tool_name: str | None = None
418
521
  """Name of the final output tool; `None` if the output came from unstructured text content."""
522
+
419
523
  tool_call_id: str | None = None
420
524
  """ID of the tool call that produced the final output; `None` if the output came from unstructured text content."""
421
525
 
@@ -436,9 +540,26 @@ def _get_usage_checking_stream_response(
436
540
 
437
541
  return _usage_checking_iterator()
438
542
  else:
439
- # TODO: Use `return aiter(stream_response)` once we drop support for Python 3.9
440
- async def _iterator():
441
- async for item in stream_response:
442
- yield item
543
+ return aiter(stream_response)
544
+
545
+
546
+ def _get_deferred_tool_requests(
547
+ parts: Iterable[_messages.ModelResponsePart], tool_manager: ToolManager[AgentDepsT]
548
+ ) -> DeferredToolRequests | None:
549
+ """Get the deferred tool requests from the model response parts."""
550
+ approvals: list[_messages.ToolCallPart] = []
551
+ calls: list[_messages.ToolCallPart] = []
552
+
553
+ for part in parts:
554
+ if isinstance(part, _messages.ToolCallPart):
555
+ tool_def = tool_manager.get_tool_def(part.tool_name)
556
+ if tool_def is not None: # pragma: no branch
557
+ if tool_def.kind == 'unapproved':
558
+ approvals.append(part)
559
+ elif tool_def.kind == 'external':
560
+ calls.append(part)
561
+
562
+ if not calls and not approvals:
563
+ return None
443
564
 
444
- return _iterator()
565
+ return DeferredToolRequests(calls=calls, approvals=approvals)
pydantic_ai/retries.py CHANGED
@@ -13,6 +13,8 @@ The module includes:
13
13
 
14
14
  from __future__ import annotations
15
15
 
16
+ from types import TracebackType
17
+
16
18
  from httpx import (
17
19
  AsyncBaseTransport,
18
20
  AsyncHTTPTransport,
@@ -24,17 +26,17 @@ from httpx import (
24
26
  )
25
27
 
26
28
  try:
27
- from tenacity import AsyncRetrying, RetryCallState, RetryError, Retrying, retry, wait_exponential
29
+ from tenacity import RetryCallState, RetryError, retry, wait_exponential
28
30
  except ImportError as _import_error:
29
31
  raise ImportError(
30
32
  'Please install `tenacity` to use the retries utilities, '
31
33
  'you can use the `retries` optional group — `pip install "pydantic-ai-slim[retries]"`'
32
34
  ) from _import_error
33
35
 
34
- from collections.abc import Awaitable
36
+ from collections.abc import Awaitable, Callable
35
37
  from datetime import datetime, timezone
36
38
  from email.utils import parsedate_to_datetime
37
- from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast
39
+ from typing import TYPE_CHECKING, Any, cast
38
40
 
39
41
  from typing_extensions import TypedDict
40
42
 
@@ -134,8 +136,9 @@ class TenacityTransport(BaseTransport):
134
136
 
135
137
  Example:
136
138
  ```python
137
- from httpx import Client, HTTPTransport, HTTPStatusError
138
- from tenacity import stop_after_attempt, retry_if_exception_type
139
+ from httpx import Client, HTTPStatusError, HTTPTransport
140
+ from tenacity import retry_if_exception_type, stop_after_attempt
141
+
139
142
  from pydantic_ai.retries import RetryConfig, TenacityTransport, wait_retry_after
140
143
 
141
144
  transport = TenacityTransport(
@@ -157,18 +160,7 @@ class TenacityTransport(BaseTransport):
157
160
  config: RetryConfig,
158
161
  wrapped: BaseTransport | None = None,
159
162
  validate_response: Callable[[Response], Any] | None = None,
160
- **kwargs: NoReturn,
161
163
  ):
162
- # TODO: Remove the following checks (and **kwargs) during v1 release
163
- if 'controller' in kwargs: # pragma: no cover
164
- raise TypeError('The `controller` argument has been renamed to `config`, and now requires a `RetryConfig`.')
165
- if kwargs: # pragma: no cover
166
- raise TypeError(f'Unexpected keyword arguments: {", ".join(kwargs)}')
167
- if isinstance(config, Retrying): # pragma: no cover
168
- raise ValueError(
169
- 'Passing a Retrying instance is no longer supported; the `config` argument must be a `pydantic_ai.retries.RetryConfig`.'
170
- )
171
-
172
164
  self.config = config
173
165
  self.wrapped = wrapped or HTTPTransport()
174
166
  self.validate_response = validate_response
@@ -195,11 +187,30 @@ class TenacityTransport(BaseTransport):
195
187
  response.request = req
196
188
 
197
189
  if self.validate_response:
198
- self.validate_response(response)
190
+ try:
191
+ self.validate_response(response)
192
+ except Exception:
193
+ response.close()
194
+ raise
199
195
  return response
200
196
 
201
197
  return handle_request(request)
202
198
 
199
+ def __enter__(self) -> TenacityTransport:
200
+ self.wrapped.__enter__()
201
+ return self
202
+
203
+ def __exit__(
204
+ self,
205
+ exc_type: type[BaseException] | None = None,
206
+ exc_value: BaseException | None = None,
207
+ traceback: TracebackType | None = None,
208
+ ) -> None:
209
+ self.wrapped.__exit__(exc_type, exc_value, traceback)
210
+
211
+ def close(self) -> None:
212
+ self.wrapped.close() # pragma: no cover
213
+
203
214
 
204
215
  class AsyncTenacityTransport(AsyncBaseTransport):
205
216
  """Asynchronous HTTP transport with tenacity-based retry functionality.
@@ -224,7 +235,8 @@ class AsyncTenacityTransport(AsyncBaseTransport):
224
235
  Example:
225
236
  ```python
226
237
  from httpx import AsyncClient, HTTPStatusError
227
- from tenacity import stop_after_attempt, retry_if_exception_type
238
+ from tenacity import retry_if_exception_type, stop_after_attempt
239
+
228
240
  from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
229
241
 
230
242
  transport = AsyncTenacityTransport(
@@ -245,18 +257,7 @@ class AsyncTenacityTransport(AsyncBaseTransport):
245
257
  config: RetryConfig,
246
258
  wrapped: AsyncBaseTransport | None = None,
247
259
  validate_response: Callable[[Response], Any] | None = None,
248
- **kwargs: NoReturn,
249
260
  ):
250
- # TODO: Remove the following checks (and **kwargs) during v1 release
251
- if 'controller' in kwargs: # pragma: no cover
252
- raise TypeError('The `controller` argument has been renamed to `config`, and now requires a `RetryConfig`.')
253
- if kwargs: # pragma: no cover
254
- raise TypeError(f'Unexpected keyword arguments: {", ".join(kwargs)}')
255
- if isinstance(config, AsyncRetrying): # pragma: no cover
256
- raise ValueError(
257
- 'Passing an AsyncRetrying instance is no longer supported; the `config` argument must be a `pydantic_ai.retries.RetryConfig`.'
258
- )
259
-
260
261
  self.config = config
261
262
  self.wrapped = wrapped or AsyncHTTPTransport()
262
263
  self.validate_response = validate_response
@@ -283,11 +284,30 @@ class AsyncTenacityTransport(AsyncBaseTransport):
283
284
  response.request = req
284
285
 
285
286
  if self.validate_response:
286
- self.validate_response(response)
287
+ try:
288
+ self.validate_response(response)
289
+ except Exception:
290
+ await response.aclose()
291
+ raise
287
292
  return response
288
293
 
289
294
  return await handle_async_request(request)
290
295
 
296
+ async def __aenter__(self) -> AsyncTenacityTransport:
297
+ await self.wrapped.__aenter__()
298
+ return self
299
+
300
+ async def __aexit__(
301
+ self,
302
+ exc_type: type[BaseException] | None = None,
303
+ exc_value: BaseException | None = None,
304
+ traceback: TracebackType | None = None,
305
+ ) -> None:
306
+ await self.wrapped.__aexit__(exc_type, exc_value, traceback)
307
+
308
+ async def aclose(self) -> None:
309
+ await self.wrapped.aclose()
310
+
291
311
 
292
312
  def wait_retry_after(
293
313
  fallback_strategy: Callable[[RetryCallState], float] | None = None, max_wait: float = 300
@@ -314,7 +334,8 @@ def wait_retry_after(
314
334
  Example:
315
335
  ```python
316
336
  from httpx import AsyncClient, HTTPStatusError
317
- from tenacity import stop_after_attempt, retry_if_exception_type
337
+ from tenacity import retry_if_exception_type, stop_after_attempt
338
+
318
339
  from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
319
340
 
320
341
  transport = AsyncTenacityTransport(