pydantic-ai-slim 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +18 -14
- pydantic_ai/_output.py +20 -105
- pydantic_ai/_run_context.py +2 -0
- pydantic_ai/_tool_manager.py +30 -11
- pydantic_ai/agent/__init__.py +34 -32
- pydantic_ai/agent/abstract.py +26 -0
- pydantic_ai/agent/wrapper.py +5 -0
- pydantic_ai/common_tools/duckduckgo.py +1 -1
- pydantic_ai/durable_exec/dbos/_agent.py +28 -0
- pydantic_ai/durable_exec/prefect/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_agent.py +25 -0
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/mcp.py +4 -4
- pydantic_ai/messages.py +5 -2
- pydantic_ai/models/__init__.py +80 -35
- pydantic_ai/models/anthropic.py +27 -8
- pydantic_ai/models/bedrock.py +3 -3
- pydantic_ai/models/cohere.py +5 -3
- pydantic_ai/models/fallback.py +25 -4
- pydantic_ai/models/function.py +8 -0
- pydantic_ai/models/gemini.py +3 -3
- pydantic_ai/models/google.py +20 -10
- pydantic_ai/models/groq.py +5 -3
- pydantic_ai/models/huggingface.py +3 -3
- pydantic_ai/models/instrumented.py +29 -13
- pydantic_ai/models/mistral.py +6 -4
- pydantic_ai/models/openai.py +11 -6
- pydantic_ai/models/outlines.py +21 -12
- pydantic_ai/models/wrapper.py +1 -1
- pydantic_ai/output.py +3 -2
- pydantic_ai/profiles/openai.py +5 -2
- pydantic_ai/result.py +5 -3
- pydantic_ai/tools.py +2 -4
- {pydantic_ai_slim-1.10.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/METADATA +9 -7
- {pydantic_ai_slim-1.10.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/RECORD +38 -38
- {pydantic_ai_slim-1.10.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.10.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.10.0.dist-info → pydantic_ai_slim-1.12.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/function.py
CHANGED
|
@@ -135,6 +135,8 @@ class FunctionModel(Model):
|
|
|
135
135
|
allow_text_output=model_request_parameters.allow_text_output,
|
|
136
136
|
output_tools=model_request_parameters.output_tools,
|
|
137
137
|
model_settings=model_settings,
|
|
138
|
+
model_request_parameters=model_request_parameters,
|
|
139
|
+
instructions=self._get_instructions(messages, model_request_parameters),
|
|
138
140
|
)
|
|
139
141
|
|
|
140
142
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -168,6 +170,8 @@ class FunctionModel(Model):
|
|
|
168
170
|
allow_text_output=model_request_parameters.allow_text_output,
|
|
169
171
|
output_tools=model_request_parameters.output_tools,
|
|
170
172
|
model_settings=model_settings,
|
|
173
|
+
model_request_parameters=model_request_parameters,
|
|
174
|
+
instructions=self._get_instructions(messages, model_request_parameters),
|
|
171
175
|
)
|
|
172
176
|
|
|
173
177
|
assert self.stream_function is not None, (
|
|
@@ -216,6 +220,10 @@ class AgentInfo:
|
|
|
216
220
|
"""The tools that can called to produce the final output of the run."""
|
|
217
221
|
model_settings: ModelSettings | None
|
|
218
222
|
"""The model settings passed to the run call."""
|
|
223
|
+
model_request_parameters: ModelRequestParameters
|
|
224
|
+
"""The model request parameters passed to the run call."""
|
|
225
|
+
instructions: str | None
|
|
226
|
+
"""The instructions passed to model."""
|
|
219
227
|
|
|
220
228
|
|
|
221
229
|
@dataclass
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -218,7 +218,7 @@ class GeminiModel(Model):
|
|
|
218
218
|
) -> AsyncIterator[HTTPResponse]:
|
|
219
219
|
tools = self._get_tools(model_request_parameters)
|
|
220
220
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
221
|
-
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)
|
|
221
|
+
sys_prompt_parts, contents = await self._message_to_gemini_content(messages, model_request_parameters)
|
|
222
222
|
|
|
223
223
|
request_data = _GeminiRequest(contents=contents)
|
|
224
224
|
if sys_prompt_parts:
|
|
@@ -331,7 +331,7 @@ class GeminiModel(Model):
|
|
|
331
331
|
)
|
|
332
332
|
|
|
333
333
|
async def _message_to_gemini_content(
|
|
334
|
-
self, messages: list[ModelMessage]
|
|
334
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
335
335
|
) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
|
|
336
336
|
sys_prompt_parts: list[_GeminiTextPart] = []
|
|
337
337
|
contents: list[_GeminiContent] = []
|
|
@@ -361,7 +361,7 @@ class GeminiModel(Model):
|
|
|
361
361
|
contents.append(_content_model_response(m))
|
|
362
362
|
else:
|
|
363
363
|
assert_never(m)
|
|
364
|
-
if instructions := self._get_instructions(messages):
|
|
364
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
365
365
|
sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions))
|
|
366
366
|
return sys_prompt_parts, contents
|
|
367
367
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import base64
|
|
4
4
|
from collections.abc import AsyncIterator, Awaitable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
@@ -224,6 +224,18 @@ class GoogleModel(Model):
|
|
|
224
224
|
"""The model provider."""
|
|
225
225
|
return self._provider.name
|
|
226
226
|
|
|
227
|
+
def prepare_request(
|
|
228
|
+
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
|
|
229
|
+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
|
|
230
|
+
if model_request_parameters.builtin_tools and model_request_parameters.output_tools:
|
|
231
|
+
if model_request_parameters.output_mode == 'auto':
|
|
232
|
+
model_request_parameters = replace(model_request_parameters, output_mode='prompted')
|
|
233
|
+
else:
|
|
234
|
+
raise UserError(
|
|
235
|
+
'Google does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
236
|
+
)
|
|
237
|
+
return super().prepare_request(model_settings, model_request_parameters)
|
|
238
|
+
|
|
227
239
|
async def request(
|
|
228
240
|
self,
|
|
229
241
|
messages: list[ModelMessage],
|
|
@@ -320,12 +332,8 @@ class GoogleModel(Model):
|
|
|
320
332
|
]
|
|
321
333
|
|
|
322
334
|
if model_request_parameters.builtin_tools:
|
|
323
|
-
if model_request_parameters.output_tools:
|
|
324
|
-
raise UserError(
|
|
325
|
-
'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
326
|
-
)
|
|
327
335
|
if model_request_parameters.function_tools:
|
|
328
|
-
raise UserError('
|
|
336
|
+
raise UserError('Google does not support function tools and built-in tools at the same time.')
|
|
329
337
|
|
|
330
338
|
for tool in model_request_parameters.builtin_tools:
|
|
331
339
|
if isinstance(tool, WebSearchTool):
|
|
@@ -402,7 +410,7 @@ class GoogleModel(Model):
|
|
|
402
410
|
if model_request_parameters.output_mode == 'native':
|
|
403
411
|
if tools:
|
|
404
412
|
raise UserError(
|
|
405
|
-
'
|
|
413
|
+
'Google does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
406
414
|
)
|
|
407
415
|
response_mime_type = 'application/json'
|
|
408
416
|
output_object = model_request_parameters.output_object
|
|
@@ -414,7 +422,7 @@ class GoogleModel(Model):
|
|
|
414
422
|
response_mime_type = 'application/json'
|
|
415
423
|
|
|
416
424
|
tool_config = self._get_tool_config(model_request_parameters, tools)
|
|
417
|
-
system_instruction, contents = await self._map_messages(messages)
|
|
425
|
+
system_instruction, contents = await self._map_messages(messages, model_request_parameters)
|
|
418
426
|
|
|
419
427
|
modalities = [Modality.TEXT.value]
|
|
420
428
|
if self.profile.supports_image_output:
|
|
@@ -504,7 +512,9 @@ class GoogleModel(Model):
|
|
|
504
512
|
_provider_name=self._provider.name,
|
|
505
513
|
)
|
|
506
514
|
|
|
507
|
-
async def _map_messages(
|
|
515
|
+
async def _map_messages(
|
|
516
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
517
|
+
) -> tuple[ContentDict | None, list[ContentUnionDict]]:
|
|
508
518
|
contents: list[ContentUnionDict] = []
|
|
509
519
|
system_parts: list[PartDict] = []
|
|
510
520
|
|
|
@@ -551,7 +561,7 @@ class GoogleModel(Model):
|
|
|
551
561
|
contents.append(_content_model_response(m, self.system))
|
|
552
562
|
else:
|
|
553
563
|
assert_never(m)
|
|
554
|
-
if instructions := self._get_instructions(messages):
|
|
564
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
555
565
|
system_parts.insert(0, {'text': instructions})
|
|
556
566
|
system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None
|
|
557
567
|
return system_instruction, contents
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -272,7 +272,7 @@ class GroqModel(Model):
|
|
|
272
272
|
else:
|
|
273
273
|
tool_choice = 'auto'
|
|
274
274
|
|
|
275
|
-
groq_messages = self._map_messages(messages)
|
|
275
|
+
groq_messages = self._map_messages(messages, model_request_parameters)
|
|
276
276
|
|
|
277
277
|
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
278
278
|
if model_request_parameters.output_mode == 'native':
|
|
@@ -388,7 +388,9 @@ class GroqModel(Model):
|
|
|
388
388
|
)
|
|
389
389
|
return tools
|
|
390
390
|
|
|
391
|
-
def _map_messages(
|
|
391
|
+
def _map_messages(
|
|
392
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
393
|
+
) -> list[chat.ChatCompletionMessageParam]:
|
|
392
394
|
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
|
|
393
395
|
groq_messages: list[chat.ChatCompletionMessageParam] = []
|
|
394
396
|
for message in messages:
|
|
@@ -423,7 +425,7 @@ class GroqModel(Model):
|
|
|
423
425
|
groq_messages.append(message_param)
|
|
424
426
|
else:
|
|
425
427
|
assert_never(message)
|
|
426
|
-
if instructions := self._get_instructions(messages):
|
|
428
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
427
429
|
groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions))
|
|
428
430
|
return groq_messages
|
|
429
431
|
|
|
@@ -231,7 +231,7 @@ class HuggingFaceModel(Model):
|
|
|
231
231
|
if model_request_parameters.builtin_tools:
|
|
232
232
|
raise UserError('HuggingFace does not support built-in tools')
|
|
233
233
|
|
|
234
|
-
hf_messages = await self._map_messages(messages)
|
|
234
|
+
hf_messages = await self._map_messages(messages, model_request_parameters)
|
|
235
235
|
|
|
236
236
|
try:
|
|
237
237
|
return await self.client.chat.completions.create( # type: ignore
|
|
@@ -322,7 +322,7 @@ class HuggingFaceModel(Model):
|
|
|
322
322
|
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
|
|
323
323
|
|
|
324
324
|
async def _map_messages(
|
|
325
|
-
self, messages: list[ModelMessage]
|
|
325
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
326
326
|
) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
327
327
|
"""Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
|
|
328
328
|
hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
|
|
@@ -359,7 +359,7 @@ class HuggingFaceModel(Model):
|
|
|
359
359
|
hf_messages.append(message_param)
|
|
360
360
|
else:
|
|
361
361
|
assert_never(message)
|
|
362
|
-
if instructions := self._get_instructions(messages):
|
|
362
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
363
363
|
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
|
|
364
364
|
return hf_messages
|
|
365
365
|
|
|
@@ -178,17 +178,20 @@ class InstrumentationSettings:
|
|
|
178
178
|
description='Monetary cost',
|
|
179
179
|
)
|
|
180
180
|
|
|
181
|
-
def messages_to_otel_events(
|
|
181
|
+
def messages_to_otel_events(
|
|
182
|
+
self, messages: list[ModelMessage], parameters: ModelRequestParameters | None = None
|
|
183
|
+
) -> list[Event]:
|
|
182
184
|
"""Convert a list of model messages to OpenTelemetry events.
|
|
183
185
|
|
|
184
186
|
Args:
|
|
185
187
|
messages: The messages to convert.
|
|
188
|
+
parameters: The model request parameters.
|
|
186
189
|
|
|
187
190
|
Returns:
|
|
188
191
|
A list of OpenTelemetry events.
|
|
189
192
|
"""
|
|
190
193
|
events: list[Event] = []
|
|
191
|
-
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
|
|
194
|
+
instructions = InstrumentedModel._get_instructions(messages, parameters) # pyright: ignore [reportPrivateUsage]
|
|
192
195
|
if instructions is not None:
|
|
193
196
|
events.append(
|
|
194
197
|
Event(
|
|
@@ -235,10 +238,17 @@ class InstrumentationSettings:
|
|
|
235
238
|
result.append(otel_message)
|
|
236
239
|
return result
|
|
237
240
|
|
|
238
|
-
def handle_messages(
|
|
241
|
+
def handle_messages(
|
|
242
|
+
self,
|
|
243
|
+
input_messages: list[ModelMessage],
|
|
244
|
+
response: ModelResponse,
|
|
245
|
+
system: str,
|
|
246
|
+
span: Span,
|
|
247
|
+
parameters: ModelRequestParameters | None = None,
|
|
248
|
+
):
|
|
239
249
|
if self.version == 1:
|
|
240
|
-
events = self.messages_to_otel_events(input_messages)
|
|
241
|
-
for event in self.messages_to_otel_events([response]):
|
|
250
|
+
events = self.messages_to_otel_events(input_messages, parameters)
|
|
251
|
+
for event in self.messages_to_otel_events([response], parameters):
|
|
242
252
|
events.append(
|
|
243
253
|
Event(
|
|
244
254
|
'gen_ai.choice',
|
|
@@ -258,7 +268,7 @@ class InstrumentationSettings:
|
|
|
258
268
|
output_messages = self.messages_to_otel_messages([response])
|
|
259
269
|
assert len(output_messages) == 1
|
|
260
270
|
output_message = output_messages[0]
|
|
261
|
-
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
271
|
+
instructions = InstrumentedModel._get_instructions(input_messages, parameters) # pyright: ignore [reportPrivateUsage]
|
|
262
272
|
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
263
273
|
attributes: dict[str, AttributeValue] = {
|
|
264
274
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
@@ -360,7 +370,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
360
370
|
)
|
|
361
371
|
with self._instrument(messages, prepared_settings, prepared_parameters) as finish:
|
|
362
372
|
response = await self.wrapped.request(messages, model_settings, model_request_parameters)
|
|
363
|
-
finish(response)
|
|
373
|
+
finish(response, prepared_parameters)
|
|
364
374
|
return response
|
|
365
375
|
|
|
366
376
|
@asynccontextmanager
|
|
@@ -384,7 +394,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
384
394
|
yield response_stream
|
|
385
395
|
finally:
|
|
386
396
|
if response_stream: # pragma: no branch
|
|
387
|
-
finish(response_stream.get())
|
|
397
|
+
finish(response_stream.get(), prepared_parameters)
|
|
388
398
|
|
|
389
399
|
@contextmanager
|
|
390
400
|
def _instrument(
|
|
@@ -392,7 +402,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
392
402
|
messages: list[ModelMessage],
|
|
393
403
|
model_settings: ModelSettings | None,
|
|
394
404
|
model_request_parameters: ModelRequestParameters,
|
|
395
|
-
) -> Iterator[Callable[[ModelResponse], None]]:
|
|
405
|
+
) -> Iterator[Callable[[ModelResponse, ModelRequestParameters], None]]:
|
|
396
406
|
operation = 'chat'
|
|
397
407
|
span_name = f'{operation} {self.model_name}'
|
|
398
408
|
# TODO Missing attributes:
|
|
@@ -401,7 +411,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
401
411
|
attributes: dict[str, AttributeValue] = {
|
|
402
412
|
'gen_ai.operation.name': operation,
|
|
403
413
|
**self.model_attributes(self.wrapped),
|
|
404
|
-
|
|
414
|
+
**self.model_request_parameters_attributes(model_request_parameters),
|
|
405
415
|
'logfire.json_schema': json.dumps(
|
|
406
416
|
{
|
|
407
417
|
'type': 'object',
|
|
@@ -419,7 +429,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
419
429
|
try:
|
|
420
430
|
with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
421
431
|
|
|
422
|
-
def finish(response: ModelResponse):
|
|
432
|
+
def finish(response: ModelResponse, parameters: ModelRequestParameters):
|
|
423
433
|
# FallbackModel updates these span attributes.
|
|
424
434
|
attributes.update(getattr(span, 'attributes', {}))
|
|
425
435
|
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
@@ -443,7 +453,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
443
453
|
if not span.is_recording():
|
|
444
454
|
return
|
|
445
455
|
|
|
446
|
-
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
456
|
+
self.instrumentation_settings.handle_messages(messages, response, system, span, parameters)
|
|
447
457
|
|
|
448
458
|
attributes_to_set = {
|
|
449
459
|
**response.usage.opentelemetry_attributes(),
|
|
@@ -476,7 +486,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
476
486
|
record_metrics()
|
|
477
487
|
|
|
478
488
|
@staticmethod
|
|
479
|
-
def model_attributes(model: Model):
|
|
489
|
+
def model_attributes(model: Model) -> dict[str, AttributeValue]:
|
|
480
490
|
attributes: dict[str, AttributeValue] = {
|
|
481
491
|
GEN_AI_SYSTEM_ATTRIBUTE: model.system,
|
|
482
492
|
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
|
|
@@ -494,6 +504,12 @@ class InstrumentedModel(WrapperModel):
|
|
|
494
504
|
|
|
495
505
|
return attributes
|
|
496
506
|
|
|
507
|
+
@staticmethod
|
|
508
|
+
def model_request_parameters_attributes(
|
|
509
|
+
model_request_parameters: ModelRequestParameters,
|
|
510
|
+
) -> dict[str, AttributeValue]:
|
|
511
|
+
return {'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters))}
|
|
512
|
+
|
|
497
513
|
@staticmethod
|
|
498
514
|
def event_to_dict(event: Event) -> dict[str, Any]:
|
|
499
515
|
if not event.body:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -230,7 +230,7 @@ class MistralModel(Model):
|
|
|
230
230
|
try:
|
|
231
231
|
response = await self.client.chat.complete_async(
|
|
232
232
|
model=str(self._model_name),
|
|
233
|
-
messages=self._map_messages(messages),
|
|
233
|
+
messages=self._map_messages(messages, model_request_parameters),
|
|
234
234
|
n=1,
|
|
235
235
|
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
236
236
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
@@ -259,7 +259,7 @@ class MistralModel(Model):
|
|
|
259
259
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
260
260
|
"""Create a streaming completion request to the Mistral model."""
|
|
261
261
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
262
|
-
mistral_messages = self._map_messages(messages)
|
|
262
|
+
mistral_messages = self._map_messages(messages, model_request_parameters)
|
|
263
263
|
|
|
264
264
|
# TODO(Marcelo): We need to replace the current MistralAI client to use the beta client.
|
|
265
265
|
# See https://docs.mistral.ai/agents/connectors/websearch/ to support web search.
|
|
@@ -523,7 +523,9 @@ class MistralModel(Model):
|
|
|
523
523
|
else:
|
|
524
524
|
assert_never(part)
|
|
525
525
|
|
|
526
|
-
def _map_messages(
|
|
526
|
+
def _map_messages(
|
|
527
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
528
|
+
) -> list[MistralMessages]:
|
|
527
529
|
"""Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
|
|
528
530
|
mistral_messages: list[MistralMessages] = []
|
|
529
531
|
for message in messages:
|
|
@@ -554,7 +556,7 @@ class MistralModel(Model):
|
|
|
554
556
|
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
555
557
|
else:
|
|
556
558
|
assert_never(message)
|
|
557
|
-
if instructions := self._get_instructions(messages):
|
|
559
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
558
560
|
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
|
|
559
561
|
|
|
560
562
|
# Post-process messages to insert fake assistant message after tool message if followed by user message
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -477,7 +477,7 @@ class OpenAIChatModel(Model):
|
|
|
477
477
|
else:
|
|
478
478
|
tool_choice = 'auto'
|
|
479
479
|
|
|
480
|
-
openai_messages = await self._map_messages(messages)
|
|
480
|
+
openai_messages = await self._map_messages(messages, model_request_parameters)
|
|
481
481
|
|
|
482
482
|
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
483
483
|
if model_request_parameters.output_mode == 'native':
|
|
@@ -672,7 +672,9 @@ class OpenAIChatModel(Model):
|
|
|
672
672
|
f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.'
|
|
673
673
|
)
|
|
674
674
|
|
|
675
|
-
async def _map_messages(
|
|
675
|
+
async def _map_messages(
|
|
676
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
677
|
+
) -> list[chat.ChatCompletionMessageParam]:
|
|
676
678
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
677
679
|
openai_messages: list[chat.ChatCompletionMessageParam] = []
|
|
678
680
|
for message in messages:
|
|
@@ -713,7 +715,7 @@ class OpenAIChatModel(Model):
|
|
|
713
715
|
openai_messages.append(message_param)
|
|
714
716
|
else:
|
|
715
717
|
assert_never(message)
|
|
716
|
-
if instructions := self._get_instructions(messages):
|
|
718
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
717
719
|
openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
|
|
718
720
|
return openai_messages
|
|
719
721
|
|
|
@@ -1164,7 +1166,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1164
1166
|
if previous_response_id == 'auto':
|
|
1165
1167
|
previous_response_id, messages = self._get_previous_response_id_and_new_messages(messages)
|
|
1166
1168
|
|
|
1167
|
-
instructions, openai_messages = await self._map_messages(messages, model_settings)
|
|
1169
|
+
instructions, openai_messages = await self._map_messages(messages, model_settings, model_request_parameters)
|
|
1168
1170
|
reasoning = self._get_reasoning(model_settings)
|
|
1169
1171
|
|
|
1170
1172
|
text: responses.ResponseTextConfigParam | None = None
|
|
@@ -1352,7 +1354,10 @@ class OpenAIResponsesModel(Model):
|
|
|
1352
1354
|
return None, messages
|
|
1353
1355
|
|
|
1354
1356
|
async def _map_messages( # noqa: C901
|
|
1355
|
-
self,
|
|
1357
|
+
self,
|
|
1358
|
+
messages: list[ModelMessage],
|
|
1359
|
+
model_settings: OpenAIResponsesModelSettings,
|
|
1360
|
+
model_request_parameters: ModelRequestParameters,
|
|
1356
1361
|
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
|
|
1357
1362
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
1358
1363
|
profile = OpenAIModelProfile.from_profile(self.profile)
|
|
@@ -1577,7 +1582,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1577
1582
|
assert_never(item)
|
|
1578
1583
|
else:
|
|
1579
1584
|
assert_never(message)
|
|
1580
|
-
instructions = self._get_instructions(messages) or NOT_GIVEN
|
|
1585
|
+
instructions = self._get_instructions(messages, model_request_parameters) or NOT_GIVEN
|
|
1581
1586
|
return instructions, openai_messages
|
|
1582
1587
|
|
|
1583
1588
|
def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam:
|
pydantic_ai/models/outlines.py
CHANGED
|
@@ -8,14 +8,13 @@ from __future__ import annotations
|
|
|
8
8
|
import io
|
|
9
9
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
10
10
|
from contextlib import asynccontextmanager
|
|
11
|
-
from dataclasses import dataclass
|
|
11
|
+
from dataclasses import dataclass, replace
|
|
12
12
|
from datetime import datetime, timezone
|
|
13
13
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
14
14
|
|
|
15
15
|
from typing_extensions import assert_never
|
|
16
16
|
|
|
17
17
|
from .. import UnexpectedModelBehavior, _utils
|
|
18
|
-
from .._output import PromptedOutputSchema
|
|
19
18
|
from .._run_context import RunContext
|
|
20
19
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
21
20
|
from ..exceptions import UserError
|
|
@@ -247,6 +246,10 @@ class OutlinesModel(Model):
|
|
|
247
246
|
model_settings: ModelSettings | None,
|
|
248
247
|
model_request_parameters: ModelRequestParameters,
|
|
249
248
|
) -> ModelResponse:
|
|
249
|
+
model_settings, model_request_parameters = self.prepare_request(
|
|
250
|
+
model_settings,
|
|
251
|
+
model_request_parameters,
|
|
252
|
+
)
|
|
250
253
|
"""Make a request to the model."""
|
|
251
254
|
prompt, output_type, inference_kwargs = await self._build_generation_arguments(
|
|
252
255
|
messages, model_settings, model_request_parameters
|
|
@@ -267,6 +270,11 @@ class OutlinesModel(Model):
|
|
|
267
270
|
model_request_parameters: ModelRequestParameters,
|
|
268
271
|
run_context: RunContext[Any] | None = None,
|
|
269
272
|
) -> AsyncIterator[StreamedResponse]:
|
|
273
|
+
model_settings, model_request_parameters = self.prepare_request(
|
|
274
|
+
model_settings,
|
|
275
|
+
model_request_parameters,
|
|
276
|
+
)
|
|
277
|
+
|
|
270
278
|
prompt, output_type, inference_kwargs = await self._build_generation_arguments(
|
|
271
279
|
messages, model_settings, model_request_parameters
|
|
272
280
|
)
|
|
@@ -298,15 +306,11 @@ class OutlinesModel(Model):
|
|
|
298
306
|
raise UserError('Outlines does not support function tools and builtin tools yet.')
|
|
299
307
|
|
|
300
308
|
if model_request_parameters.output_object:
|
|
301
|
-
instructions = PromptedOutputSchema.build_instructions(
|
|
302
|
-
self.profile.prompted_output_template, model_request_parameters.output_object
|
|
303
|
-
)
|
|
304
309
|
output_type = JsonSchema(model_request_parameters.output_object.json_schema)
|
|
305
310
|
else:
|
|
306
|
-
instructions = None
|
|
307
311
|
output_type = None
|
|
308
312
|
|
|
309
|
-
prompt = await self._format_prompt(messages,
|
|
313
|
+
prompt = await self._format_prompt(messages, model_request_parameters)
|
|
310
314
|
inference_kwargs = self.format_inference_kwargs(model_settings)
|
|
311
315
|
|
|
312
316
|
return prompt, output_type, inference_kwargs
|
|
@@ -416,17 +420,14 @@ class OutlinesModel(Model):
|
|
|
416
420
|
return filtered_settings
|
|
417
421
|
|
|
418
422
|
async def _format_prompt( # noqa: C901
|
|
419
|
-
self, messages: list[ModelMessage],
|
|
423
|
+
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
|
|
420
424
|
) -> Chat:
|
|
421
425
|
"""Turn the model messages into an Outlines Chat instance."""
|
|
422
426
|
chat = Chat()
|
|
423
427
|
|
|
424
|
-
if instructions := self._get_instructions(messages):
|
|
428
|
+
if instructions := self._get_instructions(messages, model_request_parameters):
|
|
425
429
|
chat.add_system_message(instructions)
|
|
426
430
|
|
|
427
|
-
if output_format_instructions:
|
|
428
|
-
chat.add_system_message(output_format_instructions)
|
|
429
|
-
|
|
430
431
|
for message in messages:
|
|
431
432
|
if isinstance(message, ModelRequest):
|
|
432
433
|
for part in message.parts:
|
|
@@ -525,6 +526,14 @@ class OutlinesModel(Model):
|
|
|
525
526
|
_provider_name='outlines',
|
|
526
527
|
)
|
|
527
528
|
|
|
529
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
530
|
+
"""Customize the model request parameters for the model."""
|
|
531
|
+
if model_request_parameters.output_mode in ('auto', 'native'):
|
|
532
|
+
# This way the JSON schema will be included in the instructions.
|
|
533
|
+
return replace(model_request_parameters, output_mode='prompted')
|
|
534
|
+
else:
|
|
535
|
+
return model_request_parameters
|
|
536
|
+
|
|
528
537
|
|
|
529
538
|
@dataclass
|
|
530
539
|
class OutlinesStreamedResponse(StreamedResponse):
|
pydantic_ai/models/wrapper.py
CHANGED
|
@@ -44,7 +44,7 @@ class WrapperModel(Model):
|
|
|
44
44
|
yield response_stream
|
|
45
45
|
|
|
46
46
|
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
47
|
-
return self.wrapped.customize_request_parameters(model_request_parameters)
|
|
47
|
+
return self.wrapped.customize_request_parameters(model_request_parameters) # pragma: no cover
|
|
48
48
|
|
|
49
49
|
def prepare_request(
|
|
50
50
|
self,
|
pydantic_ai/output.py
CHANGED
|
@@ -37,10 +37,11 @@ T_co = TypeVar('T_co', covariant=True)
|
|
|
37
37
|
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
|
|
38
38
|
"""Covariant type variable for the output data type of a run."""
|
|
39
39
|
|
|
40
|
-
OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image']
|
|
40
|
+
OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image', 'auto']
|
|
41
41
|
"""All output modes.
|
|
42
42
|
|
|
43
|
-
`tool_or_text` is deprecated and no longer in use.
|
|
43
|
+
- `tool_or_text` is deprecated and no longer in use.
|
|
44
|
+
- `auto` means the model will automatically choose a structured output mode based on the model's `ModelProfile.default_structured_output_mode`.
|
|
44
45
|
"""
|
|
45
46
|
StructuredOutputMode = Literal['tool', 'native', 'prompted']
|
|
46
47
|
"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode"""
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -62,7 +62,10 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
62
62
|
|
|
63
63
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
64
64
|
"""Get the model profile for an OpenAI model."""
|
|
65
|
-
|
|
65
|
+
is_gpt_5 = model_name.startswith('gpt-5')
|
|
66
|
+
is_o_series = model_name.startswith('o')
|
|
67
|
+
is_reasoning_model = is_o_series or (is_gpt_5 and 'gpt-5-chat' not in model_name)
|
|
68
|
+
|
|
66
69
|
# Check if the model supports web search (only specific search-preview models)
|
|
67
70
|
supports_web_search = '-search-preview' in model_name
|
|
68
71
|
|
|
@@ -91,7 +94,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
|
|
|
91
94
|
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
92
95
|
supports_json_schema_output=True,
|
|
93
96
|
supports_json_object_output=True,
|
|
94
|
-
supports_image_output=
|
|
97
|
+
supports_image_output=is_gpt_5 or 'o3' in model_name or '4.1' in model_name or '4o' in model_name,
|
|
95
98
|
openai_unsupported_model_settings=openai_unsupported_model_settings,
|
|
96
99
|
openai_system_prompt_role=openai_system_prompt_role,
|
|
97
100
|
openai_chat_supports_web_search=supports_web_search,
|
pydantic_ai/result.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from dataclasses import dataclass, field
|
|
5
|
+
from dataclasses import dataclass, field, replace
|
|
6
6
|
from datetime import datetime
|
|
7
7
|
from typing import TYPE_CHECKING, Generic, cast, overload
|
|
8
8
|
|
|
@@ -117,7 +117,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
117
117
|
else:
|
|
118
118
|
async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
|
|
119
119
|
for validator in self._output_validators:
|
|
120
|
-
text = await validator.validate(text, self._run_ctx)
|
|
120
|
+
text = await validator.validate(text, replace(self._run_ctx, partial_output=True))
|
|
121
121
|
yield text
|
|
122
122
|
|
|
123
123
|
# TODO (v2): Drop in favor of `response` property
|
|
@@ -195,7 +195,9 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
195
195
|
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
196
196
|
)
|
|
197
197
|
for validator in self._output_validators:
|
|
198
|
-
result_data = await validator.validate(
|
|
198
|
+
result_data = await validator.validate(
|
|
199
|
+
result_data, replace(self._run_ctx, partial_output=allow_partial)
|
|
200
|
+
)
|
|
199
201
|
return result_data
|
|
200
202
|
else:
|
|
201
203
|
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
pydantic_ai/tools.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from collections.abc import Awaitable, Callable, Sequence
|
|
4
|
-
from dataclasses import KW_ONLY, dataclass, field
|
|
4
|
+
from dataclasses import KW_ONLY, dataclass, field
|
|
5
5
|
from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast
|
|
6
6
|
|
|
7
7
|
from pydantic import Discriminator, Tag
|
|
@@ -415,6 +415,7 @@ class Tool(Generic[ToolAgentDepsT]):
|
|
|
415
415
|
strict=self.strict,
|
|
416
416
|
sequential=self.sequential,
|
|
417
417
|
metadata=self.metadata,
|
|
418
|
+
kind='unapproved' if self.requires_approval else 'function',
|
|
418
419
|
)
|
|
419
420
|
|
|
420
421
|
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
|
|
@@ -428,9 +429,6 @@ class Tool(Generic[ToolAgentDepsT]):
|
|
|
428
429
|
"""
|
|
429
430
|
base_tool_def = self.tool_def
|
|
430
431
|
|
|
431
|
-
if self.requires_approval and not ctx.tool_call_approved:
|
|
432
|
-
base_tool_def = replace(base_tool_def, kind='unapproved')
|
|
433
|
-
|
|
434
432
|
if self.prepare is not None:
|
|
435
433
|
return await self.prepare(ctx, base_tool_def)
|
|
436
434
|
else:
|