pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +265 -118
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +68 -108
- pydantic_ai/models/google.py +45 -110
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +22 -157
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +28 -58
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import warnings
|
|
5
4
|
from collections.abc import AsyncIterator, Awaitable
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime
|
|
9
|
-
from typing import Literal, Union, cast, overload
|
|
8
|
+
from typing import Any, Literal, Union, cast, overload
|
|
10
9
|
from uuid import uuid4
|
|
11
10
|
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from pydantic_ai.providers import Provider
|
|
15
14
|
|
|
16
|
-
from .. import UnexpectedModelBehavior,
|
|
15
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
17
16
|
from ..messages import (
|
|
18
17
|
AudioUrl,
|
|
19
18
|
BinaryContent,
|
|
@@ -32,6 +31,7 @@ from ..messages import (
|
|
|
32
31
|
UserPromptPart,
|
|
33
32
|
VideoUrl,
|
|
34
33
|
)
|
|
34
|
+
from ..profiles import ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import (
|
|
@@ -42,7 +42,6 @@ from . import (
|
|
|
42
42
|
check_allow_model_requests,
|
|
43
43
|
get_user_agent,
|
|
44
44
|
)
|
|
45
|
-
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
46
45
|
|
|
47
46
|
try:
|
|
48
47
|
from google import genai
|
|
@@ -141,6 +140,7 @@ class GoogleModel(Model):
|
|
|
141
140
|
model_name: GoogleModelName,
|
|
142
141
|
*,
|
|
143
142
|
provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
|
|
143
|
+
profile: ModelProfileSpec | None = None,
|
|
144
144
|
):
|
|
145
145
|
"""Initialize a Gemini model.
|
|
146
146
|
|
|
@@ -149,6 +149,7 @@ class GoogleModel(Model):
|
|
|
149
149
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
150
150
|
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
151
151
|
If not provided, a new provider will be created using the other parameters.
|
|
152
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
152
153
|
"""
|
|
153
154
|
self._model_name = model_name
|
|
154
155
|
|
|
@@ -158,6 +159,7 @@ class GoogleModel(Model):
|
|
|
158
159
|
self._provider = provider
|
|
159
160
|
self._system = provider.name
|
|
160
161
|
self.client = provider.client
|
|
162
|
+
self._profile = profile or provider.model_profile
|
|
161
163
|
|
|
162
164
|
@property
|
|
163
165
|
def base_url(self) -> str:
|
|
@@ -186,16 +188,6 @@ class GoogleModel(Model):
|
|
|
186
188
|
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
|
|
187
189
|
yield await self._process_streamed_response(response) # type: ignore
|
|
188
190
|
|
|
189
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
190
|
-
def _customize_tool_def(t: ToolDefinition):
|
|
191
|
-
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
|
|
192
|
-
|
|
193
|
-
return ModelRequestParameters(
|
|
194
|
-
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
195
|
-
allow_text_output=model_request_parameters.allow_text_output,
|
|
196
|
-
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
197
|
-
)
|
|
198
|
-
|
|
199
191
|
@property
|
|
200
192
|
def model_name(self) -> GoogleModelName:
|
|
201
193
|
"""The model name."""
|
|
@@ -291,9 +283,16 @@ class GoogleModel(Model):
|
|
|
291
283
|
'Content field missing from Gemini response', str(response)
|
|
292
284
|
) # pragma: no cover
|
|
293
285
|
parts = response.candidates[0].content.parts or []
|
|
286
|
+
vendor_id = response.response_id or None
|
|
287
|
+
vendor_details: dict[str, Any] | None = None
|
|
288
|
+
finish_reason = response.candidates[0].finish_reason
|
|
289
|
+
if finish_reason: # pragma: no branch
|
|
290
|
+
vendor_details = {'finish_reason': finish_reason.value}
|
|
294
291
|
usage = _metadata_as_usage(response)
|
|
295
292
|
usage.requests = 1
|
|
296
|
-
return _process_response_from_parts(
|
|
293
|
+
return _process_response_from_parts(
|
|
294
|
+
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
295
|
+
)
|
|
297
296
|
|
|
298
297
|
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
|
|
299
298
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -392,7 +391,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
392
391
|
|
|
393
392
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
394
393
|
async for chunk in self._response:
|
|
395
|
-
self._usage
|
|
394
|
+
self._usage = _metadata_as_usage(chunk)
|
|
396
395
|
|
|
397
396
|
assert chunk.candidates is not None
|
|
398
397
|
candidate = chunk.candidates[0]
|
|
@@ -439,7 +438,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
|
439
438
|
return ContentDict(role='model', parts=parts)
|
|
440
439
|
|
|
441
440
|
|
|
442
|
-
def _process_response_from_parts(
|
|
441
|
+
def _process_response_from_parts(
|
|
442
|
+
parts: list[Part],
|
|
443
|
+
model_name: GoogleModelName,
|
|
444
|
+
usage: usage.Usage,
|
|
445
|
+
vendor_id: str | None,
|
|
446
|
+
vendor_details: dict[str, Any] | None = None,
|
|
447
|
+
) -> ModelResponse:
|
|
443
448
|
items: list[ModelResponsePart] = []
|
|
444
449
|
for part in parts:
|
|
445
450
|
if part.text:
|
|
@@ -454,7 +459,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
|
|
|
454
459
|
raise UnexpectedModelBehavior(
|
|
455
460
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
456
461
|
)
|
|
457
|
-
return ModelResponse(
|
|
462
|
+
return ModelResponse(
|
|
463
|
+
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
|
|
464
|
+
)
|
|
458
465
|
|
|
459
466
|
|
|
460
467
|
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
|
|
@@ -475,99 +482,27 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
|
|
|
475
482
|
metadata = response.usage_metadata
|
|
476
483
|
if metadata is None:
|
|
477
484
|
return usage.Usage() # pragma: no cover
|
|
478
|
-
|
|
479
|
-
# `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
|
|
480
|
-
# handle this in the `Usage` class.
|
|
481
|
-
details = metadata.model_dump(
|
|
482
|
-
exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
|
|
483
|
-
exclude_defaults=True,
|
|
484
|
-
)
|
|
485
|
-
return usage.Usage(
|
|
486
|
-
request_tokens=details.pop('prompt_token_count', 0),
|
|
487
|
-
response_tokens=details.pop('candidates_token_count', 0),
|
|
488
|
-
total_tokens=details.pop('total_token_count', 0),
|
|
489
|
-
details=details,
|
|
490
|
-
)
|
|
485
|
+
metadata = metadata.model_dump(exclude_defaults=True)
|
|
491
486
|
|
|
487
|
+
details: dict[str, int] = {}
|
|
488
|
+
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
489
|
+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
|
|
492
490
|
|
|
493
|
-
|
|
494
|
-
|
|
491
|
+
if thoughts_token_count := metadata.get('thoughts_token_count'):
|
|
492
|
+
details['thoughts_tokens'] = thoughts_token_count
|
|
495
493
|
|
|
496
|
-
|
|
497
|
-
|
|
494
|
+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
495
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
|
|
498
496
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
497
|
+
for key, metadata_details in metadata.items():
|
|
498
|
+
if key.endswith('_details') and metadata_details:
|
|
499
|
+
suffix = key.removesuffix('_details')
|
|
500
|
+
for detail in metadata_details:
|
|
501
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
|
|
503
502
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
'additionalProperties', None
|
|
511
|
-
) # don't pop yet so it's included in the warning
|
|
512
|
-
if additional_properties: # pragma: no cover
|
|
513
|
-
original_schema = {**schema, 'additionalProperties': additional_properties}
|
|
514
|
-
warnings.warn(
|
|
515
|
-
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
|
|
516
|
-
f' Full schema: {self.schema}\n\n'
|
|
517
|
-
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
518
|
-
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
519
|
-
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
|
|
520
|
-
' and we will fix this behavior.',
|
|
521
|
-
UserWarning,
|
|
522
|
-
)
|
|
523
|
-
|
|
524
|
-
schema.pop('title', None)
|
|
525
|
-
schema.pop('default', None)
|
|
526
|
-
schema.pop('$schema', None)
|
|
527
|
-
if (const := schema.pop('const', None)) is not None: # pragma: no cover
|
|
528
|
-
# Gemini doesn't support const, but it does support enum with a single value
|
|
529
|
-
schema['enum'] = [const]
|
|
530
|
-
schema.pop('discriminator', None)
|
|
531
|
-
schema.pop('examples', None)
|
|
532
|
-
|
|
533
|
-
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
|
|
534
|
-
# where we add notes about these properties to the field description?
|
|
535
|
-
schema.pop('exclusiveMaximum', None)
|
|
536
|
-
schema.pop('exclusiveMinimum', None)
|
|
537
|
-
|
|
538
|
-
type_ = schema.get('type')
|
|
539
|
-
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
540
|
-
# This gets hit when we have a discriminated union
|
|
541
|
-
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
|
|
542
|
-
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
|
|
543
|
-
schema['anyOf'] = schema.pop('oneOf')
|
|
544
|
-
|
|
545
|
-
if type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
546
|
-
description = schema.get('description')
|
|
547
|
-
if description:
|
|
548
|
-
schema['description'] = f'{description} (format: {fmt})'
|
|
549
|
-
else:
|
|
550
|
-
schema['description'] = f'Format: {fmt}'
|
|
551
|
-
|
|
552
|
-
if '$ref' in schema:
|
|
553
|
-
raise UserError( # pragma: no cover
|
|
554
|
-
f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}'
|
|
555
|
-
)
|
|
556
|
-
|
|
557
|
-
if 'prefixItems' in schema: # pragma: lax no cover
|
|
558
|
-
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
|
|
559
|
-
prefix_items = schema.pop('prefixItems')
|
|
560
|
-
items = schema.get('items')
|
|
561
|
-
unique_items = [items] if items is not None else []
|
|
562
|
-
for item in prefix_items:
|
|
563
|
-
if item not in unique_items:
|
|
564
|
-
unique_items.append(item)
|
|
565
|
-
if len(unique_items) > 1: # pragma: no cover
|
|
566
|
-
schema['items'] = {'anyOf': unique_items}
|
|
567
|
-
elif len(unique_items) == 1:
|
|
568
|
-
schema['items'] = unique_items[0]
|
|
569
|
-
schema.setdefault('minItems', len(prefix_items))
|
|
570
|
-
if items is None:
|
|
571
|
-
schema.setdefault('maxItems', len(prefix_items))
|
|
572
|
-
|
|
573
|
-
return schema
|
|
503
|
+
return usage.Usage(
|
|
504
|
+
request_tokens=metadata.get('prompt_token_count', 0),
|
|
505
|
+
response_tokens=metadata.get('candidates_token_count', 0),
|
|
506
|
+
total_tokens=metadata.get('total_token_count', 0),
|
|
507
|
+
details=details,
|
|
508
|
+
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -27,10 +27,17 @@ from ..messages import (
|
|
|
27
27
|
ToolReturnPart,
|
|
28
28
|
UserPromptPart,
|
|
29
29
|
)
|
|
30
|
+
from ..profiles import ModelProfileSpec
|
|
30
31
|
from ..providers import Provider, infer_provider
|
|
31
32
|
from ..settings import ModelSettings
|
|
32
33
|
from ..tools import ToolDefinition
|
|
33
|
-
from . import
|
|
34
|
+
from . import (
|
|
35
|
+
Model,
|
|
36
|
+
ModelRequestParameters,
|
|
37
|
+
StreamedResponse,
|
|
38
|
+
check_allow_model_requests,
|
|
39
|
+
get_user_agent,
|
|
40
|
+
)
|
|
34
41
|
|
|
35
42
|
try:
|
|
36
43
|
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
@@ -105,7 +112,13 @@ class GroqModel(Model):
|
|
|
105
112
|
_model_name: GroqModelName = field(repr=False)
|
|
106
113
|
_system: str = field(default='groq', repr=False)
|
|
107
114
|
|
|
108
|
-
def __init__(
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: GroqModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
|
|
120
|
+
profile: ModelProfileSpec | None = None,
|
|
121
|
+
):
|
|
109
122
|
"""Initialize a Groq model.
|
|
110
123
|
|
|
111
124
|
Args:
|
|
@@ -114,12 +127,14 @@ class GroqModel(Model):
|
|
|
114
127
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
115
128
|
'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
|
|
116
129
|
created using the other parameters.
|
|
130
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
117
131
|
"""
|
|
118
132
|
self._model_name = model_name
|
|
119
133
|
|
|
120
134
|
if isinstance(provider, str):
|
|
121
135
|
provider = infer_provider(provider)
|
|
122
136
|
self.client = provider.client
|
|
137
|
+
self._profile = profile or provider.model_profile
|
|
123
138
|
|
|
124
139
|
@property
|
|
125
140
|
def base_url(self) -> str:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -30,6 +30,7 @@ from ..messages import (
|
|
|
30
30
|
UserPromptPart,
|
|
31
31
|
VideoUrl,
|
|
32
32
|
)
|
|
33
|
+
from ..profiles import ModelProfileSpec
|
|
33
34
|
from ..providers import Provider, infer_provider
|
|
34
35
|
from ..settings import ModelSettings
|
|
35
36
|
from ..tools import ToolDefinition
|
|
@@ -120,6 +121,7 @@ class MistralModel(Model):
|
|
|
120
121
|
model_name: MistralModelName,
|
|
121
122
|
*,
|
|
122
123
|
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
|
|
124
|
+
profile: ModelProfileSpec | None = None,
|
|
123
125
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
124
126
|
):
|
|
125
127
|
"""Initialize a Mistral model.
|
|
@@ -129,6 +131,7 @@ class MistralModel(Model):
|
|
|
129
131
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
130
132
|
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
|
|
131
133
|
created using the other parameters.
|
|
134
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
132
135
|
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
133
136
|
"""
|
|
134
137
|
self._model_name = model_name
|
|
@@ -137,6 +140,7 @@ class MistralModel(Model):
|
|
|
137
140
|
if isinstance(provider, str):
|
|
138
141
|
provider = infer_provider(provider)
|
|
139
142
|
self.client = provider.client
|
|
143
|
+
self._profile = profile or provider.model_profile
|
|
140
144
|
|
|
141
145
|
@property
|
|
142
146
|
def base_url(self) -> str:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import re
|
|
5
4
|
import warnings
|
|
6
5
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field
|
|
9
8
|
from datetime import datetime, timezone
|
|
10
9
|
from typing import Any, Literal, Union, cast, overload
|
|
11
10
|
|
|
12
11
|
from typing_extensions import assert_never
|
|
13
12
|
|
|
13
|
+
from pydantic_ai.profiles.openai import OpenAIModelProfile
|
|
14
14
|
from pydantic_ai.providers import Provider, infer_provider
|
|
15
15
|
|
|
16
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
@@ -33,6 +33,7 @@ from ..messages import (
|
|
|
33
33
|
UserPromptPart,
|
|
34
34
|
VideoUrl,
|
|
35
35
|
)
|
|
36
|
+
from ..profiles import ModelProfileSpec
|
|
36
37
|
from ..settings import ModelSettings
|
|
37
38
|
from ..tools import ToolDefinition
|
|
38
39
|
from . import (
|
|
@@ -43,7 +44,6 @@ from . import (
|
|
|
43
44
|
check_allow_model_requests,
|
|
44
45
|
get_user_agent,
|
|
45
46
|
)
|
|
46
|
-
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
47
47
|
|
|
48
48
|
try:
|
|
49
49
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
@@ -170,7 +170,9 @@ class OpenAIModel(Model):
|
|
|
170
170
|
self,
|
|
171
171
|
model_name: OpenAIModelName,
|
|
172
172
|
*,
|
|
173
|
-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter'
|
|
173
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
174
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
175
|
+
profile: ModelProfileSpec | None = None,
|
|
174
176
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
175
177
|
):
|
|
176
178
|
"""Initialize an OpenAI model.
|
|
@@ -180,13 +182,17 @@ class OpenAIModel(Model):
|
|
|
180
182
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
181
183
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
182
184
|
provider: The provider to use. Defaults to `'openai'`.
|
|
185
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
183
186
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
184
187
|
In the future, this may be inferred from the model name.
|
|
185
188
|
"""
|
|
186
189
|
self._model_name = model_name
|
|
190
|
+
|
|
187
191
|
if isinstance(provider, str):
|
|
188
192
|
provider = infer_provider(provider)
|
|
189
193
|
self.client = provider.client
|
|
194
|
+
self._profile = profile or provider.model_profile
|
|
195
|
+
|
|
190
196
|
self.system_prompt_role = system_prompt_role
|
|
191
197
|
|
|
192
198
|
@property
|
|
@@ -221,9 +227,6 @@ class OpenAIModel(Model):
|
|
|
221
227
|
async with response:
|
|
222
228
|
yield await self._process_streamed_response(response)
|
|
223
229
|
|
|
224
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
225
|
-
return _customize_request_parameters(model_request_parameters)
|
|
226
|
-
|
|
227
230
|
@property
|
|
228
231
|
def model_name(self) -> OpenAIModelName:
|
|
229
232
|
"""The model name."""
|
|
@@ -401,8 +404,7 @@ class OpenAIModel(Model):
|
|
|
401
404
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
402
405
|
)
|
|
403
406
|
|
|
404
|
-
|
|
405
|
-
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
407
|
+
def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
406
408
|
tool_param: chat.ChatCompletionToolParam = {
|
|
407
409
|
'type': 'function',
|
|
408
410
|
'function': {
|
|
@@ -411,7 +413,7 @@ class OpenAIModel(Model):
|
|
|
411
413
|
'parameters': f.parameters_json_schema,
|
|
412
414
|
},
|
|
413
415
|
}
|
|
414
|
-
if f.strict:
|
|
416
|
+
if f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition:
|
|
415
417
|
tool_param['function']['strict'] = f.strict
|
|
416
418
|
return tool_param
|
|
417
419
|
|
|
@@ -533,18 +535,23 @@ class OpenAIResponsesModel(Model):
|
|
|
533
535
|
self,
|
|
534
536
|
model_name: OpenAIModelName,
|
|
535
537
|
*,
|
|
536
|
-
provider: Literal['openai', 'deepseek', 'azure'
|
|
538
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
539
|
+
| Provider[AsyncOpenAI] = 'openai',
|
|
540
|
+
profile: ModelProfileSpec | None = None,
|
|
537
541
|
):
|
|
538
542
|
"""Initialize an OpenAI Responses model.
|
|
539
543
|
|
|
540
544
|
Args:
|
|
541
545
|
model_name: The name of the OpenAI model to use.
|
|
542
546
|
provider: The provider to use. Defaults to `'openai'`.
|
|
547
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
543
548
|
"""
|
|
544
549
|
self._model_name = model_name
|
|
550
|
+
|
|
545
551
|
if isinstance(provider, str):
|
|
546
552
|
provider = infer_provider(provider)
|
|
547
553
|
self.client = provider.client
|
|
554
|
+
self._profile = profile or provider.model_profile
|
|
548
555
|
|
|
549
556
|
@property
|
|
550
557
|
def model_name(self) -> OpenAIModelName:
|
|
@@ -582,9 +589,6 @@ class OpenAIResponsesModel(Model):
|
|
|
582
589
|
async with response:
|
|
583
590
|
yield await self._process_streamed_response(response)
|
|
584
591
|
|
|
585
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
586
|
-
return _customize_request_parameters(model_request_parameters)
|
|
587
|
-
|
|
588
592
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
589
593
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
590
594
|
timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
|
|
@@ -690,15 +694,15 @@ class OpenAIResponsesModel(Model):
|
|
|
690
694
|
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
691
695
|
return tools
|
|
692
696
|
|
|
693
|
-
|
|
694
|
-
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
|
|
697
|
+
def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam:
|
|
695
698
|
return {
|
|
696
699
|
'name': f.name,
|
|
697
700
|
'parameters': f.parameters_json_schema,
|
|
698
701
|
'type': 'function',
|
|
699
702
|
'description': f.description,
|
|
700
|
-
|
|
701
|
-
|
|
703
|
+
'strict': bool(
|
|
704
|
+
f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition
|
|
705
|
+
),
|
|
702
706
|
}
|
|
703
707
|
|
|
704
708
|
async def _map_messages(
|
|
@@ -980,142 +984,3 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
980
984
|
total_tokens=response_usage.total_tokens,
|
|
981
985
|
details=details,
|
|
982
986
|
)
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
_STRICT_INCOMPATIBLE_KEYS = [
|
|
986
|
-
'minLength',
|
|
987
|
-
'maxLength',
|
|
988
|
-
'pattern',
|
|
989
|
-
'format',
|
|
990
|
-
'minimum',
|
|
991
|
-
'maximum',
|
|
992
|
-
'multipleOf',
|
|
993
|
-
'patternProperties',
|
|
994
|
-
'unevaluatedProperties',
|
|
995
|
-
'propertyNames',
|
|
996
|
-
'minProperties',
|
|
997
|
-
'maxProperties',
|
|
998
|
-
'unevaluatedItems',
|
|
999
|
-
'contains',
|
|
1000
|
-
'minContains',
|
|
1001
|
-
'maxContains',
|
|
1002
|
-
'minItems',
|
|
1003
|
-
'maxItems',
|
|
1004
|
-
'uniqueItems',
|
|
1005
|
-
]
|
|
1006
|
-
|
|
1007
|
-
_sentinel = object()
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
@dataclass
|
|
1011
|
-
class _OpenAIJsonSchema(WalkJsonSchema):
|
|
1012
|
-
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
1013
|
-
|
|
1014
|
-
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
1015
|
-
but this basically just requires:
|
|
1016
|
-
* `additionalProperties` must be set to false for each object in the parameters
|
|
1017
|
-
* all fields in properties must be marked as required
|
|
1018
|
-
"""
|
|
1019
|
-
|
|
1020
|
-
def __init__(self, schema: JsonSchema, strict: bool | None):
|
|
1021
|
-
super().__init__(schema)
|
|
1022
|
-
self.strict = strict
|
|
1023
|
-
self.is_strict_compatible = True
|
|
1024
|
-
self.root_ref = schema.get('$ref')
|
|
1025
|
-
|
|
1026
|
-
def walk(self) -> JsonSchema:
|
|
1027
|
-
# Note: OpenAI does not support anyOf at the root in strict mode
|
|
1028
|
-
# However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
|
|
1029
|
-
# that the root schema either has type 'object' or is recursive.
|
|
1030
|
-
result = super().walk()
|
|
1031
|
-
|
|
1032
|
-
# For recursive models, we need to tweak the schema to make it compatible with strict mode.
|
|
1033
|
-
# Because the following should never change the semantics of the schema we apply it unconditionally.
|
|
1034
|
-
if self.root_ref is not None:
|
|
1035
|
-
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
|
|
1036
|
-
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
|
|
1037
|
-
result.update(self.defs.get(root_key) or {})
|
|
1038
|
-
|
|
1039
|
-
return result
|
|
1040
|
-
|
|
1041
|
-
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
|
|
1042
|
-
# Remove unnecessary keys
|
|
1043
|
-
schema.pop('title', None)
|
|
1044
|
-
schema.pop('default', None)
|
|
1045
|
-
schema.pop('$schema', None)
|
|
1046
|
-
schema.pop('discriminator', None)
|
|
1047
|
-
|
|
1048
|
-
if schema_ref := schema.get('$ref'):
|
|
1049
|
-
if schema_ref == self.root_ref:
|
|
1050
|
-
schema['$ref'] = '#'
|
|
1051
|
-
if len(schema) > 1:
|
|
1052
|
-
# OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
|
|
1053
|
-
# So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
|
|
1054
|
-
schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
|
|
1055
|
-
|
|
1056
|
-
# Track strict-incompatible keys
|
|
1057
|
-
incompatible_values: dict[str, Any] = {}
|
|
1058
|
-
for key in _STRICT_INCOMPATIBLE_KEYS:
|
|
1059
|
-
value = schema.get(key, _sentinel)
|
|
1060
|
-
if value is not _sentinel:
|
|
1061
|
-
incompatible_values[key] = value
|
|
1062
|
-
description = schema.get('description')
|
|
1063
|
-
if incompatible_values:
|
|
1064
|
-
if self.strict is True:
|
|
1065
|
-
notes: list[str] = []
|
|
1066
|
-
for key, value in incompatible_values.items():
|
|
1067
|
-
schema.pop(key)
|
|
1068
|
-
notes.append(f'{key}={value}')
|
|
1069
|
-
notes_string = ', '.join(notes)
|
|
1070
|
-
schema['description'] = notes_string if not description else f'{description} ({notes_string})'
|
|
1071
|
-
elif self.strict is None: # pragma: no branch
|
|
1072
|
-
self.is_strict_compatible = False
|
|
1073
|
-
|
|
1074
|
-
schema_type = schema.get('type')
|
|
1075
|
-
if 'oneOf' in schema:
|
|
1076
|
-
# OpenAI does not support oneOf in strict mode
|
|
1077
|
-
if self.strict is True:
|
|
1078
|
-
schema['anyOf'] = schema.pop('oneOf')
|
|
1079
|
-
else:
|
|
1080
|
-
self.is_strict_compatible = False
|
|
1081
|
-
|
|
1082
|
-
if schema_type == 'object':
|
|
1083
|
-
if self.strict is True:
|
|
1084
|
-
# additional properties are disallowed
|
|
1085
|
-
schema['additionalProperties'] = False
|
|
1086
|
-
|
|
1087
|
-
# all properties are required
|
|
1088
|
-
if 'properties' not in schema:
|
|
1089
|
-
schema['properties'] = dict[str, Any]()
|
|
1090
|
-
schema['required'] = list(schema['properties'].keys())
|
|
1091
|
-
|
|
1092
|
-
elif self.strict is None:
|
|
1093
|
-
if (
|
|
1094
|
-
schema.get('additionalProperties') is not False
|
|
1095
|
-
or 'properties' not in schema
|
|
1096
|
-
or 'required' not in schema
|
|
1097
|
-
):
|
|
1098
|
-
self.is_strict_compatible = False
|
|
1099
|
-
else:
|
|
1100
|
-
required = schema['required']
|
|
1101
|
-
for k in schema['properties'].keys():
|
|
1102
|
-
if k not in required:
|
|
1103
|
-
self.is_strict_compatible = False
|
|
1104
|
-
return schema
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
1108
|
-
"""Customize the request parameters for OpenAI models."""
|
|
1109
|
-
|
|
1110
|
-
def _customize_tool_def(t: ToolDefinition):
|
|
1111
|
-
schema_transformer = _OpenAIJsonSchema(t.parameters_json_schema, strict=t.strict)
|
|
1112
|
-
parameters_json_schema = schema_transformer.walk()
|
|
1113
|
-
if t.strict is None:
|
|
1114
|
-
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
1115
|
-
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
1116
|
-
|
|
1117
|
-
return ModelRequestParameters(
|
|
1118
|
-
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
1119
|
-
allow_text_output=model_request_parameters.allow_text_output,
|
|
1120
|
-
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
1121
|
-
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, fields, replace
|
|
4
|
+
from typing import Callable, Union
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from ._json_schema import JsonSchemaTransformer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ModelProfile:
|
|
13
|
+
"""Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used."""
|
|
14
|
+
|
|
15
|
+
json_schema_transformer: type[JsonSchemaTransformer] | None = None
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def from_profile(cls, profile: ModelProfile | None) -> Self:
|
|
19
|
+
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
|
|
20
|
+
if isinstance(profile, cls):
|
|
21
|
+
return profile
|
|
22
|
+
return cls().update(profile)
|
|
23
|
+
|
|
24
|
+
def update(self, profile: ModelProfile | None) -> Self:
|
|
25
|
+
"""Update this ModelProfile (subclass) instance with the non-default values from another ModelProfile instance."""
|
|
26
|
+
if not profile:
|
|
27
|
+
return self
|
|
28
|
+
field_names = set(f.name for f in fields(self))
|
|
29
|
+
non_default_attrs = {
|
|
30
|
+
f.name: getattr(profile, f.name)
|
|
31
|
+
for f in fields(profile)
|
|
32
|
+
if f.name in field_names and getattr(profile, f.name) != f.default
|
|
33
|
+
}
|
|
34
|
+
return replace(self, **non_default_attrs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
ModelProfileSpec = Union[ModelProfile, Callable[[str], Union[ModelProfile, None]]]
|
|
38
|
+
|
|
39
|
+
DEFAULT_PROFILE = ModelProfile()
|