pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +266 -119
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +4 -0
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +78 -109
  12. pydantic_ai/models/google.py +47 -112
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +25 -158
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +29 -59
  46. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
  47. pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations as _annotations
2
+
1
3
  import re
2
4
  from abc import ABC, abstractmethod
3
5
  from copy import deepcopy
@@ -10,7 +12,7 @@ JsonSchema = dict[str, Any]
10
12
 
11
13
 
12
14
  @dataclass(init=False)
13
- class WalkJsonSchema(ABC):
15
+ class JsonSchemaTransformer(ABC):
14
16
  """Walks a JSON schema, applying transformations to it at each level.
15
17
 
16
18
  Note: We may eventually want to rework tools to build the JSON schema from the type directly, using a subclass of
@@ -18,9 +20,18 @@ class WalkJsonSchema(ABC):
18
20
  """
19
21
 
20
22
  def __init__(
21
- self, schema: JsonSchema, *, prefer_inlined_defs: bool = False, simplify_nullable_unions: bool = False
23
+ self,
24
+ schema: JsonSchema,
25
+ *,
26
+ strict: bool | None = None,
27
+ prefer_inlined_defs: bool = False,
28
+ simplify_nullable_unions: bool = False,
22
29
  ):
23
30
  self.schema = schema
31
+
32
+ self.strict = strict
33
+ self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
34
+
24
35
  self.prefer_inlined_defs = prefer_inlined_defs
25
36
  self.simplify_nullable_unions = simplify_nullable_unions
26
37
 
@@ -164,3 +175,13 @@ class WalkJsonSchema(ABC):
164
175
  return [cases[0]]
165
176
 
166
177
  return cases # pragma: no cover
178
+
179
+
180
+ class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer):
181
+ """Transforms the JSON Schema to inline $defs."""
182
+
183
+ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
184
+ super().__init__(schema, strict=strict, prefer_inlined_defs=True)
185
+
186
+ def transform(self, schema: JsonSchema) -> JsonSchema:
187
+ return schema
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+ from ._json_schema import InlineDefsJsonSchemaTransformer
5
+
6
+
7
+ def amazon_model_profile(model_name: str) -> ModelProfile | None:
8
+ """Get the model profile for an Amazon model."""
9
+ return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def anthropic_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for an Anthropic model."""
8
+ return None
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def cohere_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for a Cohere model."""
8
+ return None
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def deepseek_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for a DeepSeek model."""
8
+ return None
@@ -0,0 +1,100 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import warnings
4
+
5
+ from pydantic_ai.exceptions import UserError
6
+
7
+ from . import ModelProfile
8
+ from ._json_schema import JsonSchema, JsonSchemaTransformer
9
+
10
+
11
+ def google_model_profile(model_name: str) -> ModelProfile | None:
12
+ """Get the model profile for a Google model."""
13
+ return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer)
14
+
15
+
16
+ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
17
+ """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
18
+
19
+ Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
20
+ a subset of OpenAPI v3.0.3.
21
+
22
+ Specifically:
23
+ * gemini doesn't allow the `title` keyword to be set
24
+ * gemini doesn't allow `$defs` — we need to inline the definitions where possible
25
+ """
26
+
27
+ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
28
+ super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=True)
29
+
30
+ def transform(self, schema: JsonSchema) -> JsonSchema:
31
+ # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
32
+ additional_properties = schema.pop(
33
+ 'additionalProperties', None
34
+ ) # don't pop yet so it's included in the warning
35
+ if additional_properties:
36
+ original_schema = {**schema, 'additionalProperties': additional_properties}
37
+ warnings.warn(
38
+ '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
39
+ f' Full schema: {self.schema}\n\n'
40
+ f'Source of additionalProperties within the full schema: {original_schema}\n\n'
41
+ 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
42
+ "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
43
+ ' and we will fix this behavior.',
44
+ UserWarning,
45
+ )
46
+
47
+ schema.pop('title', None)
48
+ schema.pop('default', None)
49
+ schema.pop('$schema', None)
50
+ if (const := schema.pop('const', None)) is not None: # pragma: no cover
51
+ # Gemini doesn't support const, but it does support enum with a single value
52
+ schema['enum'] = [const]
53
+ schema.pop('discriminator', None)
54
+ schema.pop('examples', None)
55
+
56
+ # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
57
+ # where we add notes about these properties to the field description?
58
+ schema.pop('exclusiveMaximum', None)
59
+ schema.pop('exclusiveMinimum', None)
60
+
61
+ # Gemini only supports string enums, so we need to convert any enum values to strings.
62
+ # Pydantic will take care of transforming the transformed string values to the correct type.
63
+ if enum := schema.get('enum'):
64
+ schema['type'] = 'string'
65
+ schema['enum'] = [str(val) for val in enum]
66
+
67
+ type_ = schema.get('type')
68
+ if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
69
+ # This gets hit when we have a discriminated union
70
+ # Gemini returns an API error in this case even though it says in its error message it shouldn't...
71
+ # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
72
+ schema['anyOf'] = schema.pop('oneOf')
73
+
74
+ if type_ == 'string' and (fmt := schema.pop('format', None)):
75
+ description = schema.get('description')
76
+ if description:
77
+ schema['description'] = f'{description} (format: {fmt})'
78
+ else:
79
+ schema['description'] = f'Format: {fmt}'
80
+
81
+ if '$ref' in schema:
82
+ raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
83
+
84
+ if 'prefixItems' in schema:
85
+ # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
86
+ prefix_items = schema.pop('prefixItems')
87
+ items = schema.get('items')
88
+ unique_items = [items] if items is not None else []
89
+ for item in prefix_items:
90
+ if item not in unique_items:
91
+ unique_items.append(item)
92
+ if len(unique_items) > 1: # pragma: no cover
93
+ schema['items'] = {'anyOf': unique_items}
94
+ elif len(unique_items) == 1: # pragma: no branch
95
+ schema['items'] = unique_items[0]
96
+ schema.setdefault('minItems', len(prefix_items))
97
+ if items is None: # pragma: no branch
98
+ schema.setdefault('maxItems', len(prefix_items))
99
+
100
+ return schema
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def grok_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for a Grok model."""
8
+ return None
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+ from ._json_schema import InlineDefsJsonSchemaTransformer
5
+
6
+
7
+ def meta_model_profile(model_name: str) -> ModelProfile | None:
8
+ """Get the model profile for a Meta model."""
9
+ return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+
5
+
6
+ def mistral_model_profile(model_name: str) -> ModelProfile | None:
7
+ """Get the model profile for a Mistral model."""
8
+ return None
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ from . import ModelProfile
8
+ from ._json_schema import JsonSchema, JsonSchemaTransformer
9
+
10
+
11
+ @dataclass
12
+ class OpenAIModelProfile(ModelProfile):
13
+ """Profile for models used with OpenAIModel.
14
+
15
+ ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
16
+ """
17
+
18
+ # This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions
19
+ openai_supports_strict_tool_definition: bool = True
20
+
21
+
22
+ def openai_model_profile(model_name: str) -> ModelProfile:
23
+ """Get the model profile for an OpenAI model."""
24
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer)
25
+
26
+
27
+ _STRICT_INCOMPATIBLE_KEYS = [
28
+ 'minLength',
29
+ 'maxLength',
30
+ 'pattern',
31
+ 'format',
32
+ 'minimum',
33
+ 'maximum',
34
+ 'multipleOf',
35
+ 'patternProperties',
36
+ 'unevaluatedProperties',
37
+ 'propertyNames',
38
+ 'minProperties',
39
+ 'maxProperties',
40
+ 'unevaluatedItems',
41
+ 'contains',
42
+ 'minContains',
43
+ 'maxContains',
44
+ 'minItems',
45
+ 'maxItems',
46
+ 'uniqueItems',
47
+ ]
48
+
49
+ _sentinel = object()
50
+
51
+
52
+ @dataclass
53
+ class OpenAIJsonSchemaTransformer(JsonSchemaTransformer):
54
+ """Recursively handle the schema to make it compatible with OpenAI strict mode.
55
+
56
+ See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
57
+ but this basically just requires:
58
+ * `additionalProperties` must be set to false for each object in the parameters
59
+ * all fields in properties must be marked as required
60
+ """
61
+
62
+ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
63
+ super().__init__(schema, strict=strict)
64
+ self.root_ref = schema.get('$ref')
65
+
66
+ def walk(self) -> JsonSchema:
67
+ # Note: OpenAI does not support anyOf at the root in strict mode
68
+ # However, we don't need to check for it here because we ensure in pydantic_ai._utils.check_object_json_schema
69
+ # that the root schema either has type 'object' or is recursive.
70
+ result = super().walk()
71
+
72
+ # For recursive models, we need to tweak the schema to make it compatible with strict mode.
73
+ # Because the following should never change the semantics of the schema we apply it unconditionally.
74
+ if self.root_ref is not None:
75
+ result.pop('$ref', None) # We replace references to the self.root_ref with just '#' in the transform method
76
+ root_key = re.sub(r'^#/\$defs/', '', self.root_ref)
77
+ result.update(self.defs.get(root_key) or {})
78
+
79
+ return result
80
+
81
+ def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
82
+ # Remove unnecessary keys
83
+ schema.pop('title', None)
84
+ schema.pop('default', None)
85
+ schema.pop('$schema', None)
86
+ schema.pop('discriminator', None)
87
+
88
+ if schema_ref := schema.get('$ref'):
89
+ if schema_ref == self.root_ref:
90
+ schema['$ref'] = '#'
91
+ if len(schema) > 1:
92
+ # OpenAI Strict mode doesn't support siblings to "$ref", but _does_ allow siblings to "anyOf".
93
+ # So if there is a "description" field or any other extra info, we move the "$ref" into an "anyOf":
94
+ schema['anyOf'] = [{'$ref': schema.pop('$ref')}]
95
+
96
+ # Track strict-incompatible keys
97
+ incompatible_values: dict[str, Any] = {}
98
+ for key in _STRICT_INCOMPATIBLE_KEYS:
99
+ value = schema.get(key, _sentinel)
100
+ if value is not _sentinel:
101
+ incompatible_values[key] = value
102
+ description = schema.get('description')
103
+ if incompatible_values:
104
+ if self.strict is True:
105
+ notes: list[str] = []
106
+ for key, value in incompatible_values.items():
107
+ schema.pop(key)
108
+ notes.append(f'{key}={value}')
109
+ notes_string = ', '.join(notes)
110
+ schema['description'] = notes_string if not description else f'{description} ({notes_string})'
111
+ elif self.strict is None: # pragma: no branch
112
+ self.is_strict_compatible = False
113
+
114
+ schema_type = schema.get('type')
115
+ if 'oneOf' in schema:
116
+ # OpenAI does not support oneOf in strict mode
117
+ if self.strict is True:
118
+ schema['anyOf'] = schema.pop('oneOf')
119
+ else:
120
+ self.is_strict_compatible = False
121
+
122
+ if schema_type == 'object':
123
+ if self.strict is True:
124
+ # additional properties are disallowed
125
+ schema['additionalProperties'] = False
126
+
127
+ # all properties are required
128
+ if 'properties' not in schema:
129
+ schema['properties'] = dict[str, Any]()
130
+ schema['required'] = list(schema['properties'].keys())
131
+
132
+ elif self.strict is None:
133
+ if (
134
+ schema.get('additionalProperties') is not False
135
+ or 'properties' not in schema
136
+ or 'required' not in schema
137
+ ):
138
+ self.is_strict_compatible = False
139
+ else:
140
+ required = schema['required']
141
+ for k in schema['properties'].keys():
142
+ if k not in required:
143
+ self.is_strict_compatible = False
144
+ return schema
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+ from ._json_schema import InlineDefsJsonSchemaTransformer
5
+
6
+
7
+ def qwen_model_profile(model_name: str) -> ModelProfile | None:
8
+ """Get the model profile for a Qwen model."""
9
+ return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
@@ -8,6 +8,8 @@ from __future__ import annotations as _annotations
8
8
  from abc import ABC, abstractmethod
9
9
  from typing import Any, Generic, TypeVar
10
10
 
11
+ from pydantic_ai.profiles import ModelProfile
12
+
11
13
  InterfaceClient = TypeVar('InterfaceClient')
12
14
 
13
15
 
@@ -41,6 +43,10 @@ class Provider(ABC, Generic[InterfaceClient]):
41
43
  """The client for the provider."""
42
44
  raise NotImplementedError()
43
45
 
46
+ def model_profile(self, model_name: str) -> ModelProfile | None:
47
+ """The model profile for the named model, if available."""
48
+ return None # pragma: no cover
49
+
44
50
 
45
51
  def infer_provider(provider: str) -> Provider[Any]:
46
52
  """Infer the provider from the provider name."""
@@ -89,5 +95,17 @@ def infer_provider(provider: str) -> Provider[Any]:
89
95
  from .cohere import CohereProvider
90
96
 
91
97
  return CohereProvider()
98
+ elif provider == 'grok':
99
+ from .grok import GrokProvider
100
+
101
+ return GrokProvider()
102
+ elif provider == 'fireworks':
103
+ from .fireworks import FireworksProvider
104
+
105
+ return FireworksProvider()
106
+ elif provider == 'together':
107
+ from .together import TogetherProvider
108
+
109
+ return TogetherProvider()
92
110
  else: # pragma: no cover
93
111
  raise ValueError(f'Unknown provider: {provider}')
@@ -7,6 +7,8 @@ import httpx
7
7
 
8
8
  from pydantic_ai.exceptions import UserError
9
9
  from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
10
12
  from pydantic_ai.providers import Provider
11
13
 
12
14
  try:
@@ -33,6 +35,9 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
33
35
  def client(self) -> AsyncAnthropic:
34
36
  return self._client
35
37
 
38
+ def model_profile(self, model_name: str) -> ModelProfile | None:
39
+ return anthropic_model_profile(model_name)
40
+
36
41
  @overload
37
42
  def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
38
43
 
@@ -8,6 +8,13 @@ from openai import AsyncOpenAI
8
8
 
9
9
  from pydantic_ai.exceptions import UserError
10
10
  from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.cohere import cohere_model_profile
13
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
14
+ from pydantic_ai.profiles.grok import grok_model_profile
15
+ from pydantic_ai.profiles.meta import meta_model_profile
16
+ from pydantic_ai.profiles.mistral import mistral_model_profile
17
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
11
18
  from pydantic_ai.providers import Provider
12
19
 
13
20
  try:
@@ -38,6 +45,33 @@ class AzureProvider(Provider[AsyncOpenAI]):
38
45
  def client(self) -> AsyncOpenAI:
39
46
  return self._client
40
47
 
48
+ def model_profile(self, model_name: str) -> ModelProfile | None:
49
+ model_name = model_name.lower()
50
+
51
+ prefix_to_profile = {
52
+ 'llama': meta_model_profile,
53
+ 'meta-': meta_model_profile,
54
+ 'deepseek': deepseek_model_profile,
55
+ 'mistralai-': mistral_model_profile,
56
+ 'mistral': mistral_model_profile,
57
+ 'cohere-': cohere_model_profile,
58
+ 'grok': grok_model_profile,
59
+ }
60
+
61
+ for prefix, profile_func in prefix_to_profile.items():
62
+ if model_name.startswith(prefix):
63
+ if prefix.endswith('-'):
64
+ model_name = model_name[len(prefix) :]
65
+
66
+ profile = profile_func(model_name)
67
+
68
+ # As AzureProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
69
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly
70
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
71
+
72
+ # OpenAI models are unprefixed
73
+ return openai_model_profile(model_name)
74
+
41
75
  @overload
42
76
  def __init__(self, *, openai_client: AsyncAzureOpenAI) -> None: ...
43
77
 
@@ -1,9 +1,18 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import overload
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Literal, overload
5
7
 
6
8
  from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.amazon import amazon_model_profile
11
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
12
+ from pydantic_ai.profiles.cohere import cohere_model_profile
13
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
14
+ from pydantic_ai.profiles.meta import meta_model_profile
15
+ from pydantic_ai.profiles.mistral import mistral_model_profile
7
16
  from pydantic_ai.providers import Provider
8
17
 
9
18
  try:
@@ -18,6 +27,17 @@ except ImportError as _import_error:
18
27
  ) from _import_error
19
28
 
20
29
 
30
+ @dataclass
31
+ class BedrockModelProfile(ModelProfile):
32
+ """Profile for models used with BedrockModel.
33
+
34
+ ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
35
+ """
36
+
37
+ bedrock_supports_tool_choice: bool = True
38
+ bedrock_tool_result_format: Literal['text', 'json'] = 'text'
39
+
40
+
21
41
  class BedrockProvider(Provider[BaseClient]):
22
42
  """Provider for AWS Bedrock."""
23
43
 
@@ -33,6 +53,45 @@ class BedrockProvider(Provider[BaseClient]):
33
53
  def client(self) -> BaseClient:
34
54
  return self._client
35
55
 
56
+ def model_profile(self, model_name: str) -> ModelProfile | None:
57
+ provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
58
+ 'anthropic': lambda model_name: BedrockModelProfile(bedrock_supports_tool_choice=False).update(
59
+ anthropic_model_profile(model_name)
60
+ ),
61
+ 'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
62
+ mistral_model_profile(model_name)
63
+ ),
64
+ 'cohere': cohere_model_profile,
65
+ 'amazon': amazon_model_profile,
66
+ 'meta': meta_model_profile,
67
+ 'deepseek': deepseek_model_profile,
68
+ }
69
+
70
+ # Split the model name into parts
71
+ parts = model_name.split('.', 2)
72
+
73
+ # Handle regional prefixes (e.g. "us.")
74
+ if len(parts) > 2 and len(parts[0]) == 2:
75
+ parts = parts[1:]
76
+
77
+ if len(parts) < 2:
78
+ return None
79
+
80
+ provider = parts[0]
81
+ model_name_with_version = parts[1]
82
+
83
+ # Remove version suffix if it matches the format (e.g. "-v1:0" or "-v14")
84
+ version_match = re.match(r'(.+)-v\d+(?::\d+)?$', model_name_with_version)
85
+ if version_match:
86
+ model_name = version_match.group(1)
87
+ else:
88
+ model_name = model_name_with_version
89
+
90
+ if provider in provider_to_profile:
91
+ return provider_to_profile[provider](model_name)
92
+
93
+ return None
94
+
36
95
  @overload
37
96
  def __init__(self, *, bedrock_client: BaseClient) -> None: ...
38
97
 
@@ -6,6 +6,8 @@ from httpx import AsyncClient as AsyncHTTPClient
6
6
 
7
7
  from pydantic_ai.exceptions import UserError
8
8
  from pydantic_ai.models import cached_async_http_client
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.cohere import cohere_model_profile
9
11
  from pydantic_ai.providers import Provider
10
12
 
11
13
  try:
@@ -33,6 +35,9 @@ class CohereProvider(Provider[AsyncClientV2]):
33
35
  def client(self) -> AsyncClientV2:
34
36
  return self._client
35
37
 
38
+ def model_profile(self, model_name: str) -> ModelProfile | None:
39
+ return cohere_model_profile(model_name)
40
+
36
41
  def __init__(
37
42
  self,
38
43
  *,
@@ -8,6 +8,9 @@ from openai import AsyncOpenAI
8
8
 
9
9
  from pydantic_ai.exceptions import UserError
10
10
  from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
13
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
11
14
  from pydantic_ai.providers import Provider
12
15
 
13
16
  try:
@@ -34,6 +37,15 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
34
37
  def client(self) -> AsyncOpenAI:
35
38
  return self._client
36
39
 
40
+ def model_profile(self, model_name: str) -> ModelProfile | None:
41
+ profile = deepseek_model_profile(model_name)
42
+
43
+ # As DeepSeekProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
44
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly.
45
+ # This was not the case when using a DeepSeek model with another model class (e.g. BedrockConverseModel or GroqModel),
46
+ # so we won't do this in `deepseek_model_profile` unless we learn it's always needed.
47
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
48
+
37
49
  @overload
38
50
  def __init__(self) -> None: ...
39
51