pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from itertools import chain
9
8
  from typing import Any, Literal, Union, cast
10
9
 
11
10
  import pydantic_core
@@ -29,11 +28,12 @@ from ..messages import (
29
28
  ToolCallPart,
30
29
  ToolReturnPart,
31
30
  UserPromptPart,
31
+ VideoUrl,
32
32
  )
33
33
  from ..providers import Provider, infer_provider
34
- from ..result import Usage
35
34
  from ..settings import ModelSettings
36
35
  from ..tools import ToolDefinition
36
+ from ..usage import Usage
37
37
  from . import (
38
38
  Model,
39
39
  ModelRequestParameters,
@@ -168,7 +168,7 @@ class MistralModel(Model):
168
168
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
169
169
  )
170
170
  async with response:
171
- yield await self._process_streamed_response(model_request_parameters.result_tools, response)
171
+ yield await self._process_streamed_response(model_request_parameters.output_tools, response)
172
172
 
173
173
  @property
174
174
  def model_name(self) -> MistralModelName:
@@ -190,9 +190,9 @@ class MistralModel(Model):
190
190
  try:
191
191
  response = await self.client.chat.complete_async(
192
192
  model=str(self._model_name),
193
- messages=list(chain(*(self._map_message(m) for m in messages))),
193
+ messages=self._map_messages(messages),
194
194
  n=1,
195
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
195
+ tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
196
196
  tool_choice=self._get_tool_choice(model_request_parameters),
197
197
  stream=False,
198
198
  max_tokens=model_settings.get('max_tokens', UNSET),
@@ -219,10 +219,10 @@ class MistralModel(Model):
219
219
  ) -> MistralEventStreamAsync[MistralCompletionEvent]:
220
220
  """Create a streaming completion request to the Mistral model."""
221
221
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
222
- mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
222
+ mistral_messages = self._map_messages(messages)
223
223
 
224
224
  if (
225
- model_request_parameters.result_tools
225
+ model_request_parameters.output_tools
226
226
  and model_request_parameters.function_tools
227
227
  or model_request_parameters.function_tools
228
228
  ):
@@ -231,7 +231,7 @@ class MistralModel(Model):
231
231
  model=str(self._model_name),
232
232
  messages=mistral_messages,
233
233
  n=1,
234
- tools=self._map_function_and_result_tools_definition(model_request_parameters) or UNSET,
234
+ tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
235
235
  tool_choice=self._get_tool_choice(model_request_parameters),
236
236
  temperature=model_settings.get('temperature', UNSET),
237
237
  top_p=model_settings.get('top_p', 1),
@@ -243,9 +243,9 @@ class MistralModel(Model):
243
243
  http_headers={'User-Agent': get_user_agent()},
244
244
  )
245
245
 
246
- elif model_request_parameters.result_tools:
246
+ elif model_request_parameters.output_tools:
247
247
  # Json Mode
248
- parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.result_tools]
248
+ parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools]
249
249
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
250
250
  mistral_messages.append(user_output_format_message)
251
251
 
@@ -276,22 +276,22 @@ class MistralModel(Model):
276
276
  - "none": Prevents tool use.
277
277
  - "required": Forces tool use.
278
278
  """
279
- if not model_request_parameters.function_tools and not model_request_parameters.result_tools:
279
+ if not model_request_parameters.function_tools and not model_request_parameters.output_tools:
280
280
  return None
281
- elif not model_request_parameters.allow_text_result:
281
+ elif not model_request_parameters.allow_text_output:
282
282
  return 'required'
283
283
  else:
284
284
  return 'auto'
285
285
 
286
- def _map_function_and_result_tools_definition(
286
+ def _map_function_and_output_tools_definition(
287
287
  self, model_request_parameters: ModelRequestParameters
288
288
  ) -> list[MistralTool] | None:
289
- """Map function and result tools to MistralTool format.
289
+ """Map function and output tools to MistralTool format.
290
290
 
291
- Returns None if both function_tools and result_tools are empty.
291
+ Returns None if both function_tools and output_tools are empty.
292
292
  """
293
293
  all_tools: list[ToolDefinition] = (
294
- model_request_parameters.function_tools + model_request_parameters.result_tools
294
+ model_request_parameters.function_tools + model_request_parameters.output_tools
295
295
  )
296
296
  tools = [
297
297
  MistralTool(
@@ -327,7 +327,7 @@ class MistralModel(Model):
327
327
 
328
328
  async def _process_streamed_response(
329
329
  self,
330
- result_tools: list[ToolDefinition],
330
+ output_tools: list[ToolDefinition],
331
331
  response: MistralEventStreamAsync[MistralCompletionEvent],
332
332
  ) -> StreamedResponse:
333
333
  """Process a streamed response, and prepare a streaming response to return."""
@@ -345,7 +345,7 @@ class MistralModel(Model):
345
345
  _response=peekable_response,
346
346
  _model_name=self._model_name,
347
347
  _timestamp=timestamp,
348
- _result_tools={c.name: c for c in result_tools},
348
+ _output_tools={c.name: c for c in output_tools},
349
349
  )
350
350
 
351
351
  @staticmethod
@@ -439,13 +439,12 @@ class MistralModel(Model):
439
439
  return int(1000 * timeout)
440
440
  raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
441
441
 
442
- @classmethod
443
- def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
442
+ def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]:
444
443
  for part in message.parts:
445
444
  if isinstance(part, SystemPromptPart):
446
445
  yield MistralSystemMessage(content=part.content)
447
446
  elif isinstance(part, UserPromptPart):
448
- yield cls._map_user_prompt(part)
447
+ yield self._map_user_prompt(part)
449
448
  elif isinstance(part, ToolReturnPart):
450
449
  yield MistralToolMessage(
451
450
  tool_call_id=part.tool_call_id,
@@ -462,28 +461,31 @@ class MistralModel(Model):
462
461
  else:
463
462
  assert_never(part)
464
463
 
465
- @classmethod
466
- def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
464
+ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
467
465
  """Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
468
- if isinstance(message, ModelRequest):
469
- yield from cls._map_user_message(message)
470
- elif isinstance(message, ModelResponse):
471
- content_chunks: list[MistralContentChunk] = []
472
- tool_calls: list[MistralToolCall] = []
473
-
474
- for part in message.parts:
475
- if isinstance(part, TextPart):
476
- content_chunks.append(MistralTextChunk(text=part.content))
477
- elif isinstance(part, ToolCallPart):
478
- tool_calls.append(cls._map_tool_call(part))
479
- else:
480
- assert_never(part)
481
- yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
482
- else:
483
- assert_never(message)
466
+ mistral_messages: list[MistralMessages] = []
467
+ for message in messages:
468
+ if isinstance(message, ModelRequest):
469
+ mistral_messages.extend(self._map_user_message(message))
470
+ elif isinstance(message, ModelResponse):
471
+ content_chunks: list[MistralContentChunk] = []
472
+ tool_calls: list[MistralToolCall] = []
473
+
474
+ for part in message.parts:
475
+ if isinstance(part, TextPart):
476
+ content_chunks.append(MistralTextChunk(text=part.content))
477
+ elif isinstance(part, ToolCallPart):
478
+ tool_calls.append(self._map_tool_call(part))
479
+ else:
480
+ assert_never(part)
481
+ mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls))
482
+ else:
483
+ assert_never(message)
484
+ if instructions := self._get_instructions(messages):
485
+ mistral_messages.insert(0, MistralSystemMessage(content=instructions))
486
+ return mistral_messages
484
487
 
485
- @staticmethod
486
- def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
488
+ def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
487
489
  content: str | list[MistralContentChunk]
488
490
  if isinstance(part.content, str):
489
491
  content = part.content
@@ -503,6 +505,8 @@ class MistralModel(Model):
503
505
  raise RuntimeError('Only image binary content is supported for Mistral.')
504
506
  elif isinstance(item, DocumentUrl):
505
507
  raise RuntimeError('DocumentUrl is not supported in Mistral.')
508
+ elif isinstance(item, VideoUrl):
509
+ raise RuntimeError('VideoUrl is not supported in Mistral.')
506
510
  else: # pragma: no cover
507
511
  raise RuntimeError(f'Unsupported content type: {type(item)}')
508
512
  return MistralUserMessage(content=content)
@@ -518,7 +522,7 @@ class MistralStreamedResponse(StreamedResponse):
518
522
  _model_name: MistralModelName
519
523
  _response: AsyncIterable[MistralCompletionEvent]
520
524
  _timestamp: datetime
521
- _result_tools: dict[str, ToolDefinition]
525
+ _output_tools: dict[str, ToolDefinition]
522
526
 
523
527
  _delta_content: str = field(default='', init=False)
524
528
 
@@ -536,13 +540,13 @@ class MistralStreamedResponse(StreamedResponse):
536
540
  content = choice.delta.content
537
541
  text = _map_content(content)
538
542
  if text:
539
- # Attempt to produce a result tool call from the received text
540
- if self._result_tools:
543
+ # Attempt to produce an output tool call from the received text
544
+ if self._output_tools:
541
545
  self._delta_content += text
542
- maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
546
+ maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools)
543
547
  if maybe_tool_call_part:
544
548
  yield self._parts_manager.handle_tool_call_part(
545
- vendor_part_id='result',
549
+ vendor_part_id='output',
546
550
  tool_name=maybe_tool_call_part.tool_name,
547
551
  args=maybe_tool_call_part.args_as_dict(),
548
552
  tool_call_id=maybe_tool_call_part.tool_call_id,
@@ -568,20 +572,20 @@ class MistralStreamedResponse(StreamedResponse):
568
572
  return self._timestamp
569
573
 
570
574
  @staticmethod
571
- def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
575
+ def _try_get_output_tool_from_text(text: str, output_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
572
576
  output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
573
577
  if output_json:
574
- for result_tool in result_tools.values():
575
- # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
578
+ for output_tool in output_tools.values():
579
+ # NOTE: Additional verification to prevent JSON validation to crash
576
580
  # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
577
581
  # Example with BaseModel and required fields.
578
582
  if not MistralStreamedResponse._validate_required_json_schema(
579
- output_json, result_tool.parameters_json_schema
583
+ output_json, output_tool.parameters_json_schema
580
584
  ):
581
585
  continue
582
586
 
583
587
  # The following part_id will be thrown away
584
- return ToolCallPart(tool_name=result_tool.name, args=output_json)
588
+ return ToolCallPart(tool_name=output_tool.name, args=output_json)
585
589
 
586
590
  @staticmethod
587
591
  def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
@@ -649,21 +653,21 @@ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk)
649
653
 
650
654
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
651
655
  """Maps the delta content from a Mistral Completion Chunk to a string or None."""
652
- result: str | None = None
656
+ output: str | None = None
653
657
 
654
658
  if isinstance(content, MistralUnset) or not content:
655
- result = None
659
+ output = None
656
660
  elif isinstance(content, list):
657
661
  for chunk in content:
658
662
  if isinstance(chunk, MistralTextChunk):
659
- result = result or '' + chunk.text
663
+ output = output or '' + chunk.text
660
664
  else:
661
665
  assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
662
666
  elif isinstance(content, str):
663
- result = content
667
+ output = content
664
668
 
665
669
  # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
666
- if result and len(result) == 0:
667
- result = None
670
+ if output and len(output) == 0: # pragma: no cover
671
+ output = None
668
672
 
669
- return result
673
+ return output
@@ -30,6 +30,7 @@ from ..messages import (
30
30
  ToolCallPart,
31
31
  ToolReturnPart,
32
32
  UserPromptPart,
33
+ VideoUrl,
33
34
  )
34
35
  from ..settings import ModelSettings
35
36
  from ..tools import ToolDefinition
@@ -41,6 +42,7 @@ from . import (
41
42
  check_allow_model_requests,
42
43
  get_user_agent,
43
44
  )
45
+ from ._json_schema import JsonSchema, WalkJsonSchema
44
46
 
45
47
  try:
46
48
  from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
@@ -252,15 +254,12 @@ class OpenAIModel(Model):
252
254
  # standalone function to make it easier to override
253
255
  if not tools:
254
256
  tool_choice: Literal['none', 'required', 'auto'] | None = None
255
- elif not model_request_parameters.allow_text_result:
257
+ elif not model_request_parameters.allow_text_output:
256
258
  tool_choice = 'required'
257
259
  else:
258
260
  tool_choice = 'auto'
259
261
 
260
- openai_messages: list[chat.ChatCompletionMessageParam] = []
261
- for m in messages:
262
- async for msg in self._map_message(m):
263
- openai_messages.append(msg)
262
+ openai_messages = await self._map_messages(messages)
264
263
 
265
264
  try:
266
265
  return await self.client.chat.completions.create(
@@ -317,35 +316,40 @@ class OpenAIModel(Model):
317
316
 
318
317
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
319
318
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
320
- if model_request_parameters.result_tools:
321
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
319
+ if model_request_parameters.output_tools:
320
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
322
321
  return tools
323
322
 
324
- async def _map_message(self, message: ModelMessage) -> AsyncIterable[chat.ChatCompletionMessageParam]:
323
+ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
325
324
  """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
326
- if isinstance(message, ModelRequest):
327
- async for item in self._map_user_message(message):
328
- yield item
329
- elif isinstance(message, ModelResponse):
330
- texts: list[str] = []
331
- tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
332
- for item in message.parts:
333
- if isinstance(item, TextPart):
334
- texts.append(item.content)
335
- elif isinstance(item, ToolCallPart):
336
- tool_calls.append(self._map_tool_call(item))
337
- else:
338
- assert_never(item)
339
- message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
340
- if texts:
341
- # Note: model responses from this model should only have one text item, so the following
342
- # shouldn't merge multiple texts into one unless you switch models between runs:
343
- message_param['content'] = '\n\n'.join(texts)
344
- if tool_calls:
345
- message_param['tool_calls'] = tool_calls
346
- yield message_param
347
- else:
348
- assert_never(message)
325
+ openai_messages: list[chat.ChatCompletionMessageParam] = []
326
+ for message in messages:
327
+ if isinstance(message, ModelRequest):
328
+ async for item in self._map_user_message(message):
329
+ openai_messages.append(item)
330
+ elif isinstance(message, ModelResponse):
331
+ texts: list[str] = []
332
+ tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
333
+ for item in message.parts:
334
+ if isinstance(item, TextPart):
335
+ texts.append(item.content)
336
+ elif isinstance(item, ToolCallPart):
337
+ tool_calls.append(self._map_tool_call(item))
338
+ else:
339
+ assert_never(item)
340
+ message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
341
+ if texts:
342
+ # Note: model responses from this model should only have one text item, so the following
343
+ # shouldn't merge multiple texts into one unless you switch models between runs:
344
+ message_param['content'] = '\n\n'.join(texts)
345
+ if tool_calls:
346
+ message_param['tool_calls'] = tool_calls
347
+ openai_messages.append(message_param)
348
+ else:
349
+ assert_never(message)
350
+ if instructions := self._get_instructions(messages):
351
+ openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system'))
352
+ return openai_messages
349
353
 
350
354
  @staticmethod
351
355
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
@@ -448,6 +452,8 @@ class OpenAIModel(Model):
448
452
  # file_data = f'data:{media_type};base64,{base64_encoded}'
449
453
  # file = File(file={'file_data': file_data, 'file_name': item.url, 'file_id': item.url}, type='file')
450
454
  # content.append(file)
455
+ elif isinstance(item, VideoUrl): # pragma: no cover
456
+ raise NotImplementedError('VideoUrl is not supported for OpenAI')
451
457
  else:
452
458
  assert_never(item)
453
459
  return chat.ChatCompletionUserMessageParam(role='user', content=content)
@@ -591,19 +597,19 @@ class OpenAIResponsesModel(Model):
591
597
  # standalone function to make it easier to override
592
598
  if not tools:
593
599
  tool_choice: Literal['none', 'required', 'auto'] | None = None
594
- elif not model_request_parameters.allow_text_result:
600
+ elif not model_request_parameters.allow_text_output:
595
601
  tool_choice = 'required'
596
602
  else:
597
603
  tool_choice = 'auto'
598
604
 
599
- system_prompt, openai_messages = await self._map_message(messages)
605
+ instructions, openai_messages = await self._map_messages(messages)
600
606
  reasoning = self._get_reasoning(model_settings)
601
607
 
602
608
  try:
603
609
  return await self.client.responses.create(
604
610
  input=openai_messages,
605
611
  model=self._model_name,
606
- instructions=system_prompt,
612
+ instructions=instructions,
607
613
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
608
614
  tools=tools or NOT_GIVEN,
609
615
  tool_choice=tool_choice or NOT_GIVEN,
@@ -632,8 +638,8 @@ class OpenAIResponsesModel(Model):
632
638
 
633
639
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
634
640
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
635
- if model_request_parameters.result_tools:
636
- tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
641
+ if model_request_parameters.output_tools:
642
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
637
643
  return tools
638
644
 
639
645
  @staticmethod
@@ -647,15 +653,16 @@ class OpenAIResponsesModel(Model):
647
653
  'strict': f.strict or False,
648
654
  }
649
655
 
650
- async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
656
+ async def _map_messages(
657
+ self, messages: list[ModelMessage]
658
+ ) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
651
659
  """Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
652
- system_prompt: str = ''
653
660
  openai_messages: list[responses.ResponseInputItemParam] = []
654
661
  for message in messages:
655
662
  if isinstance(message, ModelRequest):
656
663
  for part in message.parts:
657
664
  if isinstance(part, SystemPromptPart):
658
- system_prompt += part.content
665
+ openai_messages.append(responses.EasyInputMessageParam(role='system', content=part.content))
659
666
  elif isinstance(part, UserPromptPart):
660
667
  openai_messages.append(await self._map_user_prompt(part))
661
668
  elif isinstance(part, ToolReturnPart):
@@ -692,7 +699,8 @@ class OpenAIResponsesModel(Model):
692
699
  assert_never(item)
693
700
  else:
694
701
  assert_never(message)
695
- return system_prompt, openai_messages
702
+ instructions = self._get_instructions(messages) or NOT_GIVEN
703
+ return instructions, openai_messages
696
704
 
697
705
  @staticmethod
698
706
  def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
@@ -765,6 +773,8 @@ class OpenAIResponsesModel(Model):
765
773
  filename=f'filename.{item.format}',
766
774
  )
767
775
  )
776
+ elif isinstance(item, VideoUrl): # pragma: no cover
777
+ raise NotImplementedError('VideoUrl is not supported for OpenAI.')
768
778
  else:
769
779
  assert_never(item)
770
780
  return responses.EasyInputMessageParam(role='user', content=content)
@@ -922,137 +932,79 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
922
932
  )
923
933
 
924
934
 
925
- class _StrictSchemaHelper:
926
- def make_schema_strict(self, schema: dict[str, Any]) -> dict[str, Any]:
927
- """Recursively handle the schema to make it compatible with OpenAI strict mode.
928
-
929
- See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
930
- but this basically just requires:
931
- * `additionalProperties` must be set to false for each object in the parameters
932
- * all fields in properties must be marked as required
933
- """
934
- assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
935
+ @dataclass
936
+ class _OpenAIJsonSchema(WalkJsonSchema):
937
+ """Recursively handle the schema to make it compatible with OpenAI strict mode.
935
938
 
936
- # Create a copy to avoid modifying the original schema
937
- schema = schema.copy()
939
+ See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
940
+ but this basically just requires:
941
+ * `additionalProperties` must be set to false for each object in the parameters
942
+ * all fields in properties must be marked as required
943
+ """
938
944
 
939
- # Handle $defs
940
- if defs := schema.get('$defs'):
941
- schema['$defs'] = {k: self.make_schema_strict(v) for k, v in defs.items()}
945
+ def __init__(self, schema: JsonSchema, strict: bool | None):
946
+ super().__init__(schema)
947
+ self.strict = strict
948
+ self.is_strict_compatible = True
949
+
950
+ def transform(self, schema: JsonSchema) -> JsonSchema:
951
+ # Remove unnecessary keys
952
+ schema.pop('title', None)
953
+ schema.pop('default', None)
954
+ schema.pop('$schema', None)
955
+ schema.pop('discriminator', None)
956
+
957
+ # Remove incompatible keys, but note their impact in the description provided to the LLM
958
+ description = schema.get('description')
959
+ min_length = schema.pop('minLength', None)
960
+ max_length = schema.pop('minLength', None)
961
+ if description is not None:
962
+ notes = list[str]()
963
+ if min_length is not None: # pragma: no cover
964
+ notes.append(f'min_length={min_length}')
965
+ if max_length is not None: # pragma: no cover
966
+ notes.append(f'max_length={max_length}')
967
+ if notes: # pragma: no cover
968
+ schema['description'] = f'{description} ({", ".join(notes)})'
942
969
 
943
- # Process schema based on its type
944
970
  schema_type = schema.get('type')
945
971
  if schema_type == 'object':
946
- # Handle object type by setting additionalProperties to false
947
- # and adding all properties to required list
948
- self._make_object_schema_strict(schema)
949
- elif schema_type == 'array':
950
- # Handle array types by processing their items
951
- if 'items' in schema:
952
- items: Any = schema['items']
953
- schema['items'] = self.make_schema_strict(items)
954
- if 'prefixItems' in schema:
955
- prefix_items: list[Any] = schema['prefixItems']
956
- schema['prefixItems'] = [self.make_schema_strict(item) for item in prefix_items]
957
-
958
- elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
959
- pass # Primitive types need no special handling
960
- elif 'oneOf' in schema:
961
- schema['oneOf'] = [self.make_schema_strict(item) for item in schema['oneOf']]
962
- elif 'anyOf' in schema:
963
- schema['anyOf'] = [self.make_schema_strict(item) for item in schema['anyOf']]
964
-
972
+ if self.strict is True:
973
+ # additional properties are disallowed
974
+ schema['additionalProperties'] = False
975
+
976
+ # all properties are required
977
+ if 'properties' not in schema:
978
+ schema['properties'] = dict[str, Any]()
979
+ schema['required'] = list(schema['properties'].keys())
980
+
981
+ elif self.strict is None:
982
+ if (
983
+ schema.get('additionalProperties') is not False
984
+ or 'properties' not in schema
985
+ or 'required' not in schema
986
+ ):
987
+ self.is_strict_compatible = False
988
+ else:
989
+ required = schema['required']
990
+ for k in schema['properties'].keys():
991
+ if k not in required:
992
+ self.is_strict_compatible = False
965
993
  return schema
966
994
 
967
- def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
968
- schema['additionalProperties'] = False
969
-
970
- # Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
971
- if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
972
- pattern_props: dict[str, Any] = schema['patternProperties']
973
- schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
974
-
975
- # Handle properties — update their schemas recursively, and make all properties required
976
- if 'properties' in schema and isinstance(schema['properties'], dict):
977
- properties: dict[str, Any] = schema['properties']
978
- schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
979
- schema['required'] = list(properties.keys())
980
-
981
- def is_schema_strict(self, schema: dict[str, Any]) -> bool:
982
- """Check if the schema is strict-mode-compatible.
983
-
984
- A schema is compatible if:
985
- * `additionalProperties` is set to false for each object in the parameters
986
- * all fields in properties are marked as required
987
-
988
- See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
989
- """
990
- assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
991
-
992
- # Note that checking the defs first is usually the fastest way to proceed, but
993
- # it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
994
- # I still included the handling below because I'm not _confident_ those code paths can't be hit.
995
- if defs := schema.get('$defs'):
996
- if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
997
- return False
998
-
999
- schema_type = schema.get('type')
1000
- if schema_type == 'object':
1001
- if not self._is_object_schema_strict(schema):
1002
- return False
1003
- elif schema_type == 'array':
1004
- if 'items' in schema:
1005
- items: Any = schema['items']
1006
- if not self.is_schema_strict(items): # pragma: no cover
1007
- return False
1008
- if 'prefixItems' in schema:
1009
- prefix_items: list[Any] = schema['prefixItems']
1010
- if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
1011
- return False
1012
- elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
1013
- pass
1014
- elif 'oneOf' in schema: # pragma: no cover
1015
- if not all(self.is_schema_strict(item) for item in schema['oneOf']):
1016
- return False
1017
-
1018
- elif 'anyOf' in schema: # pragma: no cover
1019
- if not all(self.is_schema_strict(item) for item in schema['anyOf']):
1020
- return False
1021
-
1022
- return True
1023
-
1024
- def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
1025
- """Check if the schema is an object and has additionalProperties set to false."""
1026
- if schema.get('additionalProperties') is not False:
1027
- return False
1028
- if 'properties' not in schema: # pragma: no cover
1029
- return False
1030
- if 'required' not in schema: # pragma: no cover
1031
- return False
1032
-
1033
- for k, v in schema['properties'].items():
1034
- if k not in schema['required']:
1035
- return False
1036
- if not self.is_schema_strict(v): # pragma: no cover
1037
- return False
1038
-
1039
- return True
1040
-
1041
995
 
1042
996
  def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
1043
997
  """Customize the request parameters for OpenAI models."""
1044
998
 
1045
999
  def _customize_tool_def(t: ToolDefinition):
1046
- if t.strict is True:
1047
- parameters_json_schema = _StrictSchemaHelper().make_schema_strict(t.parameters_json_schema)
1048
- return replace(t, parameters_json_schema=parameters_json_schema)
1049
- elif t.strict is None:
1050
- strict = _StrictSchemaHelper().is_schema_strict(t.parameters_json_schema)
1051
- return replace(t, strict=strict)
1052
- return t
1000
+ schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
1001
+ parameters_json_schema = schema_transformer.walk()
1002
+ if t.strict is None:
1003
+ t = replace(t, strict=schema_transformer.is_strict_compatible)
1004
+ return replace(t, parameters_json_schema=parameters_json_schema)
1053
1005
 
1054
1006
  return ModelRequestParameters(
1055
1007
  function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
1056
- allow_text_result=model_request_parameters.allow_text_result,
1057
- result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
1008
+ allow_text_output=model_request_parameters.allow_text_output,
1009
+ output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
1058
1010
  )