pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +266 -119
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +4 -0
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +78 -109
  12. pydantic_ai/models/google.py +47 -112
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +25 -158
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +29 -59
  46. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
  47. pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,19 +1,18 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import warnings
5
4
  from collections.abc import AsyncIterator, Awaitable
6
5
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field, replace
6
+ from dataclasses import dataclass, field
8
7
  from datetime import datetime
9
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
10
9
  from uuid import uuid4
11
10
 
12
11
  from typing_extensions import assert_never
13
12
 
14
13
  from pydantic_ai.providers import Provider
15
14
 
16
- from .. import UnexpectedModelBehavior, UserError, _utils, usage
15
+ from .. import UnexpectedModelBehavior, _utils, usage
17
16
  from ..messages import (
18
17
  AudioUrl,
19
18
  BinaryContent,
@@ -32,6 +31,7 @@ from ..messages import (
32
31
  UserPromptPart,
33
32
  VideoUrl,
34
33
  )
34
+ from ..profiles import ModelProfileSpec
35
35
  from ..settings import ModelSettings
36
36
  from ..tools import ToolDefinition
37
37
  from . import (
@@ -42,7 +42,6 @@ from . import (
42
42
  check_allow_model_requests,
43
43
  get_user_agent,
44
44
  )
45
- from ._json_schema import JsonSchema, WalkJsonSchema
46
45
 
47
46
  try:
48
47
  from google import genai
@@ -141,6 +140,7 @@ class GoogleModel(Model):
141
140
  model_name: GoogleModelName,
142
141
  *,
143
142
  provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
143
+ profile: ModelProfileSpec | None = None,
144
144
  ):
145
145
  """Initialize a Gemini model.
146
146
 
@@ -149,6 +149,7 @@ class GoogleModel(Model):
149
149
  provider: The provider to use for authentication and API access. Can be either the string
150
150
  'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
151
151
  If not provided, a new provider will be created using the other parameters.
152
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
152
153
  """
153
154
  self._model_name = model_name
154
155
 
@@ -158,6 +159,7 @@ class GoogleModel(Model):
158
159
  self._provider = provider
159
160
  self._system = provider.name
160
161
  self.client = provider.client
162
+ self._profile = profile or provider.model_profile
161
163
 
162
164
  @property
163
165
  def base_url(self) -> str:
@@ -186,16 +188,6 @@ class GoogleModel(Model):
186
188
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
187
189
  yield await self._process_streamed_response(response) # type: ignore
188
190
 
189
- def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
190
- def _customize_tool_def(t: ToolDefinition):
191
- return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
192
-
193
- return ModelRequestParameters(
194
- function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
195
- allow_text_output=model_request_parameters.allow_text_output,
196
- output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
197
- )
198
-
199
191
  @property
200
192
  def model_name(self) -> GoogleModelName:
201
193
  """The model name."""
@@ -291,9 +283,16 @@ class GoogleModel(Model):
291
283
  'Content field missing from Gemini response', str(response)
292
284
  ) # pragma: no cover
293
285
  parts = response.candidates[0].content.parts or []
286
+ vendor_id = response.response_id or None
287
+ vendor_details: dict[str, Any] | None = None
288
+ finish_reason = response.candidates[0].finish_reason
289
+ if finish_reason: # pragma: no branch
290
+ vendor_details = {'finish_reason': finish_reason.value}
294
291
  usage = _metadata_as_usage(response)
295
292
  usage.requests = 1
296
- return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
293
+ return _process_response_from_parts(
294
+ parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
295
+ )
297
296
 
298
297
  async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
299
298
  """Process a streamed response, and prepare a streaming response to return."""
@@ -392,7 +391,7 @@ class GeminiStreamedResponse(StreamedResponse):
392
391
 
393
392
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
394
393
  async for chunk in self._response:
395
- self._usage += _metadata_as_usage(chunk)
394
+ self._usage = _metadata_as_usage(chunk)
396
395
 
397
396
  assert chunk.candidates is not None
398
397
  candidate = chunk.candidates[0]
@@ -400,7 +399,7 @@ class GeminiStreamedResponse(StreamedResponse):
400
399
  raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
401
400
  assert candidate.content.parts is not None
402
401
  for part in candidate.content.parts:
403
- if part.text:
402
+ if part.text is not None:
404
403
  yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
405
404
  elif part.function_call:
406
405
  maybe_event = self._parts_manager.handle_tool_call_delta(
@@ -439,10 +438,16 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
439
438
  return ContentDict(role='model', parts=parts)
440
439
 
441
440
 
442
- def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
441
+ def _process_response_from_parts(
442
+ parts: list[Part],
443
+ model_name: GoogleModelName,
444
+ usage: usage.Usage,
445
+ vendor_id: str | None,
446
+ vendor_details: dict[str, Any] | None = None,
447
+ ) -> ModelResponse:
443
448
  items: list[ModelResponsePart] = []
444
449
  for part in parts:
445
- if part.text:
450
+ if part.text is not None:
446
451
  items.append(TextPart(content=part.text))
447
452
  elif part.function_call:
448
453
  assert part.function_call.name is not None
@@ -454,7 +459,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
454
459
  raise UnexpectedModelBehavior(
455
460
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
456
461
  )
457
- return ModelResponse(parts=items, model_name=model_name, usage=usage)
462
+ return ModelResponse(
463
+ parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
464
+ )
458
465
 
459
466
 
460
467
  def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
@@ -475,99 +482,27 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
475
482
  metadata = response.usage_metadata
476
483
  if metadata is None:
477
484
  return usage.Usage() # pragma: no cover
478
- # TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on
479
- # `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
480
- # handle this in the `Usage` class.
481
- details = metadata.model_dump(
482
- exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
483
- exclude_defaults=True,
484
- )
485
- return usage.Usage(
486
- request_tokens=details.pop('prompt_token_count', 0),
487
- response_tokens=details.pop('candidates_token_count', 0),
488
- total_tokens=details.pop('total_token_count', 0),
489
- details=details,
490
- )
485
+ metadata = metadata.model_dump(exclude_defaults=True)
491
486
 
487
+ details: dict[str, int] = {}
488
+ if cached_content_token_count := metadata.get('cached_content_token_count'):
489
+ details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
492
490
 
493
- class _GeminiJsonSchema(WalkJsonSchema):
494
- """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
491
+ if thoughts_token_count := metadata.get('thoughts_token_count'):
492
+ details['thoughts_tokens'] = thoughts_token_count
495
493
 
496
- Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
497
- a subset of OpenAPI v3.0.3.
494
+ if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
495
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
498
496
 
499
- Specifically:
500
- * gemini doesn't allow the `title` keyword to be set
501
- * gemini doesn't allow `$defs` — we need to inline the definitions where possible
502
- """
497
+ for key, metadata_details in metadata.items():
498
+ if key.endswith('_details') and metadata_details:
499
+ suffix = key.removesuffix('_details')
500
+ for detail in metadata_details:
501
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
503
502
 
504
- def __init__(self, schema: JsonSchema):
505
- super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
506
-
507
- def transform(self, schema: JsonSchema) -> JsonSchema:
508
- # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
509
- additional_properties = schema.pop(
510
- 'additionalProperties', None
511
- ) # don't pop yet so it's included in the warning
512
- if additional_properties: # pragma: no cover
513
- original_schema = {**schema, 'additionalProperties': additional_properties}
514
- warnings.warn(
515
- '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
516
- f' Full schema: {self.schema}\n\n'
517
- f'Source of additionalProperties within the full schema: {original_schema}\n\n'
518
- 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
519
- "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
520
- ' and we will fix this behavior.',
521
- UserWarning,
522
- )
523
-
524
- schema.pop('title', None)
525
- schema.pop('default', None)
526
- schema.pop('$schema', None)
527
- if (const := schema.pop('const', None)) is not None: # pragma: no cover
528
- # Gemini doesn't support const, but it does support enum with a single value
529
- schema['enum'] = [const]
530
- schema.pop('discriminator', None)
531
- schema.pop('examples', None)
532
-
533
- # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
534
- # where we add notes about these properties to the field description?
535
- schema.pop('exclusiveMaximum', None)
536
- schema.pop('exclusiveMinimum', None)
537
-
538
- type_ = schema.get('type')
539
- if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
540
- # This gets hit when we have a discriminated union
541
- # Gemini returns an API error in this case even though it says in its error message it shouldn't...
542
- # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
543
- schema['anyOf'] = schema.pop('oneOf')
544
-
545
- if type_ == 'string' and (fmt := schema.pop('format', None)):
546
- description = schema.get('description')
547
- if description:
548
- schema['description'] = f'{description} (format: {fmt})'
549
- else:
550
- schema['description'] = f'Format: {fmt}'
551
-
552
- if '$ref' in schema:
553
- raise UserError( # pragma: no cover
554
- f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}'
555
- )
556
-
557
- if 'prefixItems' in schema: # pragma: lax no cover
558
- # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
559
- prefix_items = schema.pop('prefixItems')
560
- items = schema.get('items')
561
- unique_items = [items] if items is not None else []
562
- for item in prefix_items:
563
- if item not in unique_items:
564
- unique_items.append(item)
565
- if len(unique_items) > 1: # pragma: no cover
566
- schema['items'] = {'anyOf': unique_items}
567
- elif len(unique_items) == 1:
568
- schema['items'] = unique_items[0]
569
- schema.setdefault('minItems', len(prefix_items))
570
- if items is None:
571
- schema.setdefault('maxItems', len(prefix_items))
572
-
573
- return schema
503
+ return usage.Usage(
504
+ request_tokens=metadata.get('prompt_token_count', 0),
505
+ response_tokens=metadata.get('candidates_token_count', 0),
506
+ total_tokens=metadata.get('total_token_count', 0),
507
+ details=details,
508
+ )
@@ -27,10 +27,17 @@ from ..messages import (
27
27
  ToolReturnPart,
28
28
  UserPromptPart,
29
29
  )
30
+ from ..profiles import ModelProfileSpec
30
31
  from ..providers import Provider, infer_provider
31
32
  from ..settings import ModelSettings
32
33
  from ..tools import ToolDefinition
33
- from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, get_user_agent
34
+ from . import (
35
+ Model,
36
+ ModelRequestParameters,
37
+ StreamedResponse,
38
+ check_allow_model_requests,
39
+ get_user_agent,
40
+ )
34
41
 
35
42
  try:
36
43
  from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
@@ -105,7 +112,13 @@ class GroqModel(Model):
105
112
  _model_name: GroqModelName = field(repr=False)
106
113
  _system: str = field(default='groq', repr=False)
107
114
 
108
- def __init__(self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq'):
115
+ def __init__(
116
+ self,
117
+ model_name: GroqModelName,
118
+ *,
119
+ provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
120
+ profile: ModelProfileSpec | None = None,
121
+ ):
109
122
  """Initialize a Groq model.
110
123
 
111
124
  Args:
@@ -114,12 +127,14 @@ class GroqModel(Model):
114
127
  provider: The provider to use for authentication and API access. Can be either the string
115
128
  'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
116
129
  created using the other parameters.
130
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
117
131
  """
118
132
  self._model_name = model_name
119
133
 
120
134
  if isinstance(provider, str):
121
135
  provider = infer_provider(provider)
122
136
  self.client = provider.client
137
+ self._profile = profile or provider.model_profile
123
138
 
124
139
  @property
125
140
  def base_url(self) -> str:
@@ -30,6 +30,7 @@ from ..messages import (
30
30
  UserPromptPart,
31
31
  VideoUrl,
32
32
  )
33
+ from ..profiles import ModelProfileSpec
33
34
  from ..providers import Provider, infer_provider
34
35
  from ..settings import ModelSettings
35
36
  from ..tools import ToolDefinition
@@ -120,6 +121,7 @@ class MistralModel(Model):
120
121
  model_name: MistralModelName,
121
122
  *,
122
123
  provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
124
+ profile: ModelProfileSpec | None = None,
123
125
  json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
124
126
  ):
125
127
  """Initialize a Mistral model.
@@ -129,6 +131,7 @@ class MistralModel(Model):
129
131
  provider: The provider to use for authentication and API access. Can be either the string
130
132
  'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
131
133
  created using the other parameters.
134
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
132
135
  json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
133
136
  """
134
137
  self._model_name = model_name
@@ -137,6 +140,7 @@ class MistralModel(Model):
137
140
  if isinstance(provider, str):
138
141
  provider = infer_provider(provider)
139
142
  self.client = provider.client
143
+ self._profile = profile or provider.model_profile
140
144
 
141
145
  @property
142
146
  def base_url(self) -> str:
@@ -1,16 +1,16 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import re
5
4
  import warnings
6
5
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
7
6
  from contextlib import asynccontextmanager
8
- from dataclasses import dataclass, field, replace
7
+ from dataclasses import dataclass, field
9
8
  from datetime import datetime, timezone
10
9
  from typing import Any, Literal, Union, cast, overload
11
10
 
12
11
  from typing_extensions import assert_never
13
12
 
13
+ from pydantic_ai.profiles.openai import OpenAIModelProfile
14
14
  from pydantic_ai.providers import Provider, infer_provider
15
15
 
16
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
@@ -33,6 +33,7 @@ from ..messages import (
33
33
  UserPromptPart,
34
34
  VideoUrl,
35
35
  )
36
+ from ..profiles import ModelProfileSpec
36
37
  from ..settings import ModelSettings
37
38
  from ..tools import ToolDefinition
38
39
  from . import (
@@ -43,7 +44,6 @@ from . import (
43
44
  check_allow_model_requests,
44
45
  get_user_agent,
45
46
  )
46
- from ._json_schema import JsonSchema, WalkJsonSchema
47
47
 
48
48
  try:
49
49
  from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
@@ -170,7 +170,9 @@ class OpenAIModel(Model):
170
170
  self,
171
171
  model_name: OpenAIModelName,
172
172
  *,
173
- provider: Literal['openai', 'deepseek', 'azure', 'openrouter'] | Provider[AsyncOpenAI] = 'openai',
173
+ provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
174
+ | Provider[AsyncOpenAI] = 'openai',
175
+ profile: ModelProfileSpec | None = None,
174
176
  system_prompt_role: OpenAISystemPromptRole | None = None,
175
177
  ):
176
178
  """Initialize an OpenAI model.
@@ -180,13 +182,17 @@ class OpenAIModel(Model):
180
182
  [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
181
183
  (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
182
184
  provider: The provider to use. Defaults to `'openai'`.
185
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
183
186
  system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
184
187
  In the future, this may be inferred from the model name.
185
188
  """
186
189
  self._model_name = model_name
190
+
187
191
  if isinstance(provider, str):
188
192
  provider = infer_provider(provider)
189
193
  self.client = provider.client
194
+ self._profile = profile or provider.model_profile
195
+
190
196
  self.system_prompt_role = system_prompt_role
191
197
 
192
198
  @property
@@ -221,9 +227,6 @@ class OpenAIModel(Model):
221
227
  async with response:
222
228
  yield await self._process_streamed_response(response)
223
229
 
224
- def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
225
- return _customize_request_parameters(model_request_parameters)
226
-
227
230
  @property
228
231
  def model_name(self) -> OpenAIModelName:
229
232
  """The model name."""
@@ -331,7 +334,9 @@ class OpenAIModel(Model):
331
334
  items.append(TextPart(choice.message.content))
332
335
  if choice.message.tool_calls is not None:
333
336
  for c in choice.message.tool_calls:
334
- items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
337
+ part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
338
+ part.tool_call_id = _guard_tool_call_id(part)
339
+ items.append(part)
335
340
  return ModelResponse(
336
341
  items,
337
342
  usage=_map_usage(response),
@@ -401,8 +406,7 @@ class OpenAIModel(Model):
401
406
  function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
402
407
  )
403
408
 
404
- @staticmethod
405
- def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
409
+ def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam:
406
410
  tool_param: chat.ChatCompletionToolParam = {
407
411
  'type': 'function',
408
412
  'function': {
@@ -411,7 +415,7 @@ class OpenAIModel(Model):
411
415
  'parameters': f.parameters_json_schema,
412
416
  },
413
417
  }
414
- if f.strict:
418
+ if f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition:
415
419
  tool_param['function']['strict'] = f.strict
416
420
  return tool_param
417
421
 
@@ -533,18 +537,23 @@ class OpenAIResponsesModel(Model):
533
537
  self,
534
538
  model_name: OpenAIModelName,
535
539
  *,
536
- provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
540
+ provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
541
+ | Provider[AsyncOpenAI] = 'openai',
542
+ profile: ModelProfileSpec | None = None,
537
543
  ):
538
544
  """Initialize an OpenAI Responses model.
539
545
 
540
546
  Args:
541
547
  model_name: The name of the OpenAI model to use.
542
548
  provider: The provider to use. Defaults to `'openai'`.
549
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
543
550
  """
544
551
  self._model_name = model_name
552
+
545
553
  if isinstance(provider, str):
546
554
  provider = infer_provider(provider)
547
555
  self.client = provider.client
556
+ self._profile = profile or provider.model_profile
548
557
 
549
558
  @property
550
559
  def model_name(self) -> OpenAIModelName:
@@ -582,9 +591,6 @@ class OpenAIResponsesModel(Model):
582
591
  async with response:
583
592
  yield await self._process_streamed_response(response)
584
593
 
585
- def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
586
- return _customize_request_parameters(model_request_parameters)
587
-
588
594
  def _process_response(self, response: responses.Response) -> ModelResponse:
589
595
  """Process a non-streamed response, and prepare a message to return."""
590
596
  timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
@@ -690,15 +696,15 @@ class OpenAIResponsesModel(Model):
690
696
  tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
691
697
  return tools
692
698
 
693
- @staticmethod
694
- def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
699
+ def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
695
700
  return {
696
701
  'name': f.name,
697
702
  'parameters': f.parameters_json_schema,
698
703
  'type': 'function',
699
704
  'description': f.description,
700
- # NOTE: f.strict should already be a boolean thanks to customize_request_parameters
701
- 'strict': f.strict or False,
705
+ 'strict': bool(
706
+ f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition
707
+ ),
702
708
  }
703
709
 
704
710
  async def _map_messages(
@@ -980,142 +986,3 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
980
986
  total_tokens=response_usage.total_tokens,
981
987
  details=details,
982
988
  )
983
-
984
-
985
- _STRICT_INCOMPATIBLE_KEYS = [
986
- 'minLength',
987
- 'maxLength',
988
- 'pattern',
989
- 'format',
990
- 'minimum',
991
- 'maximum',
992
- 'multipleOf',
993
- 'patternProperties',
994
- 'unevaluatedProperties',
995
- 'propertyNames',
996
- 'minProperties',
997
- 'maxProperties',
998
- 'unevaluatedItems',
999
- 'contains',
1000
- 'minContains',
1001
- 'maxContains',
1002
- 'minItems',
1003
- 'maxItems',
1004
- 'uniqueItems',
1005
- ]
1006
-
1007
- _sentinel = object()
1008
-
1009
-
1010
- @dataclass
1011
- class _OpenAIJsonSchema(WalkJsonSchema):
1012
- """Recursively handle the schema to make it compatible with OpenAI strict mode.
1013
-
1014
- See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
1015
- but this basically just requires:
1016
- * `additionalProperties` must be set to false for each object in the parameters
1017
- * all fields in properties must be marked as required
1018
- """
1019
-
1020
- def __init__(self, schema: JsonSchema, strict: bool | None):
1021
- super().__init__(schema)
1022
- self.strict = strict
1023
- self.is_strict_compatible = True
1024
- self.root_ref = schema.get('$ref')
1025
-
1026
- def walk(self) -> JsonSchema:
1027
- # Note: OpenAI does not support anyOf at the root in strict mode
1028
- # However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
1029
- # that the root schema either has type 'object' or is recursive.
1030
- result = super().walk()
1031
-
1032
- # For recursive models, we need to tweak the schema to make it compatible with strict mode.
1033
- # Because the following should never change the semantics of the schema we apply it unconditionally.
1034
- if self.root_ref is not None:
1035
- result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
1036
- root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
1037
- result.update(self.defs.get(root_key) or {})
1038
-
1039
- return result
1040
-
1041
- def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
1042
- # Remove unnecessary keys
1043
- schema.pop('title', None)
1044
- schema.pop('default', None)
1045
- schema.pop('$schema', None)
1046
- schema.pop('discriminator', None)
1047
-
1048
- if schema_ref := schema.get('$ref'):
1049
- if schema_ref == self.root_ref:
1050
- schema['$ref'] = '#'
1051
- if len(schema) > 1:
1052
- # OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
1053
- # So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
1054
- schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
1055
-
1056
- # Track strict-incompatible keys
1057
- incompatible_values: dict[str, Any] = {}
1058
- for key in _STRICT_INCOMPATIBLE_KEYS:
1059
- value = schema.get(key, _sentinel)
1060
- if value is not _sentinel:
1061
- incompatible_values[key] = value
1062
- description = schema.get('description')
1063
- if incompatible_values:
1064
- if self.strict is True:
1065
- notes: list[str] = []
1066
- for key, value in incompatible_values.items():
1067
- schema.pop(key)
1068
- notes.append(f'{key}={value}')
1069
- notes_string = ', '.join(notes)
1070
- schema['description'] = notes_string if not description else f'{description} ({notes_string})'
1071
- elif self.strict is None: # pragma: no branch
1072
- self.is_strict_compatible = False
1073
-
1074
- schema_type = schema.get('type')
1075
- if 'oneOf' in schema:
1076
- # OpenAI does not support oneOf in strict mode
1077
- if self.strict is True:
1078
- schema['anyOf'] = schema.pop('oneOf')
1079
- else:
1080
- self.is_strict_compatible = False
1081
-
1082
- if schema_type == 'object':
1083
- if self.strict is True:
1084
- # additional properties are disallowed
1085
- schema['additionalProperties'] = False
1086
-
1087
- # all properties are required
1088
- if 'properties' not in schema:
1089
- schema['properties'] = dict[str, Any]()
1090
- schema['required'] = list(schema['properties'].keys())
1091
-
1092
- elif self.strict is None:
1093
- if (
1094
- schema.get('additionalProperties') is not False
1095
- or 'properties' not in schema
1096
- or 'required' not in schema
1097
- ):
1098
- self.is_strict_compatible = False
1099
- else:
1100
- required = schema['required']
1101
- for k in schema['properties'].keys():
1102
- if k not in required:
1103
- self.is_strict_compatible = False
1104
- return schema
1105
-
1106
-
1107
- def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
1108
- """Customize the request parameters for OpenAI models."""
1109
-
1110
- def _customize_tool_def(t: ToolDefinition):
1111
- schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
1112
- parameters_json_schema = schema_transformer.walk()
1113
- if t.strict is None:
1114
- t = replace(t, strict=schema_transformer.is_strict_compatible)
1115
- return replace(t, parameters_json_schema=parameters_json_schema)
1116
-
1117
- return ModelRequestParameters(
1118
- function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
1119
- allow_text_output=model_request_parameters.allow_text_output,
1120
- output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
1121
- )
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from dataclasses import dataclass, fields, replace
4
+ from typing import Callable, Union
5
+
6
+ from typing_extensions import Self
7
+
8
+ from ._json_schema import JsonSchemaTransformer
9
+
10
+
11
+ @dataclass
12
+ class ModelProfile:
13
+ """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used."""
14
+
15
+ json_schema_transformer: type[JsonSchemaTransformer] | None = None
16
+
17
+ @classmethod
18
+ def from_profile(cls, profile: ModelProfile | None) -> Self:
19
+ """Build a ModelProfile subclass instance from a ModelProfile instance."""
20
+ if isinstance(profile, cls):
21
+ return profile
22
+ return cls().update(profile)
23
+
24
+ def update(self, profile: ModelProfile | None) -> Self:
25
+ """Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance."""
26
+ if not profile:
27
+ return self
28
+ field_names = set(f.name for f in fields(self))
29
+ non_default_attrs = {
30
+ f.name: getattr(profile, f.name)
31
+ for f in fields(profile)
32
+ if f.name in field_names and getattr(profile, f.name) != f.default
33
+ }
34
+ return replace(self, **non_default_attrs)
35
+
36
+
37
+ ModelProfileSpec = Union[ModelProfile, Callable[[str], Union[ModelProfile, None]]]
38
+
39
+ DEFAULT_PROFILE = ModelProfile()