pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +3 -3
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +378 -164
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/format_prompt.py +3 -6
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -18
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/instrumented.py +6 -1
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +16 -4
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -266,7 +266,7 @@ class AnthropicModel(Model):
|
|
|
266
266
|
items.append(TextPart(content=item.text))
|
|
267
267
|
elif isinstance(item, BetaRedactedThinkingBlock): # pragma: no cover
|
|
268
268
|
warnings.warn(
|
|
269
|
-
'
|
|
269
|
+
'Pydantic AI currently does not handle redacted thinking blocks. '
|
|
270
270
|
'If you have a suggestion on how we should handle them, please open an issue.',
|
|
271
271
|
UserWarning,
|
|
272
272
|
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -214,21 +214,39 @@ class DeltaToolCall:
|
|
|
214
214
|
"""Incremental change to the tool call ID."""
|
|
215
215
|
|
|
216
216
|
|
|
217
|
+
@dataclass
|
|
218
|
+
class DeltaThinkingPart:
|
|
219
|
+
"""Incremental change to a thinking part.
|
|
220
|
+
|
|
221
|
+
Used to describe a chunk when streaming thinking responses.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
content: str | None = None
|
|
225
|
+
"""Incremental change to the thinking content."""
|
|
226
|
+
signature: str | None = None
|
|
227
|
+
"""Incremental change to the thinking signature."""
|
|
228
|
+
|
|
229
|
+
|
|
217
230
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
218
231
|
"""A mapping of tool call IDs to incremental changes."""
|
|
219
232
|
|
|
233
|
+
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
234
|
+
"""A mapping of thinking call IDs to incremental changes."""
|
|
235
|
+
|
|
220
236
|
# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
|
|
221
237
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
222
238
|
"""A function used to generate a non-streamed response."""
|
|
223
239
|
|
|
224
240
|
# TODO: Change signature as indicated above
|
|
225
|
-
StreamFunctionDef: TypeAlias = Callable[
|
|
241
|
+
StreamFunctionDef: TypeAlias = Callable[
|
|
242
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
|
|
243
|
+
]
|
|
226
244
|
"""A function used to generate a streamed response.
|
|
227
245
|
|
|
228
|
-
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
|
|
229
|
-
really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`,
|
|
246
|
+
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
|
|
247
|
+
really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
|
|
230
248
|
|
|
231
|
-
E.g. you need to yield all text or all `
|
|
249
|
+
E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
|
|
232
250
|
"""
|
|
233
251
|
|
|
234
252
|
|
|
@@ -237,7 +255,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
237
255
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
238
256
|
|
|
239
257
|
_model_name: str
|
|
240
|
-
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
258
|
+
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
241
259
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
242
260
|
|
|
243
261
|
def __post_init__(self):
|
|
@@ -249,20 +267,31 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
249
267
|
response_tokens = _estimate_string_tokens(item)
|
|
250
268
|
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
251
269
|
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
|
|
270
|
+
elif isinstance(item, dict) and item:
|
|
271
|
+
for dtc_index, delta in item.items():
|
|
272
|
+
if isinstance(delta, DeltaThinkingPart):
|
|
273
|
+
if delta.content: # pragma: no branch
|
|
274
|
+
response_tokens = _estimate_string_tokens(delta.content)
|
|
275
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
276
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
277
|
+
vendor_part_id=dtc_index,
|
|
278
|
+
content=delta.content,
|
|
279
|
+
signature=delta.signature,
|
|
280
|
+
)
|
|
281
|
+
elif isinstance(delta, DeltaToolCall):
|
|
282
|
+
if delta.json_args:
|
|
283
|
+
response_tokens = _estimate_string_tokens(delta.json_args)
|
|
284
|
+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
285
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
286
|
+
vendor_part_id=dtc_index,
|
|
287
|
+
tool_name=delta.name,
|
|
288
|
+
args=delta.json_args,
|
|
289
|
+
tool_call_id=delta.tool_call_id,
|
|
290
|
+
)
|
|
291
|
+
if maybe_event is not None:
|
|
292
|
+
yield maybe_event
|
|
293
|
+
else:
|
|
294
|
+
assert_never(delta)
|
|
266
295
|
|
|
267
296
|
@property
|
|
268
297
|
def model_name(self) -> str:
|
|
@@ -299,12 +328,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
299
328
|
if isinstance(part, TextPart):
|
|
300
329
|
response_tokens += _estimate_string_tokens(part.content)
|
|
301
330
|
elif isinstance(part, ThinkingPart):
|
|
302
|
-
|
|
303
|
-
# If you are unsatisfied with this, please open an issue.
|
|
304
|
-
pass
|
|
331
|
+
response_tokens += _estimate_string_tokens(part.content)
|
|
305
332
|
elif isinstance(part, ToolCallPart):
|
|
306
|
-
|
|
307
|
-
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
333
|
+
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
|
|
308
334
|
else:
|
|
309
335
|
assert_never(part)
|
|
310
336
|
else:
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -48,18 +48,10 @@ from . import (
|
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
LatestGeminiModelNames = Literal[
|
|
51
|
-
'gemini-1.5-flash',
|
|
52
|
-
'gemini-1.5-flash-8b',
|
|
53
|
-
'gemini-1.5-pro',
|
|
54
|
-
'gemini-1.0-pro',
|
|
55
51
|
'gemini-2.0-flash',
|
|
56
|
-
'gemini-2.0-flash-lite
|
|
57
|
-
'gemini-2.0-pro-exp-02-05',
|
|
58
|
-
'gemini-2.5-flash-preview-05-20',
|
|
52
|
+
'gemini-2.0-flash-lite',
|
|
59
53
|
'gemini-2.5-flash',
|
|
60
54
|
'gemini-2.5-flash-lite-preview-06-17',
|
|
61
|
-
'gemini-2.5-pro-exp-03-25',
|
|
62
|
-
'gemini-2.5-pro-preview-05-06',
|
|
63
55
|
'gemini-2.5-pro',
|
|
64
56
|
]
|
|
65
57
|
"""Latest Gemini models."""
|
|
@@ -99,15 +91,6 @@ class GeminiModelSettings(ModelSettings, total=False):
|
|
|
99
91
|
See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
|
|
100
92
|
"""
|
|
101
93
|
|
|
102
|
-
gemini_thinking_config: ThinkingConfig
|
|
103
|
-
"""Thinking is on by default in both the API and AI Studio.
|
|
104
|
-
|
|
105
|
-
Being on by default doesn't mean the model will send back thoughts. For that, you need to set `include_thoughts`
|
|
106
|
-
to `True`. If you want to turn it off, set `thinking_budget` to `0`.
|
|
107
|
-
|
|
108
|
-
See more about it on <https://ai.google.dev/gemini-api/docs/thinking>.
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
94
|
|
|
112
95
|
@dataclass(init=False)
|
|
113
96
|
class GeminiModel(Model):
|
pydantic_ai/models/google.py
CHANGED
|
@@ -73,18 +73,10 @@ except ImportError as _import_error:
|
|
|
73
73
|
) from _import_error
|
|
74
74
|
|
|
75
75
|
LatestGoogleModelNames = Literal[
|
|
76
|
-
'gemini-1.5-flash',
|
|
77
|
-
'gemini-1.5-flash-8b',
|
|
78
|
-
'gemini-1.5-pro',
|
|
79
|
-
'gemini-1.0-pro',
|
|
80
76
|
'gemini-2.0-flash',
|
|
81
|
-
'gemini-2.0-flash-lite
|
|
82
|
-
'gemini-2.0-pro-exp-02-05',
|
|
83
|
-
'gemini-2.5-flash-preview-05-20',
|
|
77
|
+
'gemini-2.0-flash-lite',
|
|
84
78
|
'gemini-2.5-flash',
|
|
85
79
|
'gemini-2.5-flash-lite-preview-06-17',
|
|
86
|
-
'gemini-2.5-pro-exp-03-25',
|
|
87
|
-
'gemini-2.5-pro-preview-05-06',
|
|
88
80
|
'gemini-2.5-pro',
|
|
89
81
|
]
|
|
90
82
|
"""Latest Gemini models."""
|
|
@@ -492,8 +484,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
492
484
|
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
|
|
493
485
|
parts.append({'function_call': function_call})
|
|
494
486
|
elif isinstance(item, TextPart):
|
|
495
|
-
|
|
496
|
-
parts.append({'text': item.content})
|
|
487
|
+
parts.append({'text': item.content})
|
|
497
488
|
elif isinstance(item, ThinkingPart): # pragma: no cover
|
|
498
489
|
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
499
490
|
# please open an issue. The below code is the code to send thinking to the provider.
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -156,7 +156,12 @@ class InstrumentationSettings:
|
|
|
156
156
|
events: list[Event] = []
|
|
157
157
|
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
|
|
158
158
|
if instructions is not None:
|
|
159
|
-
events.append(
|
|
159
|
+
events.append(
|
|
160
|
+
Event(
|
|
161
|
+
'gen_ai.system.message',
|
|
162
|
+
body={**({'content': instructions} if self.include_content else {}), 'role': 'system'},
|
|
163
|
+
)
|
|
164
|
+
)
|
|
160
165
|
|
|
161
166
|
for message_index, message in enumerate(messages):
|
|
162
167
|
message_events: list[Event] = []
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -428,7 +428,7 @@ class MistralModel(Model):
|
|
|
428
428
|
if value_type == 'object':
|
|
429
429
|
additional_properties = value.get('additionalProperties', {})
|
|
430
430
|
if isinstance(additional_properties, bool):
|
|
431
|
-
return 'bool' # pragma: no cover
|
|
431
|
+
return 'bool' # pragma: lax no cover
|
|
432
432
|
additional_properties_type = additional_properties.get('type')
|
|
433
433
|
if (
|
|
434
434
|
additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
|
+
from pydantic import ValidationError
|
|
11
12
|
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
14
|
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
@@ -50,7 +51,7 @@ from . import (
|
|
|
50
51
|
|
|
51
52
|
try:
|
|
52
53
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
53
|
-
from openai.types import
|
|
54
|
+
from openai.types import AllModels, chat, responses
|
|
54
55
|
from openai.types.chat import (
|
|
55
56
|
ChatCompletionChunk,
|
|
56
57
|
ChatCompletionContentPartImageParam,
|
|
@@ -80,7 +81,7 @@ __all__ = (
|
|
|
80
81
|
'OpenAIModelName',
|
|
81
82
|
)
|
|
82
83
|
|
|
83
|
-
OpenAIModelName = Union[str,
|
|
84
|
+
OpenAIModelName = Union[str, AllModels]
|
|
84
85
|
"""
|
|
85
86
|
Possible OpenAI model names.
|
|
86
87
|
|
|
@@ -347,8 +348,19 @@ class OpenAIModel(Model):
|
|
|
347
348
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
348
349
|
raise # pragma: no cover
|
|
349
350
|
|
|
350
|
-
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
351
|
+
def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
|
|
351
352
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
353
|
+
# Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function:
|
|
354
|
+
# * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!)
|
|
355
|
+
# * if the endpoint returns plain text, the return type is a string
|
|
356
|
+
# Thus we validate it fully here.
|
|
357
|
+
if not isinstance(response, chat.ChatCompletion):
|
|
358
|
+
raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
response = chat.ChatCompletion.model_validate(response.model_dump())
|
|
362
|
+
except ValidationError as e:
|
|
363
|
+
raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
|
|
352
364
|
timestamp = number_to_datetime(response.created)
|
|
353
365
|
choice = response.choices[0]
|
|
354
366
|
items: list[ModelResponsePart] = []
|
|
@@ -1051,7 +1063,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1051
1063
|
vendor_part_id=chunk.item_id,
|
|
1052
1064
|
tool_name=None,
|
|
1053
1065
|
args=chunk.delta,
|
|
1054
|
-
tool_call_id=
|
|
1066
|
+
tool_call_id=None,
|
|
1055
1067
|
)
|
|
1056
1068
|
if maybe_event is not None: # pragma: no branch
|
|
1057
1069
|
yield maybe_event
|
pydantic_ai/output.py
CHANGED
|
@@ -10,7 +10,8 @@ from pydantic_core import core_schema
|
|
|
10
10
|
from typing_extensions import TypeAliasType, TypeVar
|
|
11
11
|
|
|
12
12
|
from . import _utils
|
|
13
|
-
from .
|
|
13
|
+
from .messages import ToolCallPart
|
|
14
|
+
from .tools import RunContext, ToolDefinition
|
|
14
15
|
|
|
15
16
|
__all__ = (
|
|
16
17
|
# classes
|
|
@@ -330,15 +331,17 @@ def StructuredDict(
|
|
|
330
331
|
return _StructuredDict
|
|
331
332
|
|
|
332
333
|
|
|
334
|
+
_OutputSpecItem = TypeAliasType(
|
|
335
|
+
'_OutputSpecItem',
|
|
336
|
+
Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]],
|
|
337
|
+
type_params=(T_co,),
|
|
338
|
+
)
|
|
339
|
+
|
|
333
340
|
OutputSpec = TypeAliasType(
|
|
334
341
|
'OutputSpec',
|
|
335
342
|
Union[
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
NativeOutput[T_co],
|
|
339
|
-
PromptedOutput[T_co],
|
|
340
|
-
TextOutput[T_co],
|
|
341
|
-
Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]],
|
|
343
|
+
_OutputSpecItem[T_co],
|
|
344
|
+
Sequence['OutputSpec[T_co]'],
|
|
342
345
|
],
|
|
343
346
|
type_params=(T_co,),
|
|
344
347
|
)
|
|
@@ -354,3 +357,14 @@ You should not need to import or use this type directly.
|
|
|
354
357
|
|
|
355
358
|
See [output docs](../output.md) for more information.
|
|
356
359
|
"""
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@dataclass
|
|
363
|
+
class DeferredToolCalls:
|
|
364
|
+
"""Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.
|
|
365
|
+
|
|
366
|
+
See [deferred toolset docs](../toolsets.md#deferred-toolset) for more information.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
tool_calls: list[ToolCallPart]
|
|
370
|
+
tool_defs: dict[str, ToolDefinition]
|
pydantic_ai/profiles/google.py
CHANGED
|
@@ -43,7 +43,7 @@ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
|
|
|
43
43
|
f' Full schema: {self.schema}\n\n'
|
|
44
44
|
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
45
45
|
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
46
|
-
"If Google's APIs are updated to support this properly, please create an issue on the
|
|
46
|
+
"If Google's APIs are updated to support this properly, please create an issue on the Pydantic AI GitHub"
|
|
47
47
|
' and we will fix this behavior.',
|
|
48
48
|
UserWarning,
|
|
49
49
|
)
|
pydantic_ai/providers/grok.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import overload
|
|
4
|
+
from typing import Literal, overload
|
|
5
5
|
|
|
6
6
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
7
|
from openai import AsyncOpenAI
|
|
@@ -21,6 +21,18 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
21
21
|
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
22
22
|
) from _import_error
|
|
23
23
|
|
|
24
|
+
# https://docs.x.ai/docs/models
|
|
25
|
+
GrokModelName = Literal[
|
|
26
|
+
'grok-4',
|
|
27
|
+
'grok-4-0709',
|
|
28
|
+
'grok-3',
|
|
29
|
+
'grok-3-mini',
|
|
30
|
+
'grok-3-fast',
|
|
31
|
+
'grok-3-mini-fast',
|
|
32
|
+
'grok-2-vision-1212',
|
|
33
|
+
'grok-2-image-1212',
|
|
34
|
+
]
|
|
35
|
+
|
|
24
36
|
|
|
25
37
|
class GrokProvider(Provider[AsyncOpenAI]):
|
|
26
38
|
"""Provider for Grok API."""
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -12,6 +12,7 @@ from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
|
12
12
|
from pydantic_ai.profiles.google import google_model_profile
|
|
13
13
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
14
14
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
15
|
+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
15
16
|
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
16
17
|
from pydantic_ai.providers import Provider
|
|
17
18
|
|
|
@@ -47,6 +48,7 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
47
48
|
'qwen': qwen_model_profile,
|
|
48
49
|
'deepseek': deepseek_model_profile,
|
|
49
50
|
'mistral': mistral_model_profile,
|
|
51
|
+
'moonshotai/': moonshotai_model_profile,
|
|
50
52
|
}
|
|
51
53
|
|
|
52
54
|
for prefix, profile_func in prefix_to_profile.items():
|
pydantic_ai/result.py
CHANGED
|
@@ -5,11 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
|
5
5
|
from copy import copy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Generic
|
|
8
|
+
from typing import Generic, cast
|
|
9
9
|
|
|
10
10
|
from pydantic import ValidationError
|
|
11
11
|
from typing_extensions import TypeVar, deprecated, overload
|
|
12
12
|
|
|
13
|
+
from pydantic_ai._tool_manager import ToolManager
|
|
14
|
+
|
|
13
15
|
from . import _utils, exceptions, messages as _messages, models
|
|
14
16
|
from ._output import (
|
|
15
17
|
OutputDataT_inv,
|
|
@@ -19,7 +21,6 @@ from ._output import (
|
|
|
19
21
|
PlainTextOutputSchema,
|
|
20
22
|
TextOutputSchema,
|
|
21
23
|
ToolOutputSchema,
|
|
22
|
-
TraceContext,
|
|
23
24
|
)
|
|
24
25
|
from ._run_context import AgentDepsT, RunContext
|
|
25
26
|
from .messages import AgentStreamEvent, FinalResultEvent
|
|
@@ -47,8 +48,8 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
47
48
|
_output_schema: OutputSchema[OutputDataT]
|
|
48
49
|
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
49
50
|
_run_ctx: RunContext[AgentDepsT]
|
|
50
|
-
_trace_ctx: TraceContext
|
|
51
51
|
_usage_limits: UsageLimits | None
|
|
52
|
+
_tool_manager: ToolManager[AgentDepsT]
|
|
52
53
|
|
|
53
54
|
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
54
55
|
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
@@ -97,37 +98,40 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
97
98
|
self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False
|
|
98
99
|
) -> OutputDataT:
|
|
99
100
|
"""Validate a structured result message."""
|
|
100
|
-
call = None
|
|
101
101
|
if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
|
|
102
|
-
|
|
103
|
-
|
|
102
|
+
tool_call = next(
|
|
103
|
+
(
|
|
104
|
+
part
|
|
105
|
+
for part in message.parts
|
|
106
|
+
if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name
|
|
107
|
+
),
|
|
108
|
+
None,
|
|
109
|
+
)
|
|
110
|
+
if tool_call is None:
|
|
104
111
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
105
|
-
f'Invalid response, unable to find tool
|
|
112
|
+
f'Invalid response, unable to find tool call for {output_tool_name!r}'
|
|
106
113
|
)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
wrap_validation_errors=False,
|
|
115
|
-
)
|
|
114
|
+
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
115
|
+
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
116
|
+
if not self._output_schema.allows_deferred_tool_calls:
|
|
117
|
+
raise exceptions.UserError( # pragma: no cover
|
|
118
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
119
|
+
)
|
|
120
|
+
return cast(OutputDataT, deferred_tool_calls)
|
|
116
121
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
117
122
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
118
123
|
|
|
119
124
|
result_data = await self._output_schema.process(
|
|
120
|
-
text, self._run_ctx,
|
|
125
|
+
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
121
126
|
)
|
|
127
|
+
for validator in self._output_validators:
|
|
128
|
+
result_data = await validator.validate(result_data, self._run_ctx)
|
|
129
|
+
return result_data
|
|
122
130
|
else:
|
|
123
131
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
124
132
|
'Invalid response, unable to process text output'
|
|
125
133
|
)
|
|
126
134
|
|
|
127
|
-
for validator in self._output_validators:
|
|
128
|
-
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
129
|
-
return result_data
|
|
130
|
-
|
|
131
135
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
132
136
|
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
133
137
|
|
|
@@ -145,13 +149,19 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
145
149
|
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
146
150
|
if isinstance(e, _messages.PartStartEvent):
|
|
147
151
|
new_part = e.part
|
|
148
|
-
if isinstance(new_part, _messages.
|
|
149
|
-
for call, _ in output_schema.find_tool([new_part]): # pragma: no branch
|
|
150
|
-
return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id)
|
|
151
|
-
elif isinstance(new_part, _messages.TextPart) and isinstance(
|
|
152
|
+
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
152
153
|
output_schema, TextOutputSchema
|
|
153
154
|
): # pragma: no branch
|
|
154
155
|
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
156
|
+
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
157
|
+
tool_def := self._tool_manager.get_tool_def(new_part.tool_name)
|
|
158
|
+
):
|
|
159
|
+
if tool_def.kind == 'output':
|
|
160
|
+
return _messages.FinalResultEvent(
|
|
161
|
+
tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id
|
|
162
|
+
)
|
|
163
|
+
elif tool_def.kind == 'deferred':
|
|
164
|
+
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
|
|
155
165
|
|
|
156
166
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
157
167
|
self._raw_stream_response, self._usage_limits, self.usage
|
|
@@ -183,10 +193,10 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
183
193
|
_stream_response: models.StreamedResponse
|
|
184
194
|
_output_schema: OutputSchema[OutputDataT]
|
|
185
195
|
_run_ctx: RunContext[AgentDepsT]
|
|
186
|
-
_trace_ctx: TraceContext
|
|
187
196
|
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
188
197
|
_output_tool_name: str | None
|
|
189
198
|
_on_complete: Callable[[], Awaitable[None]]
|
|
199
|
+
_tool_manager: ToolManager[AgentDepsT]
|
|
190
200
|
|
|
191
201
|
_initial_run_ctx_usage: Usage = field(init=False)
|
|
192
202
|
is_complete: bool = field(default=False, init=False)
|
|
@@ -420,40 +430,43 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
420
430
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
421
431
|
) -> OutputDataT:
|
|
422
432
|
"""Validate a structured result message."""
|
|
423
|
-
call = None
|
|
424
433
|
if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None:
|
|
425
|
-
|
|
426
|
-
|
|
434
|
+
tool_call = next(
|
|
435
|
+
(
|
|
436
|
+
part
|
|
437
|
+
for part in message.parts
|
|
438
|
+
if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
|
|
439
|
+
),
|
|
440
|
+
None,
|
|
441
|
+
)
|
|
442
|
+
if tool_call is None:
|
|
427
443
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
428
|
-
f'Invalid response, unable to find tool
|
|
444
|
+
f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
|
|
429
445
|
)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
wrap_validation_errors=False,
|
|
438
|
-
)
|
|
446
|
+
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
447
|
+
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
448
|
+
if not self._output_schema.allows_deferred_tool_calls:
|
|
449
|
+
raise exceptions.UserError(
|
|
450
|
+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
451
|
+
)
|
|
452
|
+
return cast(OutputDataT, deferred_tool_calls)
|
|
439
453
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
440
454
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
441
455
|
|
|
442
456
|
result_data = await self._output_schema.process(
|
|
443
|
-
text, self._run_ctx,
|
|
457
|
+
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
444
458
|
)
|
|
459
|
+
for validator in self._output_validators:
|
|
460
|
+
result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
|
|
461
|
+
return result_data
|
|
445
462
|
else:
|
|
446
463
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
447
464
|
'Invalid response, unable to process text output'
|
|
448
465
|
)
|
|
449
466
|
|
|
450
|
-
for validator in self._output_validators:
|
|
451
|
-
result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover
|
|
452
|
-
return result_data
|
|
453
|
-
|
|
454
467
|
async def _validate_text_output(self, text: str) -> str:
|
|
455
468
|
for validator in self._output_validators:
|
|
456
|
-
text = await validator.validate(text,
|
|
469
|
+
text = await validator.validate(text, self._run_ctx) # pragma: no cover
|
|
457
470
|
return text
|
|
458
471
|
|
|
459
472
|
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|