pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +265 -118
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +68 -108
- pydantic_ai/models/google.py +45 -110
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +22 -157
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +28 -58
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
1
3
|
import re
|
|
2
4
|
from abc import ABC, abstractmethod
|
|
3
5
|
from copy import deepcopy
|
|
@@ -10,7 +12,7 @@ JsonSchema = dict[str, Any]
|
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
@dataclass(init=False)
|
|
13
|
-
class
|
|
15
|
+
class JsonSchemaTransformer(ABC):
|
|
14
16
|
"""Walks a JSON schema, applying transformations to it at each level.
|
|
15
17
|
|
|
16
18
|
Note: We may eventually want to rework tools to build the JSON schema from the type directly, using a subclass of
|
|
@@ -18,9 +20,18 @@ class WalkJsonSchema(ABC):
|
|
|
18
20
|
"""
|
|
19
21
|
|
|
20
22
|
def __init__(
|
|
21
|
-
self,
|
|
23
|
+
self,
|
|
24
|
+
schema: JsonSchema,
|
|
25
|
+
*,
|
|
26
|
+
strict: bool | None = None,
|
|
27
|
+
prefer_inlined_defs: bool = False,
|
|
28
|
+
simplify_nullable_unions: bool = False,
|
|
22
29
|
):
|
|
23
30
|
self.schema = schema
|
|
31
|
+
|
|
32
|
+
self.strict = strict
|
|
33
|
+
self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
|
|
34
|
+
|
|
24
35
|
self.prefer_inlined_defs = prefer_inlined_defs
|
|
25
36
|
self.simplify_nullable_unions = simplify_nullable_unions
|
|
26
37
|
|
|
@@ -164,3 +175,13 @@ class WalkJsonSchema(ABC):
|
|
|
164
175
|
return [cases[0]]
|
|
165
176
|
|
|
166
177
|
return cases # pragma: no cover
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer):
|
|
181
|
+
"""Transforms the JSON Schema to inline $defs."""
|
|
182
|
+
|
|
183
|
+
def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
|
|
184
|
+
super().__init__(schema, strict=strict, prefer_inlined_defs=True)
|
|
185
|
+
|
|
186
|
+
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
187
|
+
return schema
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from . import ModelProfile
|
|
4
|
+
from ._json_schema import InlineDefsJsonSchemaTransformer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def amazon_model_profile(model_name: str) -> ModelProfile | None:
|
|
8
|
+
"""Get the model profile for an Amazon model."""
|
|
9
|
+
return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.exceptions import UserError
|
|
6
|
+
|
|
7
|
+
from . import ModelProfile
|
|
8
|
+
from ._json_schema import JsonSchema, JsonSchemaTransformer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def google_model_profile(model_name: str) -> ModelProfile | None:
|
|
12
|
+
"""Get the model profile for a Google model."""
|
|
13
|
+
return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
|
|
17
|
+
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
|
|
18
|
+
|
|
19
|
+
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
|
|
20
|
+
a subset of OpenAPI v3.0.3.
|
|
21
|
+
|
|
22
|
+
Specifically:
|
|
23
|
+
* gemini doesn't allow the `title` keyword to be set
|
|
24
|
+
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
|
|
28
|
+
super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True)
|
|
29
|
+
|
|
30
|
+
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
31
|
+
# Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
|
|
32
|
+
additional_properties = schema.pop(
|
|
33
|
+
'additionalProperties', None
|
|
34
|
+
) # don't pop yet so it's included in the warning
|
|
35
|
+
if additional_properties:
|
|
36
|
+
original_schema = {**schema, 'additionalProperties': additional_properties}
|
|
37
|
+
warnings.warn(
|
|
38
|
+
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
|
|
39
|
+
f' Full schema: {self.schema}\n\n'
|
|
40
|
+
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
41
|
+
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
42
|
+
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
|
|
43
|
+
' and we will fix this behavior.',
|
|
44
|
+
UserWarning,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
schema.pop('title', None)
|
|
48
|
+
schema.pop('default', None)
|
|
49
|
+
schema.pop('$schema', None)
|
|
50
|
+
if (const := schema.pop('const', None)) is not None: # pragma: no cover
|
|
51
|
+
# Gemini doesn't support const, but it does support enum with a single value
|
|
52
|
+
schema['enum'] = [const]
|
|
53
|
+
schema.pop('discriminator', None)
|
|
54
|
+
schema.pop('examples', None)
|
|
55
|
+
|
|
56
|
+
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
|
|
57
|
+
# where we add notes about these properties to the field description?
|
|
58
|
+
schema.pop('exclusiveMaximum', None)
|
|
59
|
+
schema.pop('exclusiveMinimum', None)
|
|
60
|
+
|
|
61
|
+
# Gemini only supports string enums, so we need to convert any enum values to strings.
|
|
62
|
+
# Pydantic will take care of transforming the transformed string values to the correct type.
|
|
63
|
+
if enum := schema.get('enum'):
|
|
64
|
+
schema['type'] = 'string'
|
|
65
|
+
schema['enum'] = [str(val) for val in enum]
|
|
66
|
+
|
|
67
|
+
type_ = schema.get('type')
|
|
68
|
+
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
69
|
+
# This gets hit when we have a discriminated union
|
|
70
|
+
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
|
|
71
|
+
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
|
|
72
|
+
schema['anyOf'] = schema.pop('oneOf')
|
|
73
|
+
|
|
74
|
+
if type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
75
|
+
description = schema.get('description')
|
|
76
|
+
if description:
|
|
77
|
+
schema['description'] = f'{description} (format: {fmt})'
|
|
78
|
+
else:
|
|
79
|
+
schema['description'] = f'Format: {fmt}'
|
|
80
|
+
|
|
81
|
+
if '$ref' in schema:
|
|
82
|
+
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
|
|
83
|
+
|
|
84
|
+
if 'prefixItems' in schema:
|
|
85
|
+
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
|
|
86
|
+
prefix_items = schema.pop('prefixItems')
|
|
87
|
+
items = schema.get('items')
|
|
88
|
+
unique_items = [items] if items is not None else []
|
|
89
|
+
for item in prefix_items:
|
|
90
|
+
if item not in unique_items:
|
|
91
|
+
unique_items.append(item)
|
|
92
|
+
if len(unique_items) > 1: # pragma: no cover
|
|
93
|
+
schema['items'] = {'anyOf': unique_items}
|
|
94
|
+
elif len(unique_items) == 1: # pragma: no branch
|
|
95
|
+
schema['items'] = unique_items[0]
|
|
96
|
+
schema.setdefault('minItems', len(prefix_items))
|
|
97
|
+
if items is None: # pragma: no branch
|
|
98
|
+
schema.setdefault('maxItems', len(prefix_items))
|
|
99
|
+
|
|
100
|
+
return schema
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from . import ModelProfile
|
|
4
|
+
from ._json_schema import InlineDefsJsonSchemaTransformer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def meta_model_profile(model_name: str) -> ModelProfile | None:
|
|
8
|
+
"""Get the model profile for a Meta model."""
|
|
9
|
+
return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from . import ModelProfile
|
|
8
|
+
from ._json_schema import JsonSchema, JsonSchemaTransformer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class OpenAIModelProfile(ModelProfile):
|
|
13
|
+
"""Profile for models used with OpenAIModel.
|
|
14
|
+
|
|
15
|
+
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions
|
|
19
|
+
openai_supports_strict_tool_definition: bool = True
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
23
|
+
"""Get the model profile for an OpenAI model."""
|
|
24
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_STRICT_INCOMPATIBLE_KEYS = [
|
|
28
|
+
'minLength',
|
|
29
|
+
'maxLength',
|
|
30
|
+
'pattern',
|
|
31
|
+
'format',
|
|
32
|
+
'minimum',
|
|
33
|
+
'maximum',
|
|
34
|
+
'multipleOf',
|
|
35
|
+
'patternProperties',
|
|
36
|
+
'unevaluatedProperties',
|
|
37
|
+
'propertyNames',
|
|
38
|
+
'minProperties',
|
|
39
|
+
'maxProperties',
|
|
40
|
+
'unevaluatedItems',
|
|
41
|
+
'contains',
|
|
42
|
+
'minContains',
|
|
43
|
+
'maxContains',
|
|
44
|
+
'minItems',
|
|
45
|
+
'maxItems',
|
|
46
|
+
'uniqueItems',
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
_sentinel = object()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
|
|
54
|
+
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
55
|
+
|
|
56
|
+
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
57
|
+
but this basically just requires:
|
|
58
|
+
* `additionalProperties` must be set to false for each object in the parameters
|
|
59
|
+
* all fields in properties must be marked as required
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
|
|
63
|
+
super().__init__(schema, strict=strict)
|
|
64
|
+
self.root_ref = schema.get('$ref')
|
|
65
|
+
|
|
66
|
+
def walk(self) -> JsonSchema:
|
|
67
|
+
# Note: OpenAI does not support anyOf at the root in strict mode
|
|
68
|
+
# However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
|
|
69
|
+
# that the root schema either has type 'object' or is recursive.
|
|
70
|
+
result = super().walk()
|
|
71
|
+
|
|
72
|
+
# For recursive models, we need to tweak the schema to make it compatible with strict mode.
|
|
73
|
+
# Because the following should never change the semantics of the schema we apply it unconditionally.
|
|
74
|
+
if self.root_ref is not None:
|
|
75
|
+
result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
|
|
76
|
+
root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
|
|
77
|
+
result.update(self.defs.get(root_key) or {})
|
|
78
|
+
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
|
|
82
|
+
# Remove unnecessary keys
|
|
83
|
+
schema.pop('title', None)
|
|
84
|
+
schema.pop('default', None)
|
|
85
|
+
schema.pop('$schema', None)
|
|
86
|
+
schema.pop('discriminator', None)
|
|
87
|
+
|
|
88
|
+
if schema_ref := schema.get('$ref'):
|
|
89
|
+
if schema_ref == self.root_ref:
|
|
90
|
+
schema['$ref'] = '#'
|
|
91
|
+
if len(schema) > 1:
|
|
92
|
+
# OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
|
|
93
|
+
# So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
|
|
94
|
+
schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
|
|
95
|
+
|
|
96
|
+
# Track strict-incompatible keys
|
|
97
|
+
incompatible_values: dict[str, Any] = {}
|
|
98
|
+
for key in _STRICT_INCOMPATIBLE_KEYS:
|
|
99
|
+
value = schema.get(key, _sentinel)
|
|
100
|
+
if value is not _sentinel:
|
|
101
|
+
incompatible_values[key] = value
|
|
102
|
+
description = schema.get('description')
|
|
103
|
+
if incompatible_values:
|
|
104
|
+
if self.strict is True:
|
|
105
|
+
notes: list[str] = []
|
|
106
|
+
for key, value in incompatible_values.items():
|
|
107
|
+
schema.pop(key)
|
|
108
|
+
notes.append(f'{key}={value}')
|
|
109
|
+
notes_string = ', '.join(notes)
|
|
110
|
+
schema['description'] = notes_string if not description else f'{description} ({notes_string})'
|
|
111
|
+
elif self.strict is None: # pragma: no branch
|
|
112
|
+
self.is_strict_compatible = False
|
|
113
|
+
|
|
114
|
+
schema_type = schema.get('type')
|
|
115
|
+
if 'oneOf' in schema:
|
|
116
|
+
# OpenAI does not support oneOf in strict mode
|
|
117
|
+
if self.strict is True:
|
|
118
|
+
schema['anyOf'] = schema.pop('oneOf')
|
|
119
|
+
else:
|
|
120
|
+
self.is_strict_compatible = False
|
|
121
|
+
|
|
122
|
+
if schema_type == 'object':
|
|
123
|
+
if self.strict is True:
|
|
124
|
+
# additional properties are disallowed
|
|
125
|
+
schema['additionalProperties'] = False
|
|
126
|
+
|
|
127
|
+
# all properties are required
|
|
128
|
+
if 'properties' not in schema:
|
|
129
|
+
schema['properties'] = dict[str, Any]()
|
|
130
|
+
schema['required'] = list(schema['properties'].keys())
|
|
131
|
+
|
|
132
|
+
elif self.strict is None:
|
|
133
|
+
if (
|
|
134
|
+
schema.get('additionalProperties') is not False
|
|
135
|
+
or 'properties' not in schema
|
|
136
|
+
or 'required' not in schema
|
|
137
|
+
):
|
|
138
|
+
self.is_strict_compatible = False
|
|
139
|
+
else:
|
|
140
|
+
required = schema['required']
|
|
141
|
+
for k in schema['properties'].keys():
|
|
142
|
+
if k not in required:
|
|
143
|
+
self.is_strict_compatible = False
|
|
144
|
+
return schema
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from . import ModelProfile
|
|
4
|
+
from ._json_schema import InlineDefsJsonSchemaTransformer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def qwen_model_profile(model_name: str) -> ModelProfile | None:
|
|
8
|
+
"""Get the model profile for a Qwen model."""
|
|
9
|
+
return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
|
|
@@ -8,6 +8,8 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from typing import Any, Generic, TypeVar
|
|
10
10
|
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
|
|
11
13
|
InterfaceClient = TypeVar('InterfaceClient')
|
|
12
14
|
|
|
13
15
|
|
|
@@ -41,6 +43,10 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
41
43
|
"""The client for the provider."""
|
|
42
44
|
raise NotImplementedError()
|
|
43
45
|
|
|
46
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
47
|
+
"""The model profile for the named model, if available."""
|
|
48
|
+
return None # pragma: no cover
|
|
49
|
+
|
|
44
50
|
|
|
45
51
|
def infer_provider(provider: str) -> Provider[Any]:
|
|
46
52
|
"""Infer the provider from the provider name."""
|
|
@@ -89,5 +95,17 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
89
95
|
from .cohere import CohereProvider
|
|
90
96
|
|
|
91
97
|
return CohereProvider()
|
|
98
|
+
elif provider == 'grok':
|
|
99
|
+
from .grok import GrokProvider
|
|
100
|
+
|
|
101
|
+
return GrokProvider()
|
|
102
|
+
elif provider == 'fireworks':
|
|
103
|
+
from .fireworks import FireworksProvider
|
|
104
|
+
|
|
105
|
+
return FireworksProvider()
|
|
106
|
+
elif provider == 'together':
|
|
107
|
+
from .together import TogetherProvider
|
|
108
|
+
|
|
109
|
+
return TogetherProvider()
|
|
92
110
|
else: # pragma: no cover
|
|
93
111
|
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -7,6 +7,8 @@ import httpx
|
|
|
7
7
|
|
|
8
8
|
from pydantic_ai.exceptions import UserError
|
|
9
9
|
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
10
12
|
from pydantic_ai.providers import Provider
|
|
11
13
|
|
|
12
14
|
try:
|
|
@@ -33,6 +35,9 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
33
35
|
def client(self) -> AsyncAnthropic:
|
|
34
36
|
return self._client
|
|
35
37
|
|
|
38
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
39
|
+
return anthropic_model_profile(model_name)
|
|
40
|
+
|
|
36
41
|
@overload
|
|
37
42
|
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
|
|
38
43
|
|
pydantic_ai/providers/azure.py
CHANGED
|
@@ -8,6 +8,13 @@ from openai import AsyncOpenAI
|
|
|
8
8
|
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
10
10
|
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
13
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
14
|
+
from pydantic_ai.profiles.grok import grok_model_profile
|
|
15
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
16
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
17
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
|
|
11
18
|
from pydantic_ai.providers import Provider
|
|
12
19
|
|
|
13
20
|
try:
|
|
@@ -38,6 +45,33 @@ class AzureProvider(Provider[AsyncOpenAI]):
|
|
|
38
45
|
def client(self) -> AsyncOpenAI:
|
|
39
46
|
return self._client
|
|
40
47
|
|
|
48
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
49
|
+
model_name = model_name.lower()
|
|
50
|
+
|
|
51
|
+
prefix_to_profile = {
|
|
52
|
+
'llama': meta_model_profile,
|
|
53
|
+
'meta-': meta_model_profile,
|
|
54
|
+
'deepseek': deepseek_model_profile,
|
|
55
|
+
'mistralai-': mistral_model_profile,
|
|
56
|
+
'mistral': mistral_model_profile,
|
|
57
|
+
'cohere-': cohere_model_profile,
|
|
58
|
+
'grok': grok_model_profile,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
62
|
+
if model_name.startswith(prefix):
|
|
63
|
+
if prefix.endswith('-'):
|
|
64
|
+
model_name = model_name[len(prefix) :]
|
|
65
|
+
|
|
66
|
+
profile = profile_func(model_name)
|
|
67
|
+
|
|
68
|
+
# As AzureProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
69
|
+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
|
|
70
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
71
|
+
|
|
72
|
+
# OpenAI models are unprefixed
|
|
73
|
+
return openai_model_profile(model_name)
|
|
74
|
+
|
|
41
75
|
@overload
|
|
42
76
|
def __init__(self, *, openai_client: AsyncAzureOpenAI) -> None: ...
|
|
43
77
|
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Callable, Literal, overload
|
|
5
7
|
|
|
6
8
|
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.amazon import amazon_model_profile
|
|
11
|
+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
12
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
13
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
14
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
7
16
|
from pydantic_ai.providers import Provider
|
|
8
17
|
|
|
9
18
|
try:
|
|
@@ -18,6 +27,17 @@ except ImportError as _import_error:
|
|
|
18
27
|
) from _import_error
|
|
19
28
|
|
|
20
29
|
|
|
30
|
+
@dataclass
|
|
31
|
+
class BedrockModelProfile(ModelProfile):
|
|
32
|
+
"""Profile for models used with BedrockModel.
|
|
33
|
+
|
|
34
|
+
ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
bedrock_supports_tool_choice: bool = True
|
|
38
|
+
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
|
|
39
|
+
|
|
40
|
+
|
|
21
41
|
class BedrockProvider(Provider[BaseClient]):
|
|
22
42
|
"""Provider for AWS Bedrock."""
|
|
23
43
|
|
|
@@ -33,6 +53,45 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
33
53
|
def client(self) -> BaseClient:
|
|
34
54
|
return self._client
|
|
35
55
|
|
|
56
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
57
|
+
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
|
|
58
|
+
'anthropic': lambda model_name: BedrockModelProfile(bedrock_supports_tool_choice=False).update(
|
|
59
|
+
anthropic_model_profile(model_name)
|
|
60
|
+
),
|
|
61
|
+
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
|
|
62
|
+
mistral_model_profile(model_name)
|
|
63
|
+
),
|
|
64
|
+
'cohere': cohere_model_profile,
|
|
65
|
+
'amazon': amazon_model_profile,
|
|
66
|
+
'meta': meta_model_profile,
|
|
67
|
+
'deepseek': deepseek_model_profile,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Split the model name into parts
|
|
71
|
+
parts = model_name.split('.', 2)
|
|
72
|
+
|
|
73
|
+
# Handle regional prefixes (e.g. "us.")
|
|
74
|
+
if len(parts) > 2 and len(parts[0]) == 2:
|
|
75
|
+
parts = parts[1:]
|
|
76
|
+
|
|
77
|
+
if len(parts) < 2:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
provider = parts[0]
|
|
81
|
+
model_name_with_version = parts[1]
|
|
82
|
+
|
|
83
|
+
# Remove version suffix if it matches the format (e.g. "-v1:0" or "-v14")
|
|
84
|
+
version_match = re.match(r'(.+)-v\d+(?::\d+)?$', model_name_with_version)
|
|
85
|
+
if version_match:
|
|
86
|
+
model_name = version_match.group(1)
|
|
87
|
+
else:
|
|
88
|
+
model_name = model_name_with_version
|
|
89
|
+
|
|
90
|
+
if provider in provider_to_profile:
|
|
91
|
+
return provider_to_profile[provider](model_name)
|
|
92
|
+
|
|
93
|
+
return None
|
|
94
|
+
|
|
36
95
|
@overload
|
|
37
96
|
def __init__(self, *, bedrock_client: BaseClient) -> None: ...
|
|
38
97
|
|
pydantic_ai/providers/cohere.py
CHANGED
|
@@ -6,6 +6,8 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
6
6
|
|
|
7
7
|
from pydantic_ai.exceptions import UserError
|
|
8
8
|
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
from pydantic_ai.profiles import ModelProfile
|
|
10
|
+
from pydantic_ai.profiles.cohere import cohere_model_profile
|
|
9
11
|
from pydantic_ai.providers import Provider
|
|
10
12
|
|
|
11
13
|
try:
|
|
@@ -33,6 +35,9 @@ class CohereProvider(Provider[AsyncClientV2]):
|
|
|
33
35
|
def client(self) -> AsyncClientV2:
|
|
34
36
|
return self._client
|
|
35
37
|
|
|
38
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
39
|
+
return cohere_model_profile(model_name)
|
|
40
|
+
|
|
36
41
|
def __init__(
|
|
37
42
|
self,
|
|
38
43
|
*,
|
|
@@ -8,6 +8,9 @@ from openai import AsyncOpenAI
|
|
|
8
8
|
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
10
10
|
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles import ModelProfile
|
|
12
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
13
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
11
14
|
from pydantic_ai.providers import Provider
|
|
12
15
|
|
|
13
16
|
try:
|
|
@@ -34,6 +37,15 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
|
34
37
|
def client(self) -> AsyncOpenAI:
|
|
35
38
|
return self._client
|
|
36
39
|
|
|
40
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
41
|
+
profile = deepseek_model_profile(model_name)
|
|
42
|
+
|
|
43
|
+
# As DeepSeekProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
44
|
+
# we need to maintain that behavior unless json_schema_transformer is set explicitly.
|
|
45
|
+
# This was not the case when using a DeepSeek model with another model class (e.g. BedrockConverseModel or GroqModel),
|
|
46
|
+
# so we won't do this in `deepseek_model_profile` unless we learn it's always needed.
|
|
47
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
48
|
+
|
|
37
49
|
@overload
|
|
38
50
|
def __init__(self) -> None: ...
|
|
39
51
|
|