pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +266 -119
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +78 -109
- pydantic_ai/models/google.py +47 -112
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +25 -158
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +29 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import warnings
|
|
5
4
|
from collections.abc import AsyncIterator, Awaitable
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime
|
|
9
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
10
9
|
from uuid import uuid4
|
|
11
10
|
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from pydantic_ai.providers import Provider
|
|
15
14
|
|
|
16
|
-
from .. import UnexpectedModelBehavior,
|
|
15
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
17
16
|
from ..messages import (
|
|
18
17
|
AudioUrl,
|
|
19
18
|
BinaryContent,
|
|
@@ -32,6 +31,7 @@ from ..messages import (
|
|
|
32
31
|
UserPromptPart,
|
|
33
32
|
VideoUrl,
|
|
34
33
|
)
|
|
34
|
+
from ..profiles import ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import (
|
|
@@ -42,7 +42,6 @@ from . import (
|
|
|
42
42
|
check_allow_model_requests,
|
|
43
43
|
get_user_agent,
|
|
44
44
|
)
|
|
45
|
-
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
46
45
|
|
|
47
46
|
try:
|
|
48
47
|
from google import genai
|
|
@@ -141,6 +140,7 @@ class GoogleModel(Model):
|
|
|
141
140
|
model_name: GoogleModelName,
|
|
142
141
|
*,
|
|
143
142
|
provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
|
|
143
|
+
profile: ModelProfileSpec | None = None,
|
|
144
144
|
):
|
|
145
145
|
"""Initialize a Gemini model.
|
|
146
146
|
|
|
@@ -149,6 +149,7 @@ class GoogleModel(Model):
|
|
|
149
149
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
150
150
|
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
151
151
|
If not provided, a new provider will be created using the other parameters.
|
|
152
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
152
153
|
"""
|
|
153
154
|
self._model_name = model_name
|
|
154
155
|
|
|
@@ -158,6 +159,7 @@ class GoogleModel(Model):
|
|
|
158
159
|
self._provider = provider
|
|
159
160
|
self._system = provider.name
|
|
160
161
|
self.client = provider.client
|
|
162
|
+
self._profile = profile or provider.model_profile
|
|
161
163
|
|
|
162
164
|
@property
|
|
163
165
|
def base_url(self) -> str:
|
|
@@ -186,16 +188,6 @@ class GoogleModel(Model):
|
|
|
186
188
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
187
189
|
yield await self._process_streamed_response(response) # type: ignore
|
|
188
190
|
|
|
189
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
190
|
-
def _customize_tool_def(t: ToolDefinition):
|
|
191
|
-
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
|
|
192
|
-
|
|
193
|
-
return ModelRequestParameters(
|
|
194
|
-
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
195
|
-
allow_text_output=model_request_parameters.allow_text_output,
|
|
196
|
-
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
197
|
-
)
|
|
198
|
-
|
|
199
191
|
@property
|
|
200
192
|
def model_name(self) -> GoogleModelName:
|
|
201
193
|
"""The model name."""
|
|
@@ -291,9 +283,16 @@ class GoogleModel(Model):
|
|
|
291
283
|
'Content field missing from Gemini response', str(response)
|
|
292
284
|
) # pragma: no cover
|
|
293
285
|
parts = response.candidates[0].content.parts or []
|
|
286
|
+
vendor_id = response.response_id or None
|
|
287
|
+
vendor_details: dict[str, Any] | None = None
|
|
288
|
+
finish_reason = response.candidates[0].finish_reason
|
|
289
|
+
if finish_reason: # pragma: no branch
|
|
290
|
+
vendor_details = {'finish_reason': finish_reason.value}
|
|
294
291
|
usage = _metadata_as_usage(response)
|
|
295
292
|
usage.requests = 1
|
|
296
|
-
return _process_response_from_parts(
|
|
293
|
+
return _process_response_from_parts(
|
|
294
|
+
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
295
|
+
)
|
|
297
296
|
|
|
298
297
|
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
|
|
299
298
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -392,7 +391,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
392
391
|
|
|
393
392
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
394
393
|
async for chunk in self._response:
|
|
395
|
-
self._usage
|
|
394
|
+
self._usage = _metadata_as_usage(chunk)
|
|
396
395
|
|
|
397
396
|
assert chunk.candidates is not None
|
|
398
397
|
candidate = chunk.candidates[0]
|
|
@@ -400,7 +399,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
400
399
|
raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
|
|
401
400
|
assert candidate.content.parts is not None
|
|
402
401
|
for part in candidate.content.parts:
|
|
403
|
-
if part.text:
|
|
402
|
+
if part.text is not None:
|
|
404
403
|
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
|
|
405
404
|
elif part.function_call:
|
|
406
405
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
@@ -439,10 +438,16 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
439
438
|
return ContentDict(role='model', parts=parts)
|
|
440
439
|
|
|
441
440
|
|
|
442
|
-
def _process_response_from_parts(
|
|
441
|
+
def _process_response_from_parts(
|
|
442
|
+
parts: list[Part],
|
|
443
|
+
model_name: GoogleModelName,
|
|
444
|
+
usage: usage.Usage,
|
|
445
|
+
vendor_id: str | None,
|
|
446
|
+
vendor_details: dict[str, Any] | None = None,
|
|
447
|
+
) -> ModelResponse:
|
|
443
448
|
items: list[ModelResponsePart] = []
|
|
444
449
|
for part in parts:
|
|
445
|
-
if part.text:
|
|
450
|
+
if part.text is not None:
|
|
446
451
|
items.append(TextPart(content=part.text))
|
|
447
452
|
elif part.function_call:
|
|
448
453
|
assert part.function_call.name is not None
|
|
@@ -454,7 +459,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
|
|
|
454
459
|
raise UnexpectedModelBehavior(
|
|
455
460
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
456
461
|
)
|
|
457
|
-
return ModelResponse(
|
|
462
|
+
return ModelResponse(
|
|
463
|
+
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
464
|
+
)
|
|
458
465
|
|
|
459
466
|
|
|
460
467
|
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
|
|
@@ -475,99 +482,27 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
|
475
482
|
metadata = response.usage_metadata
|
|
476
483
|
if metadata is None:
|
|
477
484
|
return usage.Usage() # pragma: no cover
|
|
478
|
-
|
|
479
|
-
# `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
|
|
480
|
-
# handle this in the `Usage` class.
|
|
481
|
-
details = metadata.model_dump(
|
|
482
|
-
exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
|
|
483
|
-
exclude_defaults=True,
|
|
484
|
-
)
|
|
485
|
-
return usage.Usage(
|
|
486
|
-
request_tokens=details.pop('prompt_token_count', 0),
|
|
487
|
-
response_tokens=details.pop('candidates_token_count', 0),
|
|
488
|
-
total_tokens=details.pop('total_token_count', 0),
|
|
489
|
-
details=details,
|
|
490
|
-
)
|
|
485
|
+
metadata = metadata.model_dump(exclude_defaults=True)
|
|
491
486
|
|
|
487
|
+
details: dict[str, int] = {}
|
|
488
|
+
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
489
|
+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
|
|
492
490
|
|
|
493
|
-
|
|
494
|
-
|
|
491
|
+
if thoughts_token_count := metadata.get('thoughts_token_count'):
|
|
492
|
+
details['thoughts_tokens'] = thoughts_token_count
|
|
495
493
|
|
|
496
|
-
|
|
497
|
-
|
|
494
|
+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
495
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
|
|
498
496
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
497
|
+
for key, metadata_details in metadata.items():
|
|
498
|
+
if key.endswith('_details') and metadata_details:
|
|
499
|
+
suffix = key.removesuffix('_details')
|
|
500
|
+
for detail in metadata_details:
|
|
501
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
|
|
503
502
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
'additionalProperties', None
|
|
511
|
-
) # don't pop yet so it's included in the warning
|
|
512
|
-
if additional_properties: # pragma: no cover
|
|
513
|
-
original_schema = {**schema, 'additionalProperties': additional_properties}
|
|
514
|
-
warnings.warn(
|
|
515
|
-
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
|
|
516
|
-
f' Full schema: {self.schema}\n\n'
|
|
517
|
-
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
518
|
-
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
519
|
-
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
|
|
520
|
-
' and we will fix this behavior.',
|
|
521
|
-
UserWarning,
|
|
522
|
-
)
|
|
523
|
-
|
|
524
|
-
schema.pop('title', None)
|
|
525
|
-
schema.pop('default', None)
|
|
526
|
-
schema.pop('$schema', None)
|
|
527
|
-
if (const := schema.pop('const', None)) is not None: # pragma: no cover
|
|
528
|
-
# Gemini doesn't support const, but it does support enum with a single value
|
|
529
|
-
schema['enum'] = [const]
|
|
530
|
-
schema.pop('discriminator', None)
|
|
531
|
-
schema.pop('examples', None)
|
|
532
|
-
|
|
533
|
-
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
|
|
534
|
-
# where we add notes about these properties to the field description?
|
|
535
|
-
schema.pop('exclusiveMaximum', None)
|
|
536
|
-
schema.pop('exclusiveMinimum', None)
|
|
537
|
-
|
|
538
|
-
type_ = schema.get('type')
|
|
539
|
-
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
540
|
-
# This gets hit when we have a discriminated union
|
|
541
|
-
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
|
|
542
|
-
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
|
|
543
|
-
schema['anyOf'] = schema.pop('oneOf')
|
|
544
|
-
|
|
545
|
-
if type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
546
|
-
description = schema.get('description')
|
|
547
|
-
if description:
|
|
548
|
-
schema['description'] = f'{description} (format: {fmt})'
|
|
549
|
-
else:
|
|
550
|
-
schema['description'] = f'Format: {fmt}'
|
|
551
|
-
|
|
552
|
-
if '$ref' in schema:
|
|
553
|
-
raise UserError( # pragma: no cover
|
|
554
|
-
f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}'
|
|
555
|
-
)
|
|
556
|
-
|
|
557
|
-
if 'prefixItems' in schema: # pragma: lax no cover
|
|
558
|
-
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
|
|
559
|
-
prefix_items = schema.pop('prefixItems')
|
|
560
|
-
items = schema.get('items')
|
|
561
|
-
unique_items = [items] if items is not None else []
|
|
562
|
-
for item in prefix_items:
|
|
563
|
-
if item not in unique_items:
|
|
564
|
-
unique_items.append(item)
|
|
565
|
-
if len(unique_items) > 1: # pragma: no cover
|
|
566
|
-
schema['items'] = {'anyOf': unique_items}
|
|
567
|
-
elif len(unique_items) == 1:
|
|
568
|
-
schema['items'] = unique_items[0]
|
|
569
|
-
schema.setdefault('minItems', len(prefix_items))
|
|
570
|
-
if items is None:
|
|
571
|
-
schema.setdefault('maxItems', len(prefix_items))
|
|
572
|
-
|
|
573
|
-
return schema
|
|
503
|
+
return usage.Usage(
|
|
504
|
+
request_tokens=metadata.get('prompt_token_count', 0),
|
|
505
|
+
response_tokens=metadata.get('candidates_token_count', 0),
|
|
506
|
+
total_tokens=metadata.get('total_token_count', 0),
|
|
507
|
+
details=details,
|
|
508
|
+
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -27,10 +27,17 @@ from ..messages import (
|
|
|
27
27
|
ToolReturnPart,
|
|
28
28
|
UserPromptPart,
|
|
29
29
|
)
|
|
30
|
+
from ..profiles import ModelProfileSpec
|
|
30
31
|
from ..providers import Provider, infer_provider
|
|
31
32
|
from ..settings import ModelSettings
|
|
32
33
|
from ..tools import ToolDefinition
|
|
33
|
-
from . import
|
|
34
|
+
from . import (
|
|
35
|
+
Model,
|
|
36
|
+
ModelRequestParameters,
|
|
37
|
+
StreamedResponse,
|
|
38
|
+
check_allow_model_requests,
|
|
39
|
+
get_user_agent,
|
|
40
|
+
)
|
|
34
41
|
|
|
35
42
|
try:
|
|
36
43
|
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
@@ -105,7 +112,13 @@ class GroqModel(Model):
|
|
|
105
112
|
_model_name: GroqModelName = field(repr=False)
|
|
106
113
|
_system: str = field(default='groq', repr=False)
|
|
107
114
|
|
|
108
|
-
def __init__(
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: GroqModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
|
|
120
|
+
profile: ModelProfileSpec | None = None,
|
|
121
|
+
):
|
|
109
122
|
"""Initialize a Groq model.
|
|
110
123
|
|
|
111
124
|
Args:
|
|
@@ -114,12 +127,14 @@ class GroqModel(Model):
|
|
|
114
127
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
115
128
|
'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
|
|
116
129
|
created using the other parameters.
|
|
130
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
117
131
|
"""
|
|
118
132
|
self._model_name = model_name
|
|
119
133
|
|
|
120
134
|
if isinstance(provider, str):
|
|
121
135
|
provider = infer_provider(provider)
|
|
122
136
|
self.client = provider.client
|
|
137
|
+
self._profile = profile or provider.model_profile
|
|
123
138
|
|
|
124
139
|
@property
|
|
125
140
|
def base_url(self) -> str:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -30,6 +30,7 @@ from ..messages import (
|
|
|
30
30
|
UserPromptPart,
|
|
31
31
|
VideoUrl,
|
|
32
32
|
)
|
|
33
|
+
from ..profiles import ModelProfileSpec
|
|
33
34
|
from ..providers import Provider, infer_provider
|
|
34
35
|
from ..settings import ModelSettings
|
|
35
36
|
from ..tools import ToolDefinition
|
|
@@ -120,6 +121,7 @@ class MistralModel(Model):
|
|
|
120
121
|
model_name: MistralModelName,
|
|
121
122
|
*,
|
|
122
123
|
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
|
|
124
|
+
profile: ModelProfileSpec | None = None,
|
|
123
125
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
124
126
|
):
|
|
125
127
|
"""Initialize a Mistral model.
|
|
@@ -129,6 +131,7 @@ class MistralModel(Model):
|
|
|
129
131
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
130
132
|
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
|
|
131
133
|
created using the other parameters.
|
|
134
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
132
135
|
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
133
136
|
"""
|
|
134
137
|
self._model_name = model_name
|
|
@@ -137,6 +140,7 @@ class MistralModel(Model):
|
|
|
137
140
|
if isinstance(provider, str):
|
|
138
141
|
provider = infer_provider(provider)
|
|
139
142
|
self.client = provider.client
|
|
143
|
+
self._profile = profile or provider.model_profile
|
|
140
144
|
|
|
141
145
|
@property
|
|
142
146
|
def base_url(self) -> str:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import re
|
|
5
4
|
import warnings
|
|
6
5
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field
|
|
9
8
|
from datetime import datetime, timezone
|
|
10
9
|
from typing import Any, Literal, Union, cast, overload
|
|
11
10
|
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
13
|
+
from pydantic_ai.profiles.openai import OpenAIModelProfile
|
|
14
14
|
from pydantic_ai.providers import Provider, infer_provider
|
|
15
15
|
|
|
16
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
@@ -33,6 +33,7 @@ from ..messages import (
|
|
|
33
33
|
UserPromptPart,
|
|
34
34
|
VideoUrl,
|
|
35
35
|
)
|
|
36
|
+
from ..profiles import ModelProfileSpec
|
|
36
37
|
from ..settings import ModelSettings
|
|
37
38
|
from ..tools import ToolDefinition
|
|
38
39
|
from . import (
|
|
@@ -43,7 +44,6 @@ from . import (
|
|
|
43
44
|
check_allow_model_requests,
|
|
44
45
|
get_user_agent,
|
|
45
46
|
)
|
|
46
|
-
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
47
47
|
|
|
48
48
|
try:
|
|
49
49
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
@@ -170,7 +170,9 @@ class OpenAIModel(Model):
|
|
|
170
170
|
self,
|
|
171
171
|
model_name: OpenAIModelName,
|
|
172
172
|
*,
|
|
173
|
-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter'
|
|
173
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
174
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
175
|
+
profile: ModelProfileSpec | None = None,
|
|
174
176
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
175
177
|
):
|
|
176
178
|
"""Initialize an OpenAI model.
|
|
@@ -180,13 +182,17 @@ class OpenAIModel(Model):
|
|
|
180
182
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
181
183
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
182
184
|
provider: The provider to use. Defaults to `'openai'`.
|
|
185
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
183
186
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
184
187
|
In the future, this may be inferred from the model name.
|
|
185
188
|
"""
|
|
186
189
|
self._model_name = model_name
|
|
190
|
+
|
|
187
191
|
if isinstance(provider, str):
|
|
188
192
|
provider = infer_provider(provider)
|
|
189
193
|
self.client = provider.client
|
|
194
|
+
self._profile = profile or provider.model_profile
|
|
195
|
+
|
|
190
196
|
self.system_prompt_role = system_prompt_role
|
|
191
197
|
|
|
192
198
|
@property
|
|
@@ -221,9 +227,6 @@ class OpenAIModel(Model):
|
|
|
221
227
|
async with response:
|
|
222
228
|
yield await self._process_streamed_response(response)
|
|
223
229
|
|
|
224
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
225
|
-
return _customize_request_parameters(model_request_parameters)
|
|
226
|
-
|
|
227
230
|
@property
|
|
228
231
|
def model_name(self) -> OpenAIModelName:
|
|
229
232
|
"""The model name."""
|
|
@@ -331,7 +334,9 @@ class OpenAIModel(Model):
|
|
|
331
334
|
items.append(TextPart(choice.message.content))
|
|
332
335
|
if choice.message.tool_calls is not None:
|
|
333
336
|
for c in choice.message.tool_calls:
|
|
334
|
-
|
|
337
|
+
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
|
|
338
|
+
part.tool_call_id = _guard_tool_call_id(part)
|
|
339
|
+
items.append(part)
|
|
335
340
|
return ModelResponse(
|
|
336
341
|
items,
|
|
337
342
|
usage=_map_usage(response),
|
|
@@ -401,8 +406,7 @@ class OpenAIModel(Model):
|
|
|
401
406
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
402
407
|
)
|
|
403
408
|
|
|
404
|
-
|
|
405
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
409
|
+
def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
406
410
|
tool_param: chat.ChatCompletionToolParam = {
|
|
407
411
|
'type': 'function',
|
|
408
412
|
'function': {
|
|
@@ -411,7 +415,7 @@ class OpenAIModel(Model):
|
|
|
411
415
|
'parameters': f.parameters_json_schema,
|
|
412
416
|
},
|
|
413
417
|
}
|
|
414
|
-
if f.strict:
|
|
418
|
+
if f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition:
|
|
415
419
|
tool_param['function']['strict'] = f.strict
|
|
416
420
|
return tool_param
|
|
417
421
|
|
|
@@ -533,18 +537,23 @@ class OpenAIResponsesModel(Model):
|
|
|
533
537
|
self,
|
|
534
538
|
model_name: OpenAIModelName,
|
|
535
539
|
*,
|
|
536
|
-
provider: Literal['openai', 'deepseek', 'azure'
|
|
540
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
541
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
542
|
+
profile: ModelProfileSpec | None = None,
|
|
537
543
|
):
|
|
538
544
|
"""Initialize an OpenAI Responses model.
|
|
539
545
|
|
|
540
546
|
Args:
|
|
541
547
|
model_name: The name of the OpenAI model to use.
|
|
542
548
|
provider: The provider to use. Defaults to `'openai'`.
|
|
549
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
543
550
|
"""
|
|
544
551
|
self._model_name = model_name
|
|
552
|
+
|
|
545
553
|
if isinstance(provider, str):
|
|
546
554
|
provider = infer_provider(provider)
|
|
547
555
|
self.client = provider.client
|
|
556
|
+
self._profile = profile or provider.model_profile
|
|
548
557
|
|
|
549
558
|
@property
|
|
550
559
|
def model_name(self) -> OpenAIModelName:
|
|
@@ -582,9 +591,6 @@ class OpenAIResponsesModel(Model):
|
|
|
582
591
|
async with response:
|
|
583
592
|
yield await self._process_streamed_response(response)
|
|
584
593
|
|
|
585
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
586
|
-
return _customize_request_parameters(model_request_parameters)
|
|
587
|
-
|
|
588
594
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
589
595
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
590
596
|
timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
|
|
@@ -690,15 +696,15 @@ class OpenAIResponsesModel(Model):
|
|
|
690
696
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
691
697
|
return tools
|
|
692
698
|
|
|
693
|
-
|
|
694
|
-
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
|
|
699
|
+
def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
|
|
695
700
|
return {
|
|
696
701
|
'name': f.name,
|
|
697
702
|
'parameters': f.parameters_json_schema,
|
|
698
703
|
'type': 'function',
|
|
699
704
|
'description': f.description,
|
|
700
|
-
|
|
701
|
-
|
|
705
|
+
'strict': bool(
|
|
706
|
+
f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition
|
|
707
|
+
),
|
|
702
708
|
}
|
|
703
709
|
|
|
704
710
|
async def _map_messages(
|
|
@@ -980,142 +986,3 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
980
986
|
total_tokens=response_usage.total_tokens,
|
|
981
987
|
details=details,
|
|
982
988
|
)
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
_STRICT_INCOMPATIBLE_KEYS = [
|
|
986
|
-
'minLength',
|
|
987
|
-
'maxLength',
|
|
988
|
-
'pattern',
|
|
989
|
-
'format',
|
|
990
|
-
'minimum',
|
|
991
|
-
'maximum',
|
|
992
|
-
'multipleOf',
|
|
993
|
-
'patternProperties',
|
|
994
|
-
'unevaluatedProperties',
|
|
995
|
-
'propertyNames',
|
|
996
|
-
'minProperties',
|
|
997
|
-
'maxProperties',
|
|
998
|
-
'unevaluatedItems',
|
|
999
|
-
'contains',
|
|
1000
|
-
'minContains',
|
|
1001
|
-
'maxContains',
|
|
1002
|
-
'minItems',
|
|
1003
|
-
'maxItems',
|
|
1004
|
-
'uniqueItems',
|
|
1005
|
-
]
|
|
1006
|
-
|
|
1007
|
-
_sentinel = object()
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
@dataclass
|
|
1011
|
-
class _OpenAIJsonSchema(WalkJsonSchema):
|
|
1012
|
-
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
1013
|
-
|
|
1014
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
1015
|
-
but this basically just requires:
|
|
1016
|
-
* `additionalProperties` must be set to false for each object in the parameters
|
|
1017
|
-
* all fields in properties must be marked as required
|
|
1018
|
-
"""
|
|
1019
|
-
|
|
1020
|
-
def __init__(self, schema: JsonSchema, strict: bool | None):
|
|
1021
|
-
super().__init__(schema)
|
|
1022
|
-
self.strict = strict
|
|
1023
|
-
self.is_strict_compatible = True
|
|
1024
|
-
self.root_ref = schema.get('$ref')
|
|
1025
|
-
|
|
1026
|
-
def walk(self) -> JsonSchema:
|
|
1027
|
-
# Note: OpenAI does not support anyOf at the root in strict mode
|
|
1028
|
-
# However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
|
|
1029
|
-
# that the root schema either has type 'object' or is recursive.
|
|
1030
|
-
result = super().walk()
|
|
1031
|
-
|
|
1032
|
-
# For recursive models, we need to tweak the schema to make it compatible with strict mode.
|
|
1033
|
-
# Because the following should never change the semantics of the schema we apply it unconditionally.
|
|
1034
|
-
if self.root_ref is not None:
|
|
1035
|
-
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
|
|
1036
|
-
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
|
|
1037
|
-
result.update(self.defs.get(root_key) or {})
|
|
1038
|
-
|
|
1039
|
-
return result
|
|
1040
|
-
|
|
1041
|
-
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
|
|
1042
|
-
# Remove unnecessary keys
|
|
1043
|
-
schema.pop('title', None)
|
|
1044
|
-
schema.pop('default', None)
|
|
1045
|
-
schema.pop('$schema', None)
|
|
1046
|
-
schema.pop('discriminator', None)
|
|
1047
|
-
|
|
1048
|
-
if schema_ref := schema.get('$ref'):
|
|
1049
|
-
if schema_ref == self.root_ref:
|
|
1050
|
-
schema['$ref'] = '#'
|
|
1051
|
-
if len(schema) > 1:
|
|
1052
|
-
# OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
|
|
1053
|
-
# So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
|
|
1054
|
-
schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
|
|
1055
|
-
|
|
1056
|
-
# Track strict-incompatible keys
|
|
1057
|
-
incompatible_values: dict[str, Any] = {}
|
|
1058
|
-
for key in _STRICT_INCOMPATIBLE_KEYS:
|
|
1059
|
-
value = schema.get(key, _sentinel)
|
|
1060
|
-
if value is not _sentinel:
|
|
1061
|
-
incompatible_values[key] = value
|
|
1062
|
-
description = schema.get('description')
|
|
1063
|
-
if incompatible_values:
|
|
1064
|
-
if self.strict is True:
|
|
1065
|
-
notes: list[str] = []
|
|
1066
|
-
for key, value in incompatible_values.items():
|
|
1067
|
-
schema.pop(key)
|
|
1068
|
-
notes.append(f'{key}={value}')
|
|
1069
|
-
notes_string = ', '.join(notes)
|
|
1070
|
-
schema['description'] = notes_string if not description else f'{description} ({notes_string})'
|
|
1071
|
-
elif self.strict is None: # pragma: no branch
|
|
1072
|
-
self.is_strict_compatible = False
|
|
1073
|
-
|
|
1074
|
-
schema_type = schema.get('type')
|
|
1075
|
-
if 'oneOf' in schema:
|
|
1076
|
-
# OpenAI does not support oneOf in strict mode
|
|
1077
|
-
if self.strict is True:
|
|
1078
|
-
schema['anyOf'] = schema.pop('oneOf')
|
|
1079
|
-
else:
|
|
1080
|
-
self.is_strict_compatible = False
|
|
1081
|
-
|
|
1082
|
-
if schema_type == 'object':
|
|
1083
|
-
if self.strict is True:
|
|
1084
|
-
# additional properties are disallowed
|
|
1085
|
-
schema['additionalProperties'] = False
|
|
1086
|
-
|
|
1087
|
-
# all properties are required
|
|
1088
|
-
if 'properties' not in schema:
|
|
1089
|
-
schema['properties'] = dict[str, Any]()
|
|
1090
|
-
schema['required'] = list(schema['properties'].keys())
|
|
1091
|
-
|
|
1092
|
-
elif self.strict is None:
|
|
1093
|
-
if (
|
|
1094
|
-
schema.get('additionalProperties') is not False
|
|
1095
|
-
or 'properties' not in schema
|
|
1096
|
-
or 'required' not in schema
|
|
1097
|
-
):
|
|
1098
|
-
self.is_strict_compatible = False
|
|
1099
|
-
else:
|
|
1100
|
-
required = schema['required']
|
|
1101
|
-
for k in schema['properties'].keys():
|
|
1102
|
-
if k not in required:
|
|
1103
|
-
self.is_strict_compatible = False
|
|
1104
|
-
return schema
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
1108
|
-
"""Customize the request parameters for OpenAI models."""
|
|
1109
|
-
|
|
1110
|
-
def _customize_tool_def(t: ToolDefinition):
|
|
1111
|
-
schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
|
|
1112
|
-
parameters_json_schema = schema_transformer.walk()
|
|
1113
|
-
if t.strict is None:
|
|
1114
|
-
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
1115
|
-
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
1116
|
-
|
|
1117
|
-
return ModelRequestParameters(
|
|
1118
|
-
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
1119
|
-
allow_text_output=model_request_parameters.allow_text_output,
|
|
1120
|
-
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
1121
|
-
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, fields, replace
|
|
4
|
+
from typing import Callable, Union
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from ._json_schema import JsonSchemaTransformer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ModelProfile:
|
|
13
|
+
"""Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used."""
|
|
14
|
+
|
|
15
|
+
json_schema_transformer: type[JsonSchemaTransformer] | None = None
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_profile(cls, profile: ModelProfile | None) -> Self:
|
|
19
|
+
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
|
|
20
|
+
if isinstance(profile, cls):
|
|
21
|
+
return profile
|
|
22
|
+
return cls().update(profile)
|
|
23
|
+
|
|
24
|
+
def update(self, profile: ModelProfile | None) -> Self:
|
|
25
|
+
"""Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance."""
|
|
26
|
+
if not profile:
|
|
27
|
+
return self
|
|
28
|
+
field_names = set(f.name for f in fields(self))
|
|
29
|
+
non_default_attrs = {
|
|
30
|
+
f.name: getattr(profile, f.name)
|
|
31
|
+
for f in fields(profile)
|
|
32
|
+
if f.name in field_names and getattr(profile, f.name) != f.default
|
|
33
|
+
}
|
|
34
|
+
return replace(self, **non_default_attrs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
ModelProfileSpec = Union[ModelProfile, Callable[[str], Union[ModelProfile, None]]]
|
|
38
|
+
|
|
39
|
+
DEFAULT_PROFILE = ModelProfile()
|