pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +70 -59
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +511 -161
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +24 -4
- pydantic_ai/models/_json_schema.py +160 -0
- pydantic_ai/models/anthropic.py +5 -3
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +82 -75
- pydantic_ai/models/groq.py +32 -28
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +62 -58
- pydantic_ai/models/openai.py +110 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +4 -4
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.1.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.55.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
|
-
from itertools import chain
|
|
9
8
|
from typing import Any, Literal, Union, cast
|
|
10
9
|
|
|
11
10
|
import pydantic_core
|
|
@@ -29,11 +28,12 @@ from ..messages import (
|
|
|
29
28
|
ToolCallPart,
|
|
30
29
|
ToolReturnPart,
|
|
31
30
|
UserPromptPart,
|
|
31
|
+
VideoUrl,
|
|
32
32
|
)
|
|
33
33
|
from ..providers import Provider, infer_provider
|
|
34
|
-
from ..result import Usage
|
|
35
34
|
from ..settings import ModelSettings
|
|
36
35
|
from ..tools import ToolDefinition
|
|
36
|
+
from ..usage import Usage
|
|
37
37
|
from . import (
|
|
38
38
|
Model,
|
|
39
39
|
ModelRequestParameters,
|
|
@@ -168,7 +168,7 @@ class MistralModel(Model):
|
|
|
168
168
|
messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
|
|
169
169
|
)
|
|
170
170
|
async with response:
|
|
171
|
-
yield await self._process_streamed_response(model_request_parameters.
|
|
171
|
+
yield await self._process_streamed_response(model_request_parameters.output_tools, response)
|
|
172
172
|
|
|
173
173
|
@property
|
|
174
174
|
def model_name(self) -> MistralModelName:
|
|
@@ -190,9 +190,9 @@ class MistralModel(Model):
|
|
|
190
190
|
try:
|
|
191
191
|
response = await self.client.chat.complete_async(
|
|
192
192
|
model=str(self._model_name),
|
|
193
|
-
messages=
|
|
193
|
+
messages=self._map_messages(messages),
|
|
194
194
|
n=1,
|
|
195
|
-
tools=self.
|
|
195
|
+
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
196
196
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
197
197
|
stream=False,
|
|
198
198
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
@@ -219,10 +219,10 @@ class MistralModel(Model):
|
|
|
219
219
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
220
220
|
"""Create a streaming completion request to the Mistral model."""
|
|
221
221
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
222
|
-
mistral_messages =
|
|
222
|
+
mistral_messages = self._map_messages(messages)
|
|
223
223
|
|
|
224
224
|
if (
|
|
225
|
-
model_request_parameters.
|
|
225
|
+
model_request_parameters.output_tools
|
|
226
226
|
and model_request_parameters.function_tools
|
|
227
227
|
or model_request_parameters.function_tools
|
|
228
228
|
):
|
|
@@ -231,7 +231,7 @@ class MistralModel(Model):
|
|
|
231
231
|
model=str(self._model_name),
|
|
232
232
|
messages=mistral_messages,
|
|
233
233
|
n=1,
|
|
234
|
-
tools=self.
|
|
234
|
+
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
|
|
235
235
|
tool_choice=self._get_tool_choice(model_request_parameters),
|
|
236
236
|
temperature=model_settings.get('temperature', UNSET),
|
|
237
237
|
top_p=model_settings.get('top_p', 1),
|
|
@@ -243,9 +243,9 @@ class MistralModel(Model):
|
|
|
243
243
|
http_headers={'User-Agent': get_user_agent()},
|
|
244
244
|
)
|
|
245
245
|
|
|
246
|
-
elif model_request_parameters.
|
|
246
|
+
elif model_request_parameters.output_tools:
|
|
247
247
|
# Json Mode
|
|
248
|
-
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.
|
|
248
|
+
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools]
|
|
249
249
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
250
250
|
mistral_messages.append(user_output_format_message)
|
|
251
251
|
|
|
@@ -276,22 +276,22 @@ class MistralModel(Model):
|
|
|
276
276
|
- "none": Prevents tool use.
|
|
277
277
|
- "required": Forces tool use.
|
|
278
278
|
"""
|
|
279
|
-
if not model_request_parameters.function_tools and not model_request_parameters.
|
|
279
|
+
if not model_request_parameters.function_tools and not model_request_parameters.output_tools:
|
|
280
280
|
return None
|
|
281
|
-
elif not model_request_parameters.
|
|
281
|
+
elif not model_request_parameters.allow_text_output:
|
|
282
282
|
return 'required'
|
|
283
283
|
else:
|
|
284
284
|
return 'auto'
|
|
285
285
|
|
|
286
|
-
def
|
|
286
|
+
def _map_function_and_output_tools_definition(
|
|
287
287
|
self, model_request_parameters: ModelRequestParameters
|
|
288
288
|
) -> list[MistralTool] | None:
|
|
289
|
-
"""Map function and
|
|
289
|
+
"""Map function and output tools to MistralTool format.
|
|
290
290
|
|
|
291
|
-
Returns None if both function_tools and
|
|
291
|
+
Returns None if both function_tools and output_tools are empty.
|
|
292
292
|
"""
|
|
293
293
|
all_tools: list[ToolDefinition] = (
|
|
294
|
-
model_request_parameters.function_tools + model_request_parameters.
|
|
294
|
+
model_request_parameters.function_tools + model_request_parameters.output_tools
|
|
295
295
|
)
|
|
296
296
|
tools = [
|
|
297
297
|
MistralTool(
|
|
@@ -327,7 +327,7 @@ class MistralModel(Model):
|
|
|
327
327
|
|
|
328
328
|
async def _process_streamed_response(
|
|
329
329
|
self,
|
|
330
|
-
|
|
330
|
+
output_tools: list[ToolDefinition],
|
|
331
331
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
332
332
|
) -> StreamedResponse:
|
|
333
333
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -345,7 +345,7 @@ class MistralModel(Model):
|
|
|
345
345
|
_response=peekable_response,
|
|
346
346
|
_model_name=self._model_name,
|
|
347
347
|
_timestamp=timestamp,
|
|
348
|
-
|
|
348
|
+
_output_tools={c.name: c for c in output_tools},
|
|
349
349
|
)
|
|
350
350
|
|
|
351
351
|
@staticmethod
|
|
@@ -439,13 +439,12 @@ class MistralModel(Model):
|
|
|
439
439
|
return int(1000 * timeout)
|
|
440
440
|
raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
|
|
441
441
|
|
|
442
|
-
|
|
443
|
-
def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
|
|
442
|
+
def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]:
|
|
444
443
|
for part in message.parts:
|
|
445
444
|
if isinstance(part, SystemPromptPart):
|
|
446
445
|
yield MistralSystemMessage(content=part.content)
|
|
447
446
|
elif isinstance(part, UserPromptPart):
|
|
448
|
-
yield
|
|
447
|
+
yield self._map_user_prompt(part)
|
|
449
448
|
elif isinstance(part, ToolReturnPart):
|
|
450
449
|
yield MistralToolMessage(
|
|
451
450
|
tool_call_id=part.tool_call_id,
|
|
@@ -462,28 +461,31 @@ class MistralModel(Model):
|
|
|
462
461
|
else:
|
|
463
462
|
assert_never(part)
|
|
464
463
|
|
|
465
|
-
|
|
466
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
|
|
464
|
+
def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
|
|
467
465
|
"""Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
466
|
+
mistral_messages: list[MistralMessages] = []
|
|
467
|
+
for message in messages:
|
|
468
|
+
if isinstance(message, ModelRequest):
|
|
469
|
+
mistral_messages.extend(self._map_user_message(message))
|
|
470
|
+
elif isinstance(message, ModelResponse):
|
|
471
|
+
content_chunks: list[MistralContentChunk] = []
|
|
472
|
+
tool_calls: list[MistralToolCall] = []
|
|
473
|
+
|
|
474
|
+
for part in message.parts:
|
|
475
|
+
if isinstance(part, TextPart):
|
|
476
|
+
content_chunks.append(MistralTextChunk(text=part.content))
|
|
477
|
+
elif isinstance(part, ToolCallPart):
|
|
478
|
+
tool_calls.append(self._map_tool_call(part))
|
|
479
|
+
else:
|
|
480
|
+
assert_never(part)
|
|
481
|
+
mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
|
|
482
|
+
else:
|
|
483
|
+
assert_never(message)
|
|
484
|
+
if instructions := self._get_instructions(messages):
|
|
485
|
+
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
|
|
486
|
+
return mistral_messages
|
|
484
487
|
|
|
485
|
-
|
|
486
|
-
def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
|
|
488
|
+
def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
|
|
487
489
|
content: str | list[MistralContentChunk]
|
|
488
490
|
if isinstance(part.content, str):
|
|
489
491
|
content = part.content
|
|
@@ -503,6 +505,8 @@ class MistralModel(Model):
|
|
|
503
505
|
raise RuntimeError('Only image binary content is supported for Mistral.')
|
|
504
506
|
elif isinstance(item, DocumentUrl):
|
|
505
507
|
raise RuntimeError('DocumentUrl is not supported in Mistral.')
|
|
508
|
+
elif isinstance(item, VideoUrl):
|
|
509
|
+
raise RuntimeError('VideoUrl is not supported in Mistral.')
|
|
506
510
|
else: # pragma: no cover
|
|
507
511
|
raise RuntimeError(f'Unsupported content type: {type(item)}')
|
|
508
512
|
return MistralUserMessage(content=content)
|
|
@@ -518,7 +522,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
518
522
|
_model_name: MistralModelName
|
|
519
523
|
_response: AsyncIterable[MistralCompletionEvent]
|
|
520
524
|
_timestamp: datetime
|
|
521
|
-
|
|
525
|
+
_output_tools: dict[str, ToolDefinition]
|
|
522
526
|
|
|
523
527
|
_delta_content: str = field(default='', init=False)
|
|
524
528
|
|
|
@@ -536,13 +540,13 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
536
540
|
content = choice.delta.content
|
|
537
541
|
text = _map_content(content)
|
|
538
542
|
if text:
|
|
539
|
-
# Attempt to produce
|
|
540
|
-
if self.
|
|
543
|
+
# Attempt to produce an output tool call from the received text
|
|
544
|
+
if self._output_tools:
|
|
541
545
|
self._delta_content += text
|
|
542
|
-
maybe_tool_call_part = self.
|
|
546
|
+
maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools)
|
|
543
547
|
if maybe_tool_call_part:
|
|
544
548
|
yield self._parts_manager.handle_tool_call_part(
|
|
545
|
-
vendor_part_id='
|
|
549
|
+
vendor_part_id='output',
|
|
546
550
|
tool_name=maybe_tool_call_part.tool_name,
|
|
547
551
|
args=maybe_tool_call_part.args_as_dict(),
|
|
548
552
|
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
@@ -568,20 +572,20 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
568
572
|
return self._timestamp
|
|
569
573
|
|
|
570
574
|
@staticmethod
|
|
571
|
-
def
|
|
575
|
+
def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
|
|
572
576
|
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
|
|
573
577
|
if output_json:
|
|
574
|
-
for
|
|
575
|
-
# NOTE: Additional verification to prevent JSON validation to crash
|
|
578
|
+
for output_tool in output_tools.values():
|
|
579
|
+
# NOTE: Additional verification to prevent JSON validation to crash
|
|
576
580
|
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
577
581
|
# Example with BaseModel and required fields.
|
|
578
582
|
if not MistralStreamedResponse._validate_required_json_schema(
|
|
579
|
-
output_json,
|
|
583
|
+
output_json, output_tool.parameters_json_schema
|
|
580
584
|
):
|
|
581
585
|
continue
|
|
582
586
|
|
|
583
587
|
# The following part_id will be thrown away
|
|
584
|
-
return ToolCallPart(tool_name=
|
|
588
|
+
return ToolCallPart(tool_name=output_tool.name, args=output_json)
|
|
585
589
|
|
|
586
590
|
@staticmethod
|
|
587
591
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
@@ -649,21 +653,21 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
|
|
|
649
653
|
|
|
650
654
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
|
651
655
|
"""Maps the delta content from a Mistral Completion Chunk to a string or None."""
|
|
652
|
-
|
|
656
|
+
output: str | None = None
|
|
653
657
|
|
|
654
658
|
if isinstance(content, MistralUnset) or not content:
|
|
655
|
-
|
|
659
|
+
output = None
|
|
656
660
|
elif isinstance(content, list):
|
|
657
661
|
for chunk in content:
|
|
658
662
|
if isinstance(chunk, MistralTextChunk):
|
|
659
|
-
|
|
663
|
+
output = output or '' + chunk.text
|
|
660
664
|
else:
|
|
661
665
|
assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
|
|
662
666
|
elif isinstance(content, str):
|
|
663
|
-
|
|
667
|
+
output = content
|
|
664
668
|
|
|
665
669
|
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
666
|
-
if
|
|
667
|
-
|
|
670
|
+
if output and len(output) == 0: # pragma: no cover
|
|
671
|
+
output = None
|
|
668
672
|
|
|
669
|
-
return
|
|
673
|
+
return output
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -30,6 +30,7 @@ from ..messages import (
|
|
|
30
30
|
ToolCallPart,
|
|
31
31
|
ToolReturnPart,
|
|
32
32
|
UserPromptPart,
|
|
33
|
+
VideoUrl,
|
|
33
34
|
)
|
|
34
35
|
from ..settings import ModelSettings
|
|
35
36
|
from ..tools import ToolDefinition
|
|
@@ -41,6 +42,7 @@ from . import (
|
|
|
41
42
|
check_allow_model_requests,
|
|
42
43
|
get_user_agent,
|
|
43
44
|
)
|
|
45
|
+
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
44
46
|
|
|
45
47
|
try:
|
|
46
48
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
@@ -252,15 +254,12 @@ class OpenAIModel(Model):
|
|
|
252
254
|
# standalone function to make it easier to override
|
|
253
255
|
if not tools:
|
|
254
256
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
255
|
-
elif not model_request_parameters.
|
|
257
|
+
elif not model_request_parameters.allow_text_output:
|
|
256
258
|
tool_choice = 'required'
|
|
257
259
|
else:
|
|
258
260
|
tool_choice = 'auto'
|
|
259
261
|
|
|
260
|
-
openai_messages
|
|
261
|
-
for m in messages:
|
|
262
|
-
async for msg in self._map_message(m):
|
|
263
|
-
openai_messages.append(msg)
|
|
262
|
+
openai_messages = await self._map_messages(messages)
|
|
264
263
|
|
|
265
264
|
try:
|
|
266
265
|
return await self.client.chat.completions.create(
|
|
@@ -317,35 +316,40 @@ class OpenAIModel(Model):
|
|
|
317
316
|
|
|
318
317
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
319
318
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
320
|
-
if model_request_parameters.
|
|
321
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
319
|
+
if model_request_parameters.output_tools:
|
|
320
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
322
321
|
return tools
|
|
323
322
|
|
|
324
|
-
async def
|
|
323
|
+
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
|
|
325
324
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
325
|
+
openai_messages: list[chat.ChatCompletionMessageParam] = []
|
|
326
|
+
for message in messages:
|
|
327
|
+
if isinstance(message, ModelRequest):
|
|
328
|
+
async for item in self._map_user_message(message):
|
|
329
|
+
openai_messages.append(item)
|
|
330
|
+
elif isinstance(message, ModelResponse):
|
|
331
|
+
texts: list[str] = []
|
|
332
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
333
|
+
for item in message.parts:
|
|
334
|
+
if isinstance(item, TextPart):
|
|
335
|
+
texts.append(item.content)
|
|
336
|
+
elif isinstance(item, ToolCallPart):
|
|
337
|
+
tool_calls.append(self._map_tool_call(item))
|
|
338
|
+
else:
|
|
339
|
+
assert_never(item)
|
|
340
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
341
|
+
if texts:
|
|
342
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
343
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
344
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
345
|
+
if tool_calls:
|
|
346
|
+
message_param['tool_calls'] = tool_calls
|
|
347
|
+
openai_messages.append(message_param)
|
|
348
|
+
else:
|
|
349
|
+
assert_never(message)
|
|
350
|
+
if instructions := self._get_instructions(messages):
|
|
351
|
+
openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
|
|
352
|
+
return openai_messages
|
|
349
353
|
|
|
350
354
|
@staticmethod
|
|
351
355
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
@@ -448,6 +452,8 @@ class OpenAIModel(Model):
|
|
|
448
452
|
# file_data = f'data:{media_type};base64,{base64_encoded}'
|
|
449
453
|
# file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file')
|
|
450
454
|
# content.append(file)
|
|
455
|
+
elif isinstance(item, VideoUrl): # pragma: no cover
|
|
456
|
+
raise NotImplementedError('VideoUrl is not supported for OpenAI')
|
|
451
457
|
else:
|
|
452
458
|
assert_never(item)
|
|
453
459
|
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
@@ -591,19 +597,19 @@ class OpenAIResponsesModel(Model):
|
|
|
591
597
|
# standalone function to make it easier to override
|
|
592
598
|
if not tools:
|
|
593
599
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
594
|
-
elif not model_request_parameters.
|
|
600
|
+
elif not model_request_parameters.allow_text_output:
|
|
595
601
|
tool_choice = 'required'
|
|
596
602
|
else:
|
|
597
603
|
tool_choice = 'auto'
|
|
598
604
|
|
|
599
|
-
|
|
605
|
+
instructions, openai_messages = await self._map_messages(messages)
|
|
600
606
|
reasoning = self._get_reasoning(model_settings)
|
|
601
607
|
|
|
602
608
|
try:
|
|
603
609
|
return await self.client.responses.create(
|
|
604
610
|
input=openai_messages,
|
|
605
611
|
model=self._model_name,
|
|
606
|
-
instructions=
|
|
612
|
+
instructions=instructions,
|
|
607
613
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
608
614
|
tools=tools or NOT_GIVEN,
|
|
609
615
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
@@ -632,8 +638,8 @@ class OpenAIResponsesModel(Model):
|
|
|
632
638
|
|
|
633
639
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
634
640
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
635
|
-
if model_request_parameters.
|
|
636
|
-
tools += [self._map_tool_definition(r) for r in model_request_parameters.
|
|
641
|
+
if model_request_parameters.output_tools:
|
|
642
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
637
643
|
return tools
|
|
638
644
|
|
|
639
645
|
@staticmethod
|
|
@@ -647,15 +653,16 @@ class OpenAIResponsesModel(Model):
|
|
|
647
653
|
'strict': f.strict or False,
|
|
648
654
|
}
|
|
649
655
|
|
|
650
|
-
async def
|
|
656
|
+
async def _map_messages(
|
|
657
|
+
self, messages: list[ModelMessage]
|
|
658
|
+
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
|
|
651
659
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
652
|
-
system_prompt: str = ''
|
|
653
660
|
openai_messages: list[responses.ResponseInputItemParam] = []
|
|
654
661
|
for message in messages:
|
|
655
662
|
if isinstance(message, ModelRequest):
|
|
656
663
|
for part in message.parts:
|
|
657
664
|
if isinstance(part, SystemPromptPart):
|
|
658
|
-
|
|
665
|
+
openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
|
|
659
666
|
elif isinstance(part, UserPromptPart):
|
|
660
667
|
openai_messages.append(await self._map_user_prompt(part))
|
|
661
668
|
elif isinstance(part, ToolReturnPart):
|
|
@@ -692,7 +699,8 @@ class OpenAIResponsesModel(Model):
|
|
|
692
699
|
assert_never(item)
|
|
693
700
|
else:
|
|
694
701
|
assert_never(message)
|
|
695
|
-
|
|
702
|
+
instructions = self._get_instructions(messages) or NOT_GIVEN
|
|
703
|
+
return instructions, openai_messages
|
|
696
704
|
|
|
697
705
|
@staticmethod
|
|
698
706
|
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
|
|
@@ -765,6 +773,8 @@ class OpenAIResponsesModel(Model):
|
|
|
765
773
|
filename=f'filename.{item.format}',
|
|
766
774
|
)
|
|
767
775
|
)
|
|
776
|
+
elif isinstance(item, VideoUrl): # pragma: no cover
|
|
777
|
+
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
|
|
768
778
|
else:
|
|
769
779
|
assert_never(item)
|
|
770
780
|
return responses.EasyInputMessageParam(role='user', content=content)
|
|
@@ -922,137 +932,79 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
922
932
|
)
|
|
923
933
|
|
|
924
934
|
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
930
|
-
but this basically just requires:
|
|
931
|
-
* `additionalProperties` must be set to false for each object in the parameters
|
|
932
|
-
* all fields in properties must be marked as required
|
|
933
|
-
"""
|
|
934
|
-
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
935
|
+
@dataclass
|
|
936
|
+
class _OpenAIJsonSchema(WalkJsonSchema):
|
|
937
|
+
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
935
938
|
|
|
936
|
-
|
|
937
|
-
|
|
939
|
+
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
940
|
+
but this basically just requires:
|
|
941
|
+
* `additionalProperties` must be set to false for each object in the parameters
|
|
942
|
+
* all fields in properties must be marked as required
|
|
943
|
+
"""
|
|
938
944
|
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
945
|
+
def __init__(self, schema: JsonSchema, strict: bool | None):
|
|
946
|
+
super().__init__(schema)
|
|
947
|
+
self.strict = strict
|
|
948
|
+
self.is_strict_compatible = True
|
|
949
|
+
|
|
950
|
+
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
951
|
+
# Remove unnecessary keys
|
|
952
|
+
schema.pop('title', None)
|
|
953
|
+
schema.pop('default', None)
|
|
954
|
+
schema.pop('$schema', None)
|
|
955
|
+
schema.pop('discriminator', None)
|
|
956
|
+
|
|
957
|
+
# Remove incompatible keys, but note their impact in the description provided to the LLM
|
|
958
|
+
description = schema.get('description')
|
|
959
|
+
min_length = schema.pop('minLength', None)
|
|
960
|
+
max_length = schema.pop('minLength', None)
|
|
961
|
+
if description is not None:
|
|
962
|
+
notes = list[str]()
|
|
963
|
+
if min_length is not None: # pragma: no cover
|
|
964
|
+
notes.append(f'min_length={min_length}')
|
|
965
|
+
if max_length is not None: # pragma: no cover
|
|
966
|
+
notes.append(f'max_length={max_length}')
|
|
967
|
+
if notes: # pragma: no cover
|
|
968
|
+
schema['description'] = f'{description} ({", ".join(notes)})'
|
|
942
969
|
|
|
943
|
-
# Process schema based on its type
|
|
944
970
|
schema_type = schema.get('type')
|
|
945
971
|
if schema_type == 'object':
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
schema['
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
972
|
+
if self.strict is True:
|
|
973
|
+
# additional properties are disallowed
|
|
974
|
+
schema['additionalProperties'] = False
|
|
975
|
+
|
|
976
|
+
# all properties are required
|
|
977
|
+
if 'properties' not in schema:
|
|
978
|
+
schema['properties'] = dict[str, Any]()
|
|
979
|
+
schema['required'] = list(schema['properties'].keys())
|
|
980
|
+
|
|
981
|
+
elif self.strict is None:
|
|
982
|
+
if (
|
|
983
|
+
schema.get('additionalProperties') is not False
|
|
984
|
+
or 'properties' not in schema
|
|
985
|
+
or 'required' not in schema
|
|
986
|
+
):
|
|
987
|
+
self.is_strict_compatible = False
|
|
988
|
+
else:
|
|
989
|
+
required = schema['required']
|
|
990
|
+
for k in schema['properties'].keys():
|
|
991
|
+
if k not in required:
|
|
992
|
+
self.is_strict_compatible = False
|
|
965
993
|
return schema
|
|
966
994
|
|
|
967
|
-
def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
|
|
968
|
-
schema['additionalProperties'] = False
|
|
969
|
-
|
|
970
|
-
# Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
|
|
971
|
-
if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
|
|
972
|
-
pattern_props: dict[str, Any] = schema['patternProperties']
|
|
973
|
-
schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
|
|
974
|
-
|
|
975
|
-
# Handle properties — update their schemas recursively, and make all properties required
|
|
976
|
-
if 'properties' in schema and isinstance(schema['properties'], dict):
|
|
977
|
-
properties: dict[str, Any] = schema['properties']
|
|
978
|
-
schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
|
|
979
|
-
schema['required'] = list(properties.keys())
|
|
980
|
-
|
|
981
|
-
def is_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
982
|
-
"""Check if the schema is strict-mode-compatible.
|
|
983
|
-
|
|
984
|
-
A schema is compatible if:
|
|
985
|
-
* `additionalProperties` is set to false for each object in the parameters
|
|
986
|
-
* all fields in properties are marked as required
|
|
987
|
-
|
|
988
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
|
|
989
|
-
"""
|
|
990
|
-
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
991
|
-
|
|
992
|
-
# Note that checking the defs first is usually the fastest way to proceed, but
|
|
993
|
-
# it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
|
|
994
|
-
# I still included the handling below because I'm not _confident_ those code paths can't be hit.
|
|
995
|
-
if defs := schema.get('$defs'):
|
|
996
|
-
if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
|
|
997
|
-
return False
|
|
998
|
-
|
|
999
|
-
schema_type = schema.get('type')
|
|
1000
|
-
if schema_type == 'object':
|
|
1001
|
-
if not self._is_object_schema_strict(schema):
|
|
1002
|
-
return False
|
|
1003
|
-
elif schema_type == 'array':
|
|
1004
|
-
if 'items' in schema:
|
|
1005
|
-
items: Any = schema['items']
|
|
1006
|
-
if not self.is_schema_strict(items): # pragma: no cover
|
|
1007
|
-
return False
|
|
1008
|
-
if 'prefixItems' in schema:
|
|
1009
|
-
prefix_items: list[Any] = schema['prefixItems']
|
|
1010
|
-
if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
|
|
1011
|
-
return False
|
|
1012
|
-
elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
|
|
1013
|
-
pass
|
|
1014
|
-
elif 'oneOf' in schema: # pragma: no cover
|
|
1015
|
-
if not all(self.is_schema_strict(item) for item in schema['oneOf']):
|
|
1016
|
-
return False
|
|
1017
|
-
|
|
1018
|
-
elif 'anyOf' in schema: # pragma: no cover
|
|
1019
|
-
if not all(self.is_schema_strict(item) for item in schema['anyOf']):
|
|
1020
|
-
return False
|
|
1021
|
-
|
|
1022
|
-
return True
|
|
1023
|
-
|
|
1024
|
-
def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
1025
|
-
"""Check if the schema is an object and has additionalProperties set to false."""
|
|
1026
|
-
if schema.get('additionalProperties') is not False:
|
|
1027
|
-
return False
|
|
1028
|
-
if 'properties' not in schema: # pragma: no cover
|
|
1029
|
-
return False
|
|
1030
|
-
if 'required' not in schema: # pragma: no cover
|
|
1031
|
-
return False
|
|
1032
|
-
|
|
1033
|
-
for k, v in schema['properties'].items():
|
|
1034
|
-
if k not in schema['required']:
|
|
1035
|
-
return False
|
|
1036
|
-
if not self.is_schema_strict(v): # pragma: no cover
|
|
1037
|
-
return False
|
|
1038
|
-
|
|
1039
|
-
return True
|
|
1040
|
-
|
|
1041
995
|
|
|
1042
996
|
def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
1043
997
|
"""Customize the request parameters for OpenAI models."""
|
|
1044
998
|
|
|
1045
999
|
def _customize_tool_def(t: ToolDefinition):
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
return replace(t, strict=strict)
|
|
1052
|
-
return t
|
|
1000
|
+
schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
|
|
1001
|
+
parameters_json_schema = schema_transformer.walk()
|
|
1002
|
+
if t.strict is None:
|
|
1003
|
+
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
1004
|
+
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
1053
1005
|
|
|
1054
1006
|
return ModelRequestParameters(
|
|
1055
1007
|
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
1056
|
-
|
|
1057
|
-
|
|
1008
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
1009
|
+
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
1058
1010
|
)
|