pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (48) hide show
  1. pydantic_ai/_a2a.py +3 -3
  2. pydantic_ai/_agent_graph.py +220 -319
  3. pydantic_ai/_cli.py +9 -7
  4. pydantic_ai/_output.py +295 -331
  5. pydantic_ai/_parts_manager.py +2 -2
  6. pydantic_ai/_run_context.py +8 -14
  7. pydantic_ai/_tool_manager.py +190 -0
  8. pydantic_ai/_utils.py +18 -1
  9. pydantic_ai/ag_ui.py +675 -0
  10. pydantic_ai/agent.py +378 -164
  11. pydantic_ai/exceptions.py +12 -0
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/format_prompt.py +3 -6
  15. pydantic_ai/mcp.py +147 -84
  16. pydantic_ai/messages.py +13 -5
  17. pydantic_ai/models/__init__.py +30 -18
  18. pydantic_ai/models/anthropic.py +1 -1
  19. pydantic_ai/models/function.py +50 -24
  20. pydantic_ai/models/gemini.py +1 -18
  21. pydantic_ai/models/google.py +2 -11
  22. pydantic_ai/models/groq.py +1 -0
  23. pydantic_ai/models/instrumented.py +6 -1
  24. pydantic_ai/models/mistral.py +1 -1
  25. pydantic_ai/models/openai.py +16 -4
  26. pydantic_ai/output.py +21 -7
  27. pydantic_ai/profiles/google.py +1 -1
  28. pydantic_ai/profiles/moonshotai.py +8 -0
  29. pydantic_ai/providers/grok.py +13 -1
  30. pydantic_ai/providers/groq.py +2 -0
  31. pydantic_ai/result.py +58 -45
  32. pydantic_ai/tools.py +26 -119
  33. pydantic_ai/toolsets/__init__.py +22 -0
  34. pydantic_ai/toolsets/abstract.py +155 -0
  35. pydantic_ai/toolsets/combined.py +88 -0
  36. pydantic_ai/toolsets/deferred.py +38 -0
  37. pydantic_ai/toolsets/filtered.py +24 -0
  38. pydantic_ai/toolsets/function.py +238 -0
  39. pydantic_ai/toolsets/prefixed.py +37 -0
  40. pydantic_ai/toolsets/prepared.py +36 -0
  41. pydantic_ai/toolsets/renamed.py +42 -0
  42. pydantic_ai/toolsets/wrapper.py +37 -0
  43. pydantic_ai/usage.py +14 -8
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
  46. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
  47. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
  48. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
@@ -266,7 +266,7 @@ class AnthropicModel(Model):
266
266
  items.append(TextPart(content=item.text))
267
267
  elif isinstance(item, BetaRedactedThinkingBlock): # pragma: no cover
268
268
  warnings.warn(
269
- 'PydanticAI currently does not handle redacted thinking blocks. '
269
+ 'Pydantic AI currently does not handle redacted thinking blocks. '
270
270
  'If you have a suggestion on how we should handle them, please open an issue.',
271
271
  UserWarning,
272
272
  )
@@ -214,21 +214,39 @@ class DeltaToolCall:
214
214
  """Incremental change to the tool call ID."""
215
215
 
216
216
 
217
+ @dataclass
218
+ class DeltaThinkingPart:
219
+ """Incremental change to a thinking part.
220
+
221
+ Used to describe a chunk when streaming thinking responses.
222
+ """
223
+
224
+ content: str | None = None
225
+ """Incremental change to the thinking content."""
226
+ signature: str | None = None
227
+ """Incremental change to the thinking signature."""
228
+
229
+
217
230
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
218
231
  """A mapping of tool call IDs to incremental changes."""
219
232
 
233
+ DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
234
+ """A mapping of thinking call IDs to incremental changes."""
235
+
220
236
  # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
221
237
  FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
222
238
  """A function used to generate a non-streamed response."""
223
239
 
224
240
  # TODO: Change signature as indicated above
225
- StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
241
+ StreamFunctionDef: TypeAlias = Callable[
242
+ [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
243
+ ]
226
244
  """A function used to generate a streamed response.
227
245
 
228
- While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
229
- really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`,
246
+ While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
247
+ really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
230
248
 
231
- E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
249
+ E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
232
250
  """
233
251
 
234
252
 
@@ -237,7 +255,7 @@ class FunctionStreamedResponse(StreamedResponse):
237
255
  """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
238
256
 
239
257
  _model_name: str
240
- _iter: AsyncIterator[str | DeltaToolCalls]
258
+ _iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
241
259
  _timestamp: datetime = field(default_factory=_utils.now_utc)
242
260
 
243
261
  def __post_init__(self):
@@ -249,20 +267,31 @@ class FunctionStreamedResponse(StreamedResponse):
249
267
  response_tokens = _estimate_string_tokens(item)
250
268
  self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
251
269
  yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
252
- else:
253
- delta_tool_calls = item
254
- for dtc_index, delta_tool_call in delta_tool_calls.items():
255
- if delta_tool_call.json_args:
256
- response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
257
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
258
- maybe_event = self._parts_manager.handle_tool_call_delta(
259
- vendor_part_id=dtc_index,
260
- tool_name=delta_tool_call.name,
261
- args=delta_tool_call.json_args,
262
- tool_call_id=delta_tool_call.tool_call_id,
263
- )
264
- if maybe_event is not None:
265
- yield maybe_event
270
+ elif isinstance(item, dict) and item:
271
+ for dtc_index, delta in item.items():
272
+ if isinstance(delta, DeltaThinkingPart):
273
+ if delta.content: # pragma: no branch
274
+ response_tokens = _estimate_string_tokens(delta.content)
275
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
276
+ yield self._parts_manager.handle_thinking_delta(
277
+ vendor_part_id=dtc_index,
278
+ content=delta.content,
279
+ signature=delta.signature,
280
+ )
281
+ elif isinstance(delta, DeltaToolCall):
282
+ if delta.json_args:
283
+ response_tokens = _estimate_string_tokens(delta.json_args)
284
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
285
+ maybe_event = self._parts_manager.handle_tool_call_delta(
286
+ vendor_part_id=dtc_index,
287
+ tool_name=delta.name,
288
+ args=delta.json_args,
289
+ tool_call_id=delta.tool_call_id,
290
+ )
291
+ if maybe_event is not None:
292
+ yield maybe_event
293
+ else:
294
+ assert_never(delta)
266
295
 
267
296
  @property
268
297
  def model_name(self) -> str:
@@ -299,12 +328,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
299
328
  if isinstance(part, TextPart):
300
329
  response_tokens += _estimate_string_tokens(part.content)
301
330
  elif isinstance(part, ThinkingPart):
302
- # NOTE: We don't send ThinkingPart to the providers yet.
303
- # If you are unsatisfied with this, please open an issue.
304
- pass
331
+ response_tokens += _estimate_string_tokens(part.content)
305
332
  elif isinstance(part, ToolCallPart):
306
- call = part
307
- response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
333
+ response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
308
334
  else:
309
335
  assert_never(part)
310
336
  else:
@@ -48,18 +48,10 @@ from . import (
48
48
  )
49
49
 
50
50
  LatestGeminiModelNames = Literal[
51
- 'gemini-1.5-flash',
52
- 'gemini-1.5-flash-8b',
53
- 'gemini-1.5-pro',
54
- 'gemini-1.0-pro',
55
51
  'gemini-2.0-flash',
56
- 'gemini-2.0-flash-lite-preview-02-05',
57
- 'gemini-2.0-pro-exp-02-05',
58
- 'gemini-2.5-flash-preview-05-20',
52
+ 'gemini-2.0-flash-lite',
59
53
  'gemini-2.5-flash',
60
54
  'gemini-2.5-flash-lite-preview-06-17',
61
- 'gemini-2.5-pro-exp-03-25',
62
- 'gemini-2.5-pro-preview-05-06',
63
55
  'gemini-2.5-pro',
64
56
  ]
65
57
  """Latest Gemini models."""
@@ -99,15 +91,6 @@ class GeminiModelSettings(ModelSettings, total=False):
99
91
  See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
100
92
  """
101
93
 
102
- gemini_thinking_config: ThinkingConfig
103
- """Thinking is on by default in both the API and AI Studio.
104
-
105
- Being on by default doesn't mean the model will send back thoughts. For that, you need to set `include_thoughts`
106
- to `True`. If you want to turn it off, set `thinking_budget` to `0`.
107
-
108
- See more about it on <https://ai.google.dev/gemini-api/docs/thinking>.
109
- """
110
-
111
94
 
112
95
  @dataclass(init=False)
113
96
  class GeminiModel(Model):
@@ -73,18 +73,10 @@ except ImportError as _import_error:
73
73
  ) from _import_error
74
74
 
75
75
  LatestGoogleModelNames = Literal[
76
- 'gemini-1.5-flash',
77
- 'gemini-1.5-flash-8b',
78
- 'gemini-1.5-pro',
79
- 'gemini-1.0-pro',
80
76
  'gemini-2.0-flash',
81
- 'gemini-2.0-flash-lite-preview-02-05',
82
- 'gemini-2.0-pro-exp-02-05',
83
- 'gemini-2.5-flash-preview-05-20',
77
+ 'gemini-2.0-flash-lite',
84
78
  'gemini-2.5-flash',
85
79
  'gemini-2.5-flash-lite-preview-06-17',
86
- 'gemini-2.5-pro-exp-03-25',
87
- 'gemini-2.5-pro-preview-05-06',
88
80
  'gemini-2.5-pro',
89
81
  ]
90
82
  """Latest Gemini models."""
@@ -492,8 +484,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
492
484
  function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
493
485
  parts.append({'function_call': function_call})
494
486
  elif isinstance(item, TextPart):
495
- if item.content: # pragma: no branch
496
- parts.append({'text': item.content})
487
+ parts.append({'text': item.content})
497
488
  elif isinstance(item, ThinkingPart): # pragma: no cover
498
489
  # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
499
490
  # please open an issue. The below code is the code to send thinking to the provider.
@@ -79,6 +79,7 @@ PreviewGroqModelNames = Literal[
79
79
  'llama-3.2-3b-preview',
80
80
  'llama-3.2-11b-vision-preview',
81
81
  'llama-3.2-90b-vision-preview',
82
+ 'moonshotai/kimi-k2-instruct',
82
83
  ]
83
84
  """Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
84
85
 
@@ -156,7 +156,12 @@ class InstrumentationSettings:
156
156
  events: list[Event] = []
157
157
  instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
158
158
  if instructions is not None:
159
- events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
159
+ events.append(
160
+ Event(
161
+ 'gen_ai.system.message',
162
+ body={**({'content': instructions} if self.include_content else {}), 'role': 'system'},
163
+ )
164
+ )
160
165
 
161
166
  for message_index, message in enumerate(messages):
162
167
  message_events: list[Event] = []
@@ -428,7 +428,7 @@ class MistralModel(Model):
428
428
  if value_type == 'object':
429
429
  additional_properties = value.get('additionalProperties', {})
430
430
  if isinstance(additional_properties, bool):
431
- return 'bool' # pragma: no cover
431
+ return 'bool' # pragma: lax no cover
432
432
  additional_properties_type = additional_properties.get('type')
433
433
  if (
434
434
  additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
+ from pydantic import ValidationError
11
12
  from typing_extensions import assert_never
12
13
 
13
14
  from pydantic_ai._thinking_part import split_content_into_text_and_thinking
@@ -50,7 +51,7 @@ from . import (
50
51
 
51
52
  try:
52
53
  from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
53
- from openai.types import ChatModel, chat, responses
54
+ from openai.types import AllModels, chat, responses
54
55
  from openai.types.chat import (
55
56
  ChatCompletionChunk,
56
57
  ChatCompletionContentPartImageParam,
@@ -80,7 +81,7 @@ __all__ = (
80
81
  'OpenAIModelName',
81
82
  )
82
83
 
83
- OpenAIModelName = Union[str, ChatModel]
84
+ OpenAIModelName = Union[str, AllModels]
84
85
  """
85
86
  Possible OpenAI model names.
86
87
 
@@ -347,8 +348,19 @@ class OpenAIModel(Model):
347
348
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
348
349
  raise # pragma: no cover
349
350
 
350
- def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
351
+ def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
351
352
  """Process a non-streamed response, and prepare a message to return."""
353
+ # Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
354
+ # * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
355
+ # * if the endpoint returns plain text, the return type is a string
356
+ # Thus we validate it fully here.
357
+ if not isinstance(response, chat.ChatCompletion):
358
+ raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
359
+
360
+ try:
361
+ response = chat.ChatCompletion.model_validate(response.model_dump())
362
+ except ValidationError as e:
363
+ raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
352
364
  timestamp = number_to_datetime(response.created)
353
365
  choice = response.choices[0]
354
366
  items: list[ModelResponsePart] = []
@@ -1051,7 +1063,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1051
1063
  vendor_part_id=chunk.item_id,
1052
1064
  tool_name=None,
1053
1065
  args=chunk.delta,
1054
- tool_call_id=chunk.item_id,
1066
+ tool_call_id=None,
1055
1067
  )
1056
1068
  if maybe_event is not None: # pragma: no branch
1057
1069
  yield maybe_event
pydantic_ai/output.py CHANGED
@@ -10,7 +10,8 @@ from pydantic_core import core_schema
10
10
  from typing_extensions import TypeAliasType, TypeVar
11
11
 
12
12
  from . import _utils
13
- from .tools import RunContext
13
+ from .messages import ToolCallPart
14
+ from .tools import RunContext, ToolDefinition
14
15
 
15
16
  __all__ = (
16
17
  # classes
@@ -330,15 +331,17 @@ def StructuredDict(
330
331
  return _StructuredDict
331
332
 
332
333
 
334
+ _OutputSpecItem = TypeAliasType(
335
+ '_OutputSpecItem',
336
+ Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]],
337
+ type_params=(T_co,),
338
+ )
339
+
333
340
  OutputSpec = TypeAliasType(
334
341
  'OutputSpec',
335
342
  Union[
336
- OutputTypeOrFunction[T_co],
337
- ToolOutput[T_co],
338
- NativeOutput[T_co],
339
- PromptedOutput[T_co],
340
- TextOutput[T_co],
341
- Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]],
343
+ _OutputSpecItem[T_co],
344
+ Sequence['OutputSpec[T_co]'],
342
345
  ],
343
346
  type_params=(T_co,),
344
347
  )
@@ -354,3 +357,14 @@ You should not need to import or use this type directly.
354
357
 
355
358
  See [output docs](../output.md) for more information.
356
359
  """
360
+
361
+
362
+ @dataclass
363
+ class DeferredToolCalls:
364
+ """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.
365
+
366
+ See [deferred toolset docs](../toolsets.md#deferred-toolset) for more information.
367
+ """
368
+
369
+ tool_calls: list[ToolCallPart]
370
+ tool_defs: dict[str, ToolDefinition]
@@ -43,7 +43,7 @@ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
43
43
  f' Full schema: {self.schema}\n\n'
44
44
  f'Source of additionalProperties within the full schema: {original_schema}\n\n'
45
45
  'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
46
- "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
46
+ "If Google's APIs are updated to support this properly, please create an issue on the Pydantic AI GitHub"
47
47
  ' and we will fix this behavior.',
48
48
  UserWarning,
49
49
  )
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for a MoonshotAI model."""
8
+ return None
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import overload
4
+ from typing import Literal, overload
5
5
 
6
6
  from httpx import AsyncClient as AsyncHTTPClient
7
7
  from openai import AsyncOpenAI
@@ -21,6 +21,18 @@ except ImportError as _import_error: # pragma: no cover
21
21
  'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
22
22
  ) from _import_error
23
23
 
24
+ # https://docs.x.ai/docs/models
25
+ GrokModelName = Literal[
26
+ 'grok-4',
27
+ 'grok-4-0709',
28
+ 'grok-3',
29
+ 'grok-3-mini',
30
+ 'grok-3-fast',
31
+ 'grok-3-mini-fast',
32
+ 'grok-2-vision-1212',
33
+ 'grok-2-image-1212',
34
+ ]
35
+
24
36
 
25
37
  class GrokProvider(Provider[AsyncOpenAI]):
26
38
  """Provider for Grok API."""
@@ -12,6 +12,7 @@ from pydantic_ai.profiles.deepseek import deepseek_model_profile
12
12
  from pydantic_ai.profiles.google import google_model_profile
13
13
  from pydantic_ai.profiles.meta import meta_model_profile
14
14
  from pydantic_ai.profiles.mistral import mistral_model_profile
15
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
15
16
  from pydantic_ai.profiles.qwen import qwen_model_profile
16
17
  from pydantic_ai.providers import Provider
17
18
 
@@ -47,6 +48,7 @@ class GroqProvider(Provider[AsyncGroq]):
47
48
  'qwen': qwen_model_profile,
48
49
  'deepseek': deepseek_model_profile,
49
50
  'mistral': mistral_model_profile,
51
+ 'moonshotai/': moonshotai_model_profile,
50
52
  }
51
53
 
52
54
  for prefix, profile_func in prefix_to_profile.items():
pydantic_ai/result.py CHANGED
@@ -5,11 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5
5
  from copy import copy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Generic
8
+ from typing import Generic, cast
9
9
 
10
10
  from pydantic import ValidationError
11
11
  from typing_extensions import TypeVar, deprecated, overload
12
12
 
13
+ from pydantic_ai._tool_manager import ToolManager
14
+
13
15
  from . import _utils, exceptions, messages as _messages, models
14
16
  from ._output import (
15
17
  OutputDataT_inv,
@@ -19,7 +21,6 @@ from ._output import (
19
21
  PlainTextOutputSchema,
20
22
  TextOutputSchema,
21
23
  ToolOutputSchema,
22
- TraceContext,
23
24
  )
24
25
  from ._run_context import AgentDepsT, RunContext
25
26
  from .messages import AgentStreamEvent, FinalResultEvent
@@ -47,8 +48,8 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
47
48
  _output_schema: OutputSchema[OutputDataT]
48
49
  _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
49
50
  _run_ctx: RunContext[AgentDepsT]
50
- _trace_ctx: TraceContext
51
51
  _usage_limits: UsageLimits | None
52
+ _tool_manager: ToolManager[AgentDepsT]
52
53
 
53
54
  _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
54
55
  _final_result_event: FinalResultEvent | None = field(default=None, init=False)
@@ -97,37 +98,40 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
97
98
  self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
98
99
  ) -> OutputDataT:
99
100
  """Validate a structured result message."""
100
- call = None
101
101
  if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
102
- match = self._output_schema.find_named_tool(message.parts, output_tool_name)
103
- if match is None:
102
+ tool_call = next(
103
+ (
104
+ part
105
+ for part in message.parts
106
+ if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name
107
+ ),
108
+ None,
109
+ )
110
+ if tool_call is None:
104
111
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
105
- f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
112
+ f'Invalid response, unable to find tool call for {output_tool_name!r}'
106
113
  )
107
-
108
- call, output_tool = match
109
- result_data = await output_tool.process(
110
- call,
111
- self._run_ctx,
112
- self._trace_ctx,
113
- allow_partial=allow_partial,
114
- wrap_validation_errors=False,
115
- )
114
+ return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
115
+ elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
116
+ if not self._output_schema.allows_deferred_tool_calls:
117
+ raise exceptions.UserError( # pragma: no cover
118
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
119
+ )
120
+ return cast(OutputDataT, deferred_tool_calls)
116
121
  elif isinstance(self._output_schema, TextOutputSchema):
117
122
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
118
123
 
119
124
  result_data = await self._output_schema.process(
120
- text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False
125
+ text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
121
126
  )
127
+ for validator in self._output_validators:
128
+ result_data = await validator.validate(result_data, self._run_ctx)
129
+ return result_data
122
130
  else:
123
131
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
124
132
  'Invalid response, unable to process text output'
125
133
  )
126
134
 
127
- for validator in self._output_validators:
128
- result_data = await validator.validate(result_data, call, self._run_ctx)
129
- return result_data
130
-
131
135
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
132
136
  """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
133
137
 
@@ -145,13 +149,19 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
145
149
  """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
146
150
  if isinstance(e, _messages.PartStartEvent):
147
151
  new_part = e.part
148
- if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema):
149
- for call, _ in output_schema.find_tool([new_part]): # pragma: no branch
150
- return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id)
151
- elif isinstance(new_part, _messages.TextPart) and isinstance(
152
+ if isinstance(new_part, _messages.TextPart) and isinstance(
152
153
  output_schema, TextOutputSchema
153
154
  ): # pragma: no branch
154
155
  return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
156
+ elif isinstance(new_part, _messages.ToolCallPart) and (
157
+ tool_def := self._tool_manager.get_tool_def(new_part.tool_name)
158
+ ):
159
+ if tool_def.kind == 'output':
160
+ return _messages.FinalResultEvent(
161
+ tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
162
+ )
163
+ elif tool_def.kind == 'deferred':
164
+ return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
155
165
 
156
166
  usage_checking_stream = _get_usage_checking_stream_response(
157
167
  self._raw_stream_response, self._usage_limits, self.usage
@@ -183,10 +193,10 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
183
193
  _stream_response: models.StreamedResponse
184
194
  _output_schema: OutputSchema[OutputDataT]
185
195
  _run_ctx: RunContext[AgentDepsT]
186
- _trace_ctx: TraceContext
187
196
  _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
188
197
  _output_tool_name: str | None
189
198
  _on_complete: Callable[[], Awaitable[None]]
199
+ _tool_manager: ToolManager[AgentDepsT]
190
200
 
191
201
  _initial_run_ctx_usage: Usage = field(init=False)
192
202
  is_complete: bool = field(default=False, init=False)
@@ -420,40 +430,43 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
420
430
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
421
431
  ) -> OutputDataT:
422
432
  """Validate a structured result message."""
423
- call = None
424
433
  if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None:
425
- match = self._output_schema.find_named_tool(message.parts, self._output_tool_name)
426
- if match is None:
434
+ tool_call = next(
435
+ (
436
+ part
437
+ for part in message.parts
438
+ if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
439
+ ),
440
+ None,
441
+ )
442
+ if tool_call is None:
427
443
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
428
- f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
444
+ f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
429
445
  )
430
-
431
- call, output_tool = match
432
- result_data = await output_tool.process(
433
- call,
434
- self._run_ctx,
435
- self._trace_ctx,
436
- allow_partial=allow_partial,
437
- wrap_validation_errors=False,
438
- )
446
+ return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
447
+ elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
448
+ if not self._output_schema.allows_deferred_tool_calls:
449
+ raise exceptions.UserError(
450
+ 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
451
+ )
452
+ return cast(OutputDataT, deferred_tool_calls)
439
453
  elif isinstance(self._output_schema, TextOutputSchema):
440
454
  text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
441
455
 
442
456
  result_data = await self._output_schema.process(
443
- text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False
457
+ text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
444
458
  )
459
+ for validator in self._output_validators:
460
+ result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
461
+ return result_data
445
462
  else:
446
463
  raise exceptions.UnexpectedModelBehavior( # pragma: no cover
447
464
  'Invalid response, unable to process text output'
448
465
  )
449
466
 
450
- for validator in self._output_validators:
451
- result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
452
- return result_data
453
-
454
467
  async def _validate_text_output(self, text: str) -> str:
455
468
  for validator in self._output_validators:
456
- text = await validator.validate(text, None, self._run_ctx) # pragma: no cover
469
+ text = await validator.validate(text, self._run_ctx) # pragma: no cover
457
470
  return text
458
471
 
459
472
  async def _marked_completed(self, message: _messages.ModelResponse) -> None: