pydantic-ai-slim 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -0
- pydantic_ai/_a2a.py +6 -4
- pydantic_ai/_agent_graph.py +32 -32
- pydantic_ai/_cli.py +3 -3
- pydantic_ai/_output.py +8 -0
- pydantic_ai/_tool_manager.py +3 -0
- pydantic_ai/_utils.py +7 -1
- pydantic_ai/ag_ui.py +25 -14
- pydantic_ai/{agent.py → agent/__init__.py} +217 -1026
- pydantic_ai/agent/abstract.py +942 -0
- pydantic_ai/agent/wrapper.py +227 -0
- pydantic_ai/builtin_tools.py +105 -0
- pydantic_ai/direct.py +9 -9
- pydantic_ai/durable_exec/__init__.py +0 -0
- pydantic_ai/durable_exec/temporal/__init__.py +83 -0
- pydantic_ai/durable_exec/temporal/_agent.py +699 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
- pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
- pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
- pydantic_ai/durable_exec/temporal/_model.py +168 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
- pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
- pydantic_ai/ext/aci.py +10 -9
- pydantic_ai/ext/langchain.py +4 -2
- pydantic_ai/mcp.py +203 -75
- pydantic_ai/messages.py +75 -13
- pydantic_ai/models/__init__.py +66 -8
- pydantic_ai/models/anthropic.py +135 -18
- pydantic_ai/models/bedrock.py +16 -5
- pydantic_ai/models/cohere.py +11 -4
- pydantic_ai/models/fallback.py +4 -2
- pydantic_ai/models/function.py +18 -4
- pydantic_ai/models/gemini.py +20 -9
- pydantic_ai/models/google.py +53 -15
- pydantic_ai/models/groq.py +47 -11
- pydantic_ai/models/huggingface.py +26 -11
- pydantic_ai/models/instrumented.py +3 -1
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +27 -17
- pydantic_ai/models/openai.py +97 -33
- pydantic_ai/models/test.py +12 -0
- pydantic_ai/models/wrapper.py +6 -2
- pydantic_ai/profiles/groq.py +23 -0
- pydantic_ai/profiles/openai.py +1 -1
- pydantic_ai/providers/google.py +7 -7
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +21 -55
- pydantic_ai/run.py +357 -0
- pydantic_ai/tools.py +0 -1
- pydantic_ai/toolsets/__init__.py +2 -0
- pydantic_ai/toolsets/_dynamic.py +87 -0
- pydantic_ai/toolsets/abstract.py +23 -3
- pydantic_ai/toolsets/combined.py +19 -4
- pydantic_ai/toolsets/deferred.py +10 -2
- pydantic_ai/toolsets/function.py +23 -8
- pydantic_ai/toolsets/prefixed.py +4 -0
- pydantic_ai/toolsets/wrapper.py +14 -1
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +7 -5
- pydantic_ai_slim-0.7.0.dist-info/RECORD +115 -0
- pydantic_ai_slim-0.6.1.dist-info/RECORD +0 -100
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -11,12 +11,15 @@ import pydantic_core
|
|
|
11
11
|
from httpx import Timeout
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
15
|
-
|
|
16
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
15
|
+
from .._run_context import RunContext
|
|
16
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
17
17
|
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
18
|
+
from ..exceptions import UserError
|
|
18
19
|
from ..messages import (
|
|
19
20
|
BinaryContent,
|
|
21
|
+
BuiltinToolCallPart,
|
|
22
|
+
BuiltinToolReturnPart,
|
|
20
23
|
DocumentUrl,
|
|
21
24
|
ImageUrl,
|
|
22
25
|
ModelMessage,
|
|
@@ -173,6 +176,7 @@ class MistralModel(Model):
|
|
|
173
176
|
messages: list[ModelMessage],
|
|
174
177
|
model_settings: ModelSettings | None,
|
|
175
178
|
model_request_parameters: ModelRequestParameters,
|
|
179
|
+
run_context: RunContext[Any] | None = None,
|
|
176
180
|
) -> AsyncIterator[StreamedResponse]:
|
|
177
181
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
178
182
|
check_allow_model_requests()
|
|
@@ -180,7 +184,7 @@ class MistralModel(Model):
|
|
|
180
184
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
181
185
|
)
|
|
182
186
|
async with response:
|
|
183
|
-
yield await self._process_streamed_response(
|
|
187
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
184
188
|
|
|
185
189
|
@property
|
|
186
190
|
def model_name(self) -> MistralModelName:
|
|
@@ -199,6 +203,11 @@ class MistralModel(Model):
|
|
|
199
203
|
model_request_parameters: ModelRequestParameters,
|
|
200
204
|
) -> MistralChatCompletionResponse:
|
|
201
205
|
"""Make a non-streaming request to the model."""
|
|
206
|
+
# TODO(Marcelo): We need to replace the current MistralAI client to use the beta client.
|
|
207
|
+
# See https://docs.mistral.ai/agents/connectors/websearch/ to support web search.
|
|
208
|
+
if model_request_parameters.builtin_tools:
|
|
209
|
+
raise UserError('Mistral does not support built-in tools')
|
|
210
|
+
|
|
202
211
|
try:
|
|
203
212
|
response = await self.client.chat.complete_async(
|
|
204
213
|
model=str(self._model_name),
|
|
@@ -233,11 +242,12 @@ class MistralModel(Model):
|
|
|
233
242
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
234
243
|
mistral_messages = self._map_messages(messages)
|
|
235
244
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
245
|
+
# TODO(Marcelo): We need to replace the current MistralAI client to use the beta client.
|
|
246
|
+
# See https://docs.mistral.ai/agents/connectors/websearch/ to support web search.
|
|
247
|
+
if model_request_parameters.builtin_tools:
|
|
248
|
+
raise UserError('Mistral does not support built-in tools')
|
|
249
|
+
|
|
250
|
+
if model_request_parameters.function_tools:
|
|
241
251
|
# Function Calling
|
|
242
252
|
response = await self.client.chat.stream_async(
|
|
243
253
|
model=str(self._model_name),
|
|
@@ -305,16 +315,13 @@ class MistralModel(Model):
|
|
|
305
315
|
|
|
306
316
|
Returns None if both function_tools and output_tools are empty.
|
|
307
317
|
"""
|
|
308
|
-
all_tools: list[ToolDefinition] = (
|
|
309
|
-
model_request_parameters.function_tools + model_request_parameters.output_tools
|
|
310
|
-
)
|
|
311
318
|
tools = [
|
|
312
319
|
MistralTool(
|
|
313
320
|
function=MistralFunction(
|
|
314
321
|
name=r.name, parameters=r.parameters_json_schema, description=r.description or ''
|
|
315
322
|
)
|
|
316
323
|
)
|
|
317
|
-
for r in
|
|
324
|
+
for r in model_request_parameters.tool_defs.values()
|
|
318
325
|
]
|
|
319
326
|
return tools if tools else None
|
|
320
327
|
|
|
@@ -346,8 +353,8 @@ class MistralModel(Model):
|
|
|
346
353
|
|
|
347
354
|
async def _process_streamed_response(
|
|
348
355
|
self,
|
|
349
|
-
output_tools: list[ToolDefinition],
|
|
350
356
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
357
|
+
model_request_parameters: ModelRequestParameters,
|
|
351
358
|
) -> StreamedResponse:
|
|
352
359
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
353
360
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -363,10 +370,10 @@ class MistralModel(Model):
|
|
|
363
370
|
timestamp = _now_utc()
|
|
364
371
|
|
|
365
372
|
return MistralStreamedResponse(
|
|
373
|
+
model_request_parameters=model_request_parameters,
|
|
366
374
|
_response=peekable_response,
|
|
367
375
|
_model_name=self._model_name,
|
|
368
376
|
_timestamp=timestamp,
|
|
369
|
-
_output_tools={c.name: c for c in output_tools},
|
|
370
377
|
)
|
|
371
378
|
|
|
372
379
|
@staticmethod
|
|
@@ -502,6 +509,9 @@ class MistralModel(Model):
|
|
|
502
509
|
pass
|
|
503
510
|
elif isinstance(part, ToolCallPart):
|
|
504
511
|
tool_calls.append(self._map_tool_call(part))
|
|
512
|
+
elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
513
|
+
# This is currently never returned from mistral
|
|
514
|
+
pass
|
|
505
515
|
else:
|
|
506
516
|
assert_never(part)
|
|
507
517
|
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
@@ -570,7 +580,6 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
570
580
|
_model_name: MistralModelName
|
|
571
581
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
572
582
|
_timestamp: datetime
|
|
573
|
-
_output_tools: dict[str, ToolDefinition]
|
|
574
583
|
|
|
575
584
|
_delta_content: str = field(default='', init=False)
|
|
576
585
|
|
|
@@ -589,10 +598,11 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
589
598
|
text = _map_content(content)
|
|
590
599
|
if text:
|
|
591
600
|
# Attempt to produce an output tool call from the received text
|
|
592
|
-
|
|
601
|
+
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
|
|
602
|
+
if output_tools:
|
|
593
603
|
self._delta_content += text
|
|
594
604
|
# TODO: Port to native "manual JSON" mode
|
|
595
|
-
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content,
|
|
605
|
+
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, output_tools)
|
|
596
606
|
if maybe_tool_call_part:
|
|
597
607
|
yield self._parts_manager.handle_tool_call_part(
|
|
598
608
|
vendor_part_id='output',
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -11,16 +11,18 @@ from typing import Any, Literal, Union, cast, overload
|
|
|
11
11
|
from pydantic import ValidationError
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
15
|
-
from pydantic_ai.profiles.openai import OpenAIModelProfile
|
|
16
|
-
from pydantic_ai.providers import Provider, infer_provider
|
|
17
|
-
|
|
18
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
15
|
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
16
|
+
from .._run_context import RunContext
|
|
17
|
+
from .._thinking_part import split_content_into_text_and_thinking
|
|
20
18
|
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
19
|
+
from ..builtin_tools import CodeExecutionTool, WebSearchTool
|
|
20
|
+
from ..exceptions import UserError
|
|
21
21
|
from ..messages import (
|
|
22
22
|
AudioUrl,
|
|
23
23
|
BinaryContent,
|
|
24
|
+
BuiltinToolCallPart,
|
|
25
|
+
BuiltinToolReturnPart,
|
|
24
26
|
DocumentUrl,
|
|
25
27
|
ImageUrl,
|
|
26
28
|
ModelMessage,
|
|
@@ -38,16 +40,11 @@ from ..messages import (
|
|
|
38
40
|
VideoUrl,
|
|
39
41
|
)
|
|
40
42
|
from ..profiles import ModelProfile, ModelProfileSpec
|
|
43
|
+
from ..profiles.openai import OpenAIModelProfile
|
|
44
|
+
from ..providers import Provider, infer_provider
|
|
41
45
|
from ..settings import ModelSettings
|
|
42
46
|
from ..tools import ToolDefinition
|
|
43
|
-
from . import
|
|
44
|
-
Model,
|
|
45
|
-
ModelRequestParameters,
|
|
46
|
-
StreamedResponse,
|
|
47
|
-
check_allow_model_requests,
|
|
48
|
-
download_item,
|
|
49
|
-
get_user_agent,
|
|
50
|
-
)
|
|
47
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
51
48
|
|
|
52
49
|
try:
|
|
53
50
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
@@ -63,6 +60,11 @@ try:
|
|
|
63
60
|
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
64
61
|
from openai.types.chat.chat_completion_content_part_param import File, FileFile
|
|
65
62
|
from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam
|
|
63
|
+
from openai.types.chat.completion_create_params import (
|
|
64
|
+
WebSearchOptions,
|
|
65
|
+
WebSearchOptionsUserLocation,
|
|
66
|
+
WebSearchOptionsUserLocationApproximate,
|
|
67
|
+
)
|
|
66
68
|
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
|
|
67
69
|
from openai.types.responses.response_input_param import FunctionCallOutput, Message
|
|
68
70
|
from openai.types.shared import ReasoningEffort
|
|
@@ -254,13 +256,14 @@ class OpenAIModel(Model):
|
|
|
254
256
|
messages: list[ModelMessage],
|
|
255
257
|
model_settings: ModelSettings | None,
|
|
256
258
|
model_request_parameters: ModelRequestParameters,
|
|
259
|
+
run_context: RunContext[Any] | None = None,
|
|
257
260
|
) -> AsyncIterator[StreamedResponse]:
|
|
258
261
|
check_allow_model_requests()
|
|
259
262
|
response = await self._completions_create(
|
|
260
263
|
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
261
264
|
)
|
|
262
265
|
async with response:
|
|
263
|
-
yield await self._process_streamed_response(response)
|
|
266
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
264
267
|
|
|
265
268
|
@property
|
|
266
269
|
def model_name(self) -> OpenAIModelName:
|
|
@@ -298,6 +301,8 @@ class OpenAIModel(Model):
|
|
|
298
301
|
model_request_parameters: ModelRequestParameters,
|
|
299
302
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
300
303
|
tools = self._get_tools(model_request_parameters)
|
|
304
|
+
web_search_options = self._get_web_search_options(model_request_parameters)
|
|
305
|
+
|
|
301
306
|
if not tools:
|
|
302
307
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
303
308
|
elif (
|
|
@@ -344,6 +349,7 @@ class OpenAIModel(Model):
|
|
|
344
349
|
seed=model_settings.get('seed', NOT_GIVEN),
|
|
345
350
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
346
351
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
352
|
+
web_search_options=web_search_options or NOT_GIVEN,
|
|
347
353
|
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
|
|
348
354
|
prediction=model_settings.get('openai_prediction', NOT_GIVEN),
|
|
349
355
|
temperature=sampling_settings.get('temperature', NOT_GIVEN),
|
|
@@ -422,7 +428,9 @@ class OpenAIModel(Model):
|
|
|
422
428
|
vendor_id=response.id,
|
|
423
429
|
)
|
|
424
430
|
|
|
425
|
-
async def _process_streamed_response(
|
|
431
|
+
async def _process_streamed_response(
|
|
432
|
+
self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters
|
|
433
|
+
) -> OpenAIStreamedResponse:
|
|
426
434
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
427
435
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
428
436
|
first_chunk = await peekable_response.peek()
|
|
@@ -432,6 +440,7 @@ class OpenAIModel(Model):
|
|
|
432
440
|
)
|
|
433
441
|
|
|
434
442
|
return OpenAIStreamedResponse(
|
|
443
|
+
model_request_parameters=model_request_parameters,
|
|
435
444
|
_model_name=self._model_name,
|
|
436
445
|
_model_profile=self.profile,
|
|
437
446
|
_response=peekable_response,
|
|
@@ -439,10 +448,24 @@ class OpenAIModel(Model):
|
|
|
439
448
|
)
|
|
440
449
|
|
|
441
450
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
451
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
452
|
+
|
|
453
|
+
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
|
|
454
|
+
for tool in model_request_parameters.builtin_tools:
|
|
455
|
+
if isinstance(tool, WebSearchTool): # pragma: no branch
|
|
456
|
+
if tool.user_location:
|
|
457
|
+
return WebSearchOptions(
|
|
458
|
+
search_context_size=tool.search_context_size,
|
|
459
|
+
user_location=WebSearchOptionsUserLocation(
|
|
460
|
+
type='approximate',
|
|
461
|
+
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location),
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
return WebSearchOptions(search_context_size=tool.search_context_size)
|
|
465
|
+
else:
|
|
466
|
+
raise UserError(
|
|
467
|
+
f'`{tool.__class__.__name__}` is not supported by `OpenAIModel`. If it should be, please file an issue.'
|
|
468
|
+
)
|
|
446
469
|
|
|
447
470
|
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
448
471
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
@@ -464,6 +487,9 @@ class OpenAIModel(Model):
|
|
|
464
487
|
pass
|
|
465
488
|
elif isinstance(item, ToolCallPart):
|
|
466
489
|
tool_calls.append(self._map_tool_call(item))
|
|
490
|
+
# OpenAI doesn't return built-in tool calls
|
|
491
|
+
elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
492
|
+
pass
|
|
467
493
|
else:
|
|
468
494
|
assert_never(item)
|
|
469
495
|
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
@@ -608,14 +634,6 @@ class OpenAIResponsesModel(Model):
|
|
|
608
634
|
The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the
|
|
609
635
|
new API for OpenAI models.
|
|
610
636
|
|
|
611
|
-
The Responses API has built-in tools, that you can use instead of building your own:
|
|
612
|
-
|
|
613
|
-
- [Web search](https://platform.openai.com/docs/guides/tools-web-search)
|
|
614
|
-
- [File search](https://platform.openai.com/docs/guides/tools-file-search)
|
|
615
|
-
- [Computer use](https://platform.openai.com/docs/guides/tools-computer-use)
|
|
616
|
-
|
|
617
|
-
Use the `openai_builtin_tools` setting to add these tools to your model.
|
|
618
|
-
|
|
619
637
|
If you are interested in the differences between the Responses API and the Chat Completions API,
|
|
620
638
|
see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
|
|
621
639
|
"""
|
|
@@ -679,13 +697,14 @@ class OpenAIResponsesModel(Model):
|
|
|
679
697
|
messages: list[ModelMessage],
|
|
680
698
|
model_settings: ModelSettings | None,
|
|
681
699
|
model_request_parameters: ModelRequestParameters,
|
|
700
|
+
run_context: RunContext[Any] | None = None,
|
|
682
701
|
) -> AsyncIterator[StreamedResponse]:
|
|
683
702
|
check_allow_model_requests()
|
|
684
703
|
response = await self._responses_create(
|
|
685
704
|
messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
|
|
686
705
|
)
|
|
687
706
|
async with response:
|
|
688
|
-
yield await self._process_streamed_response(response)
|
|
707
|
+
yield await self._process_streamed_response(response, model_request_parameters)
|
|
689
708
|
|
|
690
709
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
691
710
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -712,7 +731,9 @@ class OpenAIResponsesModel(Model):
|
|
|
712
731
|
)
|
|
713
732
|
|
|
714
733
|
async def _process_streamed_response(
|
|
715
|
-
self,
|
|
734
|
+
self,
|
|
735
|
+
response: AsyncStream[responses.ResponseStreamEvent],
|
|
736
|
+
model_request_parameters: ModelRequestParameters,
|
|
716
737
|
) -> OpenAIResponsesStreamedResponse:
|
|
717
738
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
718
739
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -722,6 +743,7 @@ class OpenAIResponsesModel(Model):
|
|
|
722
743
|
|
|
723
744
|
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
|
|
724
745
|
return OpenAIResponsesStreamedResponse(
|
|
746
|
+
model_request_parameters=model_request_parameters,
|
|
725
747
|
_model_name=self._model_name,
|
|
726
748
|
_response=peekable_response,
|
|
727
749
|
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
@@ -752,8 +774,11 @@ class OpenAIResponsesModel(Model):
|
|
|
752
774
|
model_settings: OpenAIResponsesModelSettings,
|
|
753
775
|
model_request_parameters: ModelRequestParameters,
|
|
754
776
|
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
|
|
755
|
-
tools =
|
|
756
|
-
|
|
777
|
+
tools = (
|
|
778
|
+
self._get_builtin_tools(model_request_parameters)
|
|
779
|
+
+ list(model_settings.get('openai_builtin_tools', []))
|
|
780
|
+
+ self._get_tools(model_request_parameters)
|
|
781
|
+
)
|
|
757
782
|
|
|
758
783
|
if not tools:
|
|
759
784
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
@@ -836,9 +861,26 @@ class OpenAIResponsesModel(Model):
|
|
|
836
861
|
return Reasoning(effort=reasoning_effort, summary=reasoning_summary)
|
|
837
862
|
|
|
838
863
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
864
|
+
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
865
|
+
|
|
866
|
+
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]:
|
|
867
|
+
tools: list[responses.ToolParam] = []
|
|
868
|
+
for tool in model_request_parameters.builtin_tools:
|
|
869
|
+
if isinstance(tool, WebSearchTool):
|
|
870
|
+
web_search_tool = responses.WebSearchToolParam(
|
|
871
|
+
type='web_search_preview', search_context_size=tool.search_context_size
|
|
872
|
+
)
|
|
873
|
+
if tool.user_location:
|
|
874
|
+
web_search_tool['user_location'] = responses.web_search_tool_param.UserLocation(
|
|
875
|
+
type='approximate', **tool.user_location
|
|
876
|
+
)
|
|
877
|
+
tools.append(web_search_tool)
|
|
878
|
+
elif isinstance(tool, CodeExecutionTool): # pragma: no branch
|
|
879
|
+
tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}})
|
|
880
|
+
else:
|
|
881
|
+
raise UserError( # pragma: no cover
|
|
882
|
+
f'`{tool.__class__.__name__}` is not supported by `OpenAIResponsesModel`. If it should be, please file an issue.'
|
|
883
|
+
)
|
|
842
884
|
return tools
|
|
843
885
|
|
|
844
886
|
def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
|
|
@@ -895,6 +937,9 @@ class OpenAIResponsesModel(Model):
|
|
|
895
937
|
openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
|
|
896
938
|
elif isinstance(item, ToolCallPart):
|
|
897
939
|
openai_messages.append(self._map_tool_call(item))
|
|
940
|
+
# OpenAI doesn't return built-in tool calls
|
|
941
|
+
elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
|
|
942
|
+
pass
|
|
898
943
|
elif isinstance(item, ThinkingPart):
|
|
899
944
|
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
900
945
|
# please open an issue. The below code is the code to send thinking to the provider.
|
|
@@ -1071,6 +1116,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1071
1116
|
|
|
1072
1117
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
1073
1118
|
async for chunk in self._response:
|
|
1119
|
+
# NOTE: You can inspect the builtin tools used checking the `ResponseCompletedEvent`.
|
|
1074
1120
|
if isinstance(chunk, responses.ResponseCompletedEvent):
|
|
1075
1121
|
self._usage += _map_usage(chunk.response)
|
|
1076
1122
|
|
|
@@ -1122,6 +1168,8 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1122
1168
|
)
|
|
1123
1169
|
elif isinstance(chunk.item, responses.ResponseOutputMessage):
|
|
1124
1170
|
pass
|
|
1171
|
+
elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
|
|
1172
|
+
pass
|
|
1125
1173
|
else:
|
|
1126
1174
|
warnings.warn( # pragma: no cover
|
|
1127
1175
|
f'Handling of this item type is not yet implemented. Please report on our GitHub: {chunk}',
|
|
@@ -1148,6 +1196,10 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1148
1196
|
signature=chunk.item_id,
|
|
1149
1197
|
)
|
|
1150
1198
|
|
|
1199
|
+
# TODO(Marcelo): We should support annotations in the future.
|
|
1200
|
+
elif isinstance(chunk, responses.ResponseOutputTextAnnotationAddedEvent):
|
|
1201
|
+
pass # there's nothing we need to do here
|
|
1202
|
+
|
|
1151
1203
|
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
1152
1204
|
maybe_event = self._parts_manager.handle_text_delta(
|
|
1153
1205
|
vendor_part_id=chunk.content_index, content=chunk.delta
|
|
@@ -1158,6 +1210,18 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1158
1210
|
elif isinstance(chunk, responses.ResponseTextDoneEvent):
|
|
1159
1211
|
pass # there's nothing we need to do here
|
|
1160
1212
|
|
|
1213
|
+
elif isinstance(chunk, responses.ResponseWebSearchCallInProgressEvent):
|
|
1214
|
+
pass # there's nothing we need to do here
|
|
1215
|
+
|
|
1216
|
+
elif isinstance(chunk, responses.ResponseWebSearchCallSearchingEvent):
|
|
1217
|
+
pass # there's nothing we need to do here
|
|
1218
|
+
|
|
1219
|
+
elif isinstance(chunk, responses.ResponseWebSearchCallCompletedEvent):
|
|
1220
|
+
pass # there's nothing we need to do here
|
|
1221
|
+
|
|
1222
|
+
elif isinstance(chunk, responses.ResponseAudioDeltaEvent): # pragma: lax no cover
|
|
1223
|
+
pass # there's nothing we need to do here
|
|
1224
|
+
|
|
1161
1225
|
else: # pragma: no cover
|
|
1162
1226
|
warnings.warn(
|
|
1163
1227
|
f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
|
pydantic_ai/models/test.py
CHANGED
|
@@ -12,7 +12,11 @@ import pydantic_core
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
14
|
from .. import _utils
|
|
15
|
+
from .._run_context import RunContext
|
|
16
|
+
from ..exceptions import UserError
|
|
15
17
|
from ..messages import (
|
|
18
|
+
BuiltinToolCallPart,
|
|
19
|
+
BuiltinToolReturnPart,
|
|
16
20
|
ModelMessage,
|
|
17
21
|
ModelRequest,
|
|
18
22
|
ModelResponse,
|
|
@@ -118,11 +122,13 @@ class TestModel(Model):
|
|
|
118
122
|
messages: list[ModelMessage],
|
|
119
123
|
model_settings: ModelSettings | None,
|
|
120
124
|
model_request_parameters: ModelRequestParameters,
|
|
125
|
+
run_context: RunContext[Any] | None = None,
|
|
121
126
|
) -> AsyncIterator[StreamedResponse]:
|
|
122
127
|
self.last_model_request_parameters = model_request_parameters
|
|
123
128
|
|
|
124
129
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
125
130
|
yield TestStreamedResponse(
|
|
131
|
+
model_request_parameters=model_request_parameters,
|
|
126
132
|
_model_name=self._model_name,
|
|
127
133
|
_structured_response=model_response,
|
|
128
134
|
_messages=messages,
|
|
@@ -179,6 +185,9 @@ class TestModel(Model):
|
|
|
179
185
|
model_settings: ModelSettings | None,
|
|
180
186
|
model_request_parameters: ModelRequestParameters,
|
|
181
187
|
) -> ModelResponse:
|
|
188
|
+
if model_request_parameters.builtin_tools:
|
|
189
|
+
raise UserError('TestModel does not support built-in tools')
|
|
190
|
+
|
|
182
191
|
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
183
192
|
output_wrapper = self._get_output(model_request_parameters)
|
|
184
193
|
output_tools = model_request_parameters.output_tools
|
|
@@ -283,6 +292,9 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
283
292
|
yield self._parts_manager.handle_tool_call_part(
|
|
284
293
|
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
|
285
294
|
)
|
|
295
|
+
elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
|
|
296
|
+
# NOTE: These parts are not generated by TestModel, but we need to handle them for type checking
|
|
297
|
+
assert False, f'Unexpected part type in TestModel: {type(part).__name__}'
|
|
286
298
|
elif isinstance(part, ThinkingPart): # pragma: no cover
|
|
287
299
|
# NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel.
|
|
288
300
|
assert False, "This should be unreachable — we don't generate ThinkingPart on TestModel."
|
pydantic_ai/models/wrapper.py
CHANGED
|
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|
|
6
6
|
from functools import cached_property
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
+
from .._run_context import RunContext
|
|
9
10
|
from ..messages import ModelMessage, ModelResponse
|
|
10
11
|
from ..profiles import ModelProfile
|
|
11
12
|
from ..settings import ModelSettings
|
|
@@ -35,8 +36,11 @@ class WrapperModel(Model):
|
|
|
35
36
|
messages: list[ModelMessage],
|
|
36
37
|
model_settings: ModelSettings | None,
|
|
37
38
|
model_request_parameters: ModelRequestParameters,
|
|
39
|
+
run_context: RunContext[Any] | None = None,
|
|
38
40
|
) -> AsyncIterator[StreamedResponse]:
|
|
39
|
-
async with self.wrapped.request_stream(
|
|
41
|
+
async with self.wrapped.request_stream(
|
|
42
|
+
messages, model_settings, model_request_parameters, run_context
|
|
43
|
+
) as response_stream:
|
|
40
44
|
yield response_stream
|
|
41
45
|
|
|
42
46
|
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
@@ -60,4 +64,4 @@ class WrapperModel(Model):
|
|
|
60
64
|
return self.wrapped.settings
|
|
61
65
|
|
|
62
66
|
def __getattr__(self, item: str):
|
|
63
|
-
return getattr(self.wrapped, item)
|
|
67
|
+
return getattr(self.wrapped, item)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from . import ModelProfile
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class GroqModelProfile(ModelProfile):
|
|
10
|
+
"""Profile for models used with GroqModel.
|
|
11
|
+
|
|
12
|
+
ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
groq_always_has_web_search_builtin_tool: bool = False
|
|
16
|
+
"""Whether the model always has the web search built-in tool available."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def groq_model_profile(model_name: str) -> ModelProfile:
|
|
20
|
+
"""Get the model profile for a Groq model."""
|
|
21
|
+
return GroqModelProfile(
|
|
22
|
+
groq_always_has_web_search_builtin_tool=model_name.startswith('compound-'),
|
|
23
|
+
)
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -32,7 +32,7 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
32
32
|
|
|
33
33
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
34
34
|
"""Get the model profile for an OpenAI model."""
|
|
35
|
-
is_reasoning_model = model_name.startswith('o')
|
|
35
|
+
is_reasoning_model = model_name.startswith('o') or model_name.startswith('gpt-5')
|
|
36
36
|
# Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
|
|
37
37
|
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
|
|
38
38
|
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -12,8 +12,8 @@ from pydantic_ai.profiles.google import google_model_profile
|
|
|
12
12
|
from pydantic_ai.providers import Provider
|
|
13
13
|
|
|
14
14
|
try:
|
|
15
|
-
from google import genai
|
|
16
15
|
from google.auth.credentials import Credentials
|
|
16
|
+
from google.genai import Client
|
|
17
17
|
from google.genai.types import HttpOptionsDict
|
|
18
18
|
except ImportError as _import_error:
|
|
19
19
|
raise ImportError(
|
|
@@ -22,7 +22,7 @@ except ImportError as _import_error:
|
|
|
22
22
|
) from _import_error
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class GoogleProvider(Provider[
|
|
25
|
+
class GoogleProvider(Provider[Client]):
|
|
26
26
|
"""Provider for Google."""
|
|
27
27
|
|
|
28
28
|
@property
|
|
@@ -34,7 +34,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
34
34
|
return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage]
|
|
35
35
|
|
|
36
36
|
@property
|
|
37
|
-
def client(self) ->
|
|
37
|
+
def client(self) -> Client:
|
|
38
38
|
return self._client
|
|
39
39
|
|
|
40
40
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
@@ -53,7 +53,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
53
53
|
) -> None: ...
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
|
-
def __init__(self, *, client:
|
|
56
|
+
def __init__(self, *, client: Client) -> None: ...
|
|
57
57
|
|
|
58
58
|
@overload
|
|
59
59
|
def __init__(self, *, vertexai: bool = False) -> None: ...
|
|
@@ -65,7 +65,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
65
65
|
credentials: Credentials | None = None,
|
|
66
66
|
project: str | None = None,
|
|
67
67
|
location: VertexAILocation | Literal['global'] | None = None,
|
|
68
|
-
client:
|
|
68
|
+
client: Client | None = None,
|
|
69
69
|
vertexai: bool | None = None,
|
|
70
70
|
) -> None:
|
|
71
71
|
"""Create a new Google provider.
|
|
@@ -102,9 +102,9 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
102
102
|
'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
|
|
103
103
|
'to use the Google Generative Language API.'
|
|
104
104
|
)
|
|
105
|
-
self._client =
|
|
105
|
+
self._client = Client(vertexai=vertexai, api_key=api_key, http_options=http_options)
|
|
106
106
|
else:
|
|
107
|
-
self._client =
|
|
107
|
+
self._client = Client(
|
|
108
108
|
vertexai=vertexai,
|
|
109
109
|
project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'),
|
|
110
110
|
# From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149:
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -10,6 +10,7 @@ from pydantic_ai.models import cached_async_http_client
|
|
|
10
10
|
from pydantic_ai.profiles import ModelProfile
|
|
11
11
|
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
12
12
|
from pydantic_ai.profiles.google import google_model_profile
|
|
13
|
+
from pydantic_ai.profiles.groq import groq_model_profile
|
|
13
14
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
14
15
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
15
16
|
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
@@ -49,6 +50,7 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
49
50
|
'deepseek': deepseek_model_profile,
|
|
50
51
|
'mistral': mistral_model_profile,
|
|
51
52
|
'moonshotai/': moonshotai_model_profile,
|
|
53
|
+
'compound-': groq_model_profile,
|
|
52
54
|
}
|
|
53
55
|
|
|
54
56
|
for prefix, profile_func in prefix_to_profile.items():
|