pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +29 -35
- pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
- pydantic_ai/_output.py +265 -118
- pydantic_ai/agent.py +15 -15
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +2 -2
- pydantic_ai/models/__init__.py +39 -3
- pydantic_ai/models/anthropic.py +4 -0
- pydantic_ai/models/bedrock.py +43 -16
- pydantic_ai/models/cohere.py +4 -0
- pydantic_ai/models/gemini.py +68 -108
- pydantic_ai/models/google.py +45 -110
- pydantic_ai/models/groq.py +17 -2
- pydantic_ai/models/mistral.py +4 -0
- pydantic_ai/models/openai.py +22 -157
- pydantic_ai/profiles/__init__.py +39 -0
- pydantic_ai/{models → profiles}/_json_schema.py +23 -2
- pydantic_ai/profiles/amazon.py +9 -0
- pydantic_ai/profiles/anthropic.py +8 -0
- pydantic_ai/profiles/cohere.py +8 -0
- pydantic_ai/profiles/deepseek.py +8 -0
- pydantic_ai/profiles/google.py +100 -0
- pydantic_ai/profiles/grok.py +8 -0
- pydantic_ai/profiles/meta.py +9 -0
- pydantic_ai/profiles/mistral.py +8 -0
- pydantic_ai/profiles/openai.py +144 -0
- pydantic_ai/profiles/qwen.py +9 -0
- pydantic_ai/providers/__init__.py +18 -0
- pydantic_ai/providers/anthropic.py +5 -0
- pydantic_ai/providers/azure.py +34 -0
- pydantic_ai/providers/bedrock.py +60 -1
- pydantic_ai/providers/cohere.py +5 -0
- pydantic_ai/providers/deepseek.py +12 -0
- pydantic_ai/providers/fireworks.py +99 -0
- pydantic_ai/providers/google.py +5 -0
- pydantic_ai/providers/google_gla.py +5 -0
- pydantic_ai/providers/google_vertex.py +5 -0
- pydantic_ai/providers/grok.py +82 -0
- pydantic_ai/providers/groq.py +25 -0
- pydantic_ai/providers/mistral.py +5 -0
- pydantic_ai/providers/openai.py +5 -0
- pydantic_ai/providers/openrouter.py +36 -0
- pydantic_ai/providers/together.py +96 -0
- pydantic_ai/result.py +34 -103
- pydantic_ai/tools.py +28 -58
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +4 -4
- pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
- pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -9,16 +9,19 @@ from __future__ import annotations as _annotations
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
10
|
from collections.abc import AsyncIterator, Iterator
|
|
11
11
|
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
-
from dataclasses import dataclass, field
|
|
12
|
+
from dataclasses import dataclass, field, replace
|
|
13
13
|
from datetime import datetime
|
|
14
|
-
from functools import cache
|
|
14
|
+
from functools import cache, cached_property
|
|
15
15
|
|
|
16
16
|
import httpx
|
|
17
17
|
from typing_extensions import Literal, TypeAliasType
|
|
18
18
|
|
|
19
|
+
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
20
|
+
|
|
19
21
|
from .._parts_manager import ModelResponsePartsManager
|
|
20
22
|
from ..exceptions import UserError
|
|
21
23
|
from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
|
|
24
|
+
from ..profiles._json_schema import JsonSchemaTransformer
|
|
22
25
|
from ..settings import ModelSettings
|
|
23
26
|
from ..tools import ToolDefinition
|
|
24
27
|
from ..usage import Usage
|
|
@@ -68,6 +71,10 @@ KnownModelName = TypeAliasType(
|
|
|
68
71
|
'bedrock:us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
|
69
72
|
'bedrock:anthropic.claude-3-7-sonnet-20250219-v1:0',
|
|
70
73
|
'bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0',
|
|
74
|
+
'bedrock:anthropic.claude-opus-4-20250514-v1:0',
|
|
75
|
+
'bedrock:us.anthropic.claude-opus-4-20250514-v1:0',
|
|
76
|
+
'bedrock:anthropic.claude-sonnet-4-20250514-v1:0',
|
|
77
|
+
'bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0',
|
|
71
78
|
'bedrock:cohere.command-text-v14',
|
|
72
79
|
'bedrock:cohere.command-r-v1:0',
|
|
73
80
|
'bedrock:cohere.command-r-plus-v1:0',
|
|
@@ -292,6 +299,8 @@ class ModelRequestParameters:
|
|
|
292
299
|
class Model(ABC):
|
|
293
300
|
"""Abstract class for a model."""
|
|
294
301
|
|
|
302
|
+
_profile: ModelProfileSpec | None = None
|
|
303
|
+
|
|
295
304
|
@abstractmethod
|
|
296
305
|
async def request(
|
|
297
306
|
self,
|
|
@@ -323,6 +332,13 @@ class Model(ABC):
|
|
|
323
332
|
In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
|
|
324
333
|
for vendor/model-specific reasons.
|
|
325
334
|
"""
|
|
335
|
+
if transformer := self.profile.json_schema_transformer:
|
|
336
|
+
model_request_parameters = replace(
|
|
337
|
+
model_request_parameters,
|
|
338
|
+
function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
|
|
339
|
+
output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
|
|
340
|
+
)
|
|
341
|
+
|
|
326
342
|
return model_request_parameters
|
|
327
343
|
|
|
328
344
|
@property
|
|
@@ -331,6 +347,18 @@ class Model(ABC):
|
|
|
331
347
|
"""The model name."""
|
|
332
348
|
raise NotImplementedError()
|
|
333
349
|
|
|
350
|
+
@cached_property
|
|
351
|
+
def profile(self) -> ModelProfile:
|
|
352
|
+
"""The model profile."""
|
|
353
|
+
_profile = self._profile
|
|
354
|
+
if callable(_profile):
|
|
355
|
+
_profile = _profile(self.model_name)
|
|
356
|
+
|
|
357
|
+
if _profile is None:
|
|
358
|
+
return DEFAULT_PROFILE
|
|
359
|
+
|
|
360
|
+
return _profile
|
|
361
|
+
|
|
334
362
|
@property
|
|
335
363
|
@abstractmethod
|
|
336
364
|
def system(self) -> str:
|
|
@@ -515,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
|
|
|
515
543
|
from .cohere import CohereModel
|
|
516
544
|
|
|
517
545
|
return CohereModel(model_name, provider=provider)
|
|
518
|
-
elif provider in ('
|
|
546
|
+
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
|
|
519
547
|
from .openai import OpenAIModel
|
|
520
548
|
|
|
521
549
|
return OpenAIModel(model_name, provider=provider)
|
|
@@ -584,3 +612,11 @@ def get_user_agent() -> str:
|
|
|
584
612
|
from .. import __version__
|
|
585
613
|
|
|
586
614
|
return f'pydantic-ai/{__version__}'
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinition):
|
|
618
|
+
schema_transformer = transformer(t.parameters_json_schema, strict=t.strict)
|
|
619
|
+
parameters_json_schema = schema_transformer.walk()
|
|
620
|
+
if t.strict is None:
|
|
621
|
+
t = replace(t, strict=schema_transformer.is_strict_compatible)
|
|
622
|
+
return replace(t, parameters_json_schema=parameters_json_schema)
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -27,6 +27,7 @@ from ..messages import (
|
|
|
27
27
|
ToolReturnPart,
|
|
28
28
|
UserPromptPart,
|
|
29
29
|
)
|
|
30
|
+
from ..profiles import ModelProfileSpec
|
|
30
31
|
from ..providers import Provider, infer_provider
|
|
31
32
|
from ..settings import ModelSettings
|
|
32
33
|
from ..tools import ToolDefinition
|
|
@@ -118,6 +119,7 @@ class AnthropicModel(Model):
|
|
|
118
119
|
model_name: AnthropicModelName,
|
|
119
120
|
*,
|
|
120
121
|
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
|
|
122
|
+
profile: ModelProfileSpec | None = None,
|
|
121
123
|
):
|
|
122
124
|
"""Initialize an Anthropic model.
|
|
123
125
|
|
|
@@ -126,12 +128,14 @@ class AnthropicModel(Model):
|
|
|
126
128
|
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
127
129
|
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
|
|
128
130
|
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
|
|
131
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
129
132
|
"""
|
|
130
133
|
self._model_name = model_name
|
|
131
134
|
|
|
132
135
|
if isinstance(provider, str):
|
|
133
136
|
provider = infer_provider(provider)
|
|
134
137
|
self.client = provider.client
|
|
138
|
+
self._profile = profile or provider.model_profile
|
|
135
139
|
|
|
136
140
|
@property
|
|
137
141
|
def base_url(self) -> str:
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -32,8 +32,15 @@ from pydantic_ai.messages import (
|
|
|
32
32
|
UserPromptPart,
|
|
33
33
|
VideoUrl,
|
|
34
34
|
)
|
|
35
|
-
from pydantic_ai.models import
|
|
35
|
+
from pydantic_ai.models import (
|
|
36
|
+
Model,
|
|
37
|
+
ModelRequestParameters,
|
|
38
|
+
StreamedResponse,
|
|
39
|
+
cached_async_http_client,
|
|
40
|
+
)
|
|
41
|
+
from pydantic_ai.profiles import ModelProfileSpec
|
|
36
42
|
from pydantic_ai.providers import Provider, infer_provider
|
|
43
|
+
from pydantic_ai.providers.bedrock import BedrockModelProfile
|
|
37
44
|
from pydantic_ai.settings import ModelSettings
|
|
38
45
|
from pydantic_ai.tools import ToolDefinition
|
|
39
46
|
|
|
@@ -56,6 +63,7 @@ if TYPE_CHECKING:
|
|
|
56
63
|
PromptVariableValuesTypeDef,
|
|
57
64
|
SystemContentBlockTypeDef,
|
|
58
65
|
ToolChoiceTypeDef,
|
|
66
|
+
ToolConfigurationTypeDef,
|
|
59
67
|
ToolTypeDef,
|
|
60
68
|
VideoBlockTypeDef,
|
|
61
69
|
)
|
|
@@ -85,6 +93,10 @@ LatestBedrockModelNames = Literal[
|
|
|
85
93
|
'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
|
86
94
|
'anthropic.claude-3-7-sonnet-20250219-v1:0',
|
|
87
95
|
'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
|
|
96
|
+
'anthropic.claude-opus-4-20250514-v1:0',
|
|
97
|
+
'us.anthropic.claude-opus-4-20250514-v1:0',
|
|
98
|
+
'anthropic.claude-sonnet-4-20250514-v1:0',
|
|
99
|
+
'us.anthropic.claude-sonnet-4-20250514-v1:0',
|
|
88
100
|
'cohere.command-text-v14',
|
|
89
101
|
'cohere.command-r-v1:0',
|
|
90
102
|
'cohere.command-r-plus-v1:0',
|
|
@@ -190,6 +202,7 @@ class BedrockConverseModel(Model):
|
|
|
190
202
|
model_name: BedrockModelName,
|
|
191
203
|
*,
|
|
192
204
|
provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
|
|
205
|
+
profile: ModelProfileSpec | None = None,
|
|
193
206
|
):
|
|
194
207
|
"""Initialize a Bedrock model.
|
|
195
208
|
|
|
@@ -200,12 +213,14 @@ class BedrockConverseModel(Model):
|
|
|
200
213
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
201
214
|
'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
|
|
202
215
|
created using the other parameters.
|
|
216
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
203
217
|
"""
|
|
204
218
|
self._model_name = model_name
|
|
205
219
|
|
|
206
220
|
if isinstance(provider, str):
|
|
207
221
|
provider = infer_provider(provider)
|
|
208
222
|
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
223
|
+
self._profile = profile or provider.model_profile
|
|
209
224
|
|
|
210
225
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
|
|
211
226
|
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
@@ -301,15 +316,6 @@ class BedrockConverseModel(Model):
|
|
|
301
316
|
model_settings: BedrockModelSettings | None,
|
|
302
317
|
model_request_parameters: ModelRequestParameters,
|
|
303
318
|
) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
|
|
304
|
-
tools = self._get_tools(model_request_parameters)
|
|
305
|
-
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
|
|
306
|
-
if not tools or not support_tools_choice:
|
|
307
|
-
tool_choice: ToolChoiceTypeDef = {}
|
|
308
|
-
elif not model_request_parameters.allow_text_output:
|
|
309
|
-
tool_choice = {'any': {}} # pragma: no cover
|
|
310
|
-
else:
|
|
311
|
-
tool_choice = {'auto': {}}
|
|
312
|
-
|
|
313
319
|
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
314
320
|
inference_config = self._map_inference_config(model_settings)
|
|
315
321
|
|
|
@@ -320,6 +326,10 @@ class BedrockConverseModel(Model):
|
|
|
320
326
|
'inferenceConfig': inference_config,
|
|
321
327
|
}
|
|
322
328
|
|
|
329
|
+
tool_config = self._map_tool_config(model_request_parameters)
|
|
330
|
+
if tool_config:
|
|
331
|
+
params['toolConfig'] = tool_config
|
|
332
|
+
|
|
323
333
|
# Bedrock supports a set of specific extra parameters
|
|
324
334
|
if model_settings:
|
|
325
335
|
if guardrail_config := model_settings.get('bedrock_guardrail_config', None):
|
|
@@ -337,11 +347,6 @@ class BedrockConverseModel(Model):
|
|
|
337
347
|
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
|
|
338
348
|
params['promptVariables'] = prompt_variables
|
|
339
349
|
|
|
340
|
-
if tools:
|
|
341
|
-
params['toolConfig'] = {'tools': tools}
|
|
342
|
-
if tool_choice:
|
|
343
|
-
params['toolConfig']['toolChoice'] = tool_choice
|
|
344
|
-
|
|
345
350
|
if stream:
|
|
346
351
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
|
347
352
|
model_response = model_response['stream']
|
|
@@ -367,6 +372,23 @@ class BedrockConverseModel(Model):
|
|
|
367
372
|
|
|
368
373
|
return inference_config
|
|
369
374
|
|
|
375
|
+
def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
|
|
376
|
+
tools = self._get_tools(model_request_parameters)
|
|
377
|
+
if not tools:
|
|
378
|
+
return None
|
|
379
|
+
|
|
380
|
+
tool_choice: ToolChoiceTypeDef
|
|
381
|
+
if not model_request_parameters.allow_text_output:
|
|
382
|
+
tool_choice = {'any': {}}
|
|
383
|
+
else:
|
|
384
|
+
tool_choice = {'auto': {}}
|
|
385
|
+
|
|
386
|
+
tool_config: ToolConfigurationTypeDef = {'tools': tools}
|
|
387
|
+
if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
|
|
388
|
+
tool_config['toolChoice'] = tool_choice
|
|
389
|
+
|
|
390
|
+
return tool_config
|
|
391
|
+
|
|
370
392
|
async def _map_messages(
|
|
371
393
|
self, messages: list[ModelMessage]
|
|
372
394
|
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
|
|
@@ -374,6 +396,7 @@ class BedrockConverseModel(Model):
|
|
|
374
396
|
|
|
375
397
|
Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
|
|
376
398
|
"""
|
|
399
|
+
profile = BedrockModelProfile.from_profile(self.profile)
|
|
377
400
|
system_prompt: list[SystemContentBlockTypeDef] = []
|
|
378
401
|
bedrock_messages: list[MessageUnionTypeDef] = []
|
|
379
402
|
document_count: Iterator[int] = count(1)
|
|
@@ -393,7 +416,11 @@ class BedrockConverseModel(Model):
|
|
|
393
416
|
{
|
|
394
417
|
'toolResult': {
|
|
395
418
|
'toolUseId': part.tool_call_id,
|
|
396
|
-
'content': [
|
|
419
|
+
'content': [
|
|
420
|
+
{'text': part.model_response_str()}
|
|
421
|
+
if profile.bedrock_tool_result_format == 'text'
|
|
422
|
+
else {'json': part.model_response_object()}
|
|
423
|
+
],
|
|
397
424
|
'status': 'success',
|
|
398
425
|
}
|
|
399
426
|
}
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
ToolReturnPart,
|
|
21
21
|
UserPromptPart,
|
|
22
22
|
)
|
|
23
|
+
from ..profiles import ModelProfileSpec
|
|
23
24
|
from ..providers import Provider, infer_provider
|
|
24
25
|
from ..settings import ModelSettings
|
|
25
26
|
from ..tools import ToolDefinition
|
|
@@ -107,6 +108,7 @@ class CohereModel(Model):
|
|
|
107
108
|
model_name: CohereModelName,
|
|
108
109
|
*,
|
|
109
110
|
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
|
|
111
|
+
profile: ModelProfileSpec | None = None,
|
|
110
112
|
):
|
|
111
113
|
"""Initialize an Cohere model.
|
|
112
114
|
|
|
@@ -116,12 +118,14 @@ class CohereModel(Model):
|
|
|
116
118
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
117
119
|
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
118
120
|
created using the other parameters.
|
|
121
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
119
122
|
"""
|
|
120
123
|
self._model_name = model_name
|
|
121
124
|
|
|
122
125
|
if isinstance(provider, str):
|
|
123
126
|
provider = infer_provider(provider)
|
|
124
127
|
self.client = provider.client
|
|
128
|
+
self._profile = profile or provider.model_profile
|
|
125
129
|
|
|
126
130
|
@property
|
|
127
131
|
def base_url(self) -> str:
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import warnings
|
|
5
4
|
from collections.abc import AsyncIterator, Sequence
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime
|
|
9
8
|
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
10
9
|
from uuid import uuid4
|
|
@@ -16,7 +15,7 @@ from typing_extensions import NotRequired, TypedDict, assert_never
|
|
|
16
15
|
|
|
17
16
|
from pydantic_ai.providers import Provider, infer_provider
|
|
18
17
|
|
|
19
|
-
from .. import ModelHTTPError, UnexpectedModelBehavior,
|
|
18
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
20
19
|
from ..messages import (
|
|
21
20
|
AudioUrl,
|
|
22
21
|
BinaryContent,
|
|
@@ -35,6 +34,7 @@ from ..messages import (
|
|
|
35
34
|
UserPromptPart,
|
|
36
35
|
VideoUrl,
|
|
37
36
|
)
|
|
37
|
+
from ..profiles import ModelProfileSpec
|
|
38
38
|
from ..settings import ModelSettings
|
|
39
39
|
from ..tools import ToolDefinition
|
|
40
40
|
from . import (
|
|
@@ -45,7 +45,6 @@ from . import (
|
|
|
45
45
|
check_allow_model_requests,
|
|
46
46
|
get_user_agent,
|
|
47
47
|
)
|
|
48
|
-
from ._json_schema import JsonSchema, WalkJsonSchema
|
|
49
48
|
|
|
50
49
|
LatestGeminiModelNames = Literal[
|
|
51
50
|
'gemini-1.5-flash',
|
|
@@ -121,6 +120,7 @@ class GeminiModel(Model):
|
|
|
121
120
|
model_name: GeminiModelName,
|
|
122
121
|
*,
|
|
123
122
|
provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
|
|
123
|
+
profile: ModelProfileSpec | None = None,
|
|
124
124
|
):
|
|
125
125
|
"""Initialize a Gemini model.
|
|
126
126
|
|
|
@@ -129,6 +129,7 @@ class GeminiModel(Model):
|
|
|
129
129
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
130
130
|
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
131
131
|
If not provided, a new provider will be created using the other parameters.
|
|
132
|
+
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
|
|
132
133
|
"""
|
|
133
134
|
self._model_name = model_name
|
|
134
135
|
self._provider = provider
|
|
@@ -138,6 +139,7 @@ class GeminiModel(Model):
|
|
|
138
139
|
self._system = provider.name
|
|
139
140
|
self.client = provider.client
|
|
140
141
|
self._url = str(self.client.base_url)
|
|
142
|
+
self._profile = profile or provider.model_profile
|
|
141
143
|
|
|
142
144
|
@property
|
|
143
145
|
def base_url(self) -> str:
|
|
@@ -171,16 +173,6 @@ class GeminiModel(Model):
|
|
|
171
173
|
) as http_response:
|
|
172
174
|
yield await self._process_streamed_response(http_response)
|
|
173
175
|
|
|
174
|
-
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
175
|
-
def _customize_tool_def(t: ToolDefinition):
|
|
176
|
-
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
|
|
177
|
-
|
|
178
|
-
return ModelRequestParameters(
|
|
179
|
-
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
180
|
-
allow_text_output=model_request_parameters.allow_text_output,
|
|
181
|
-
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
|
|
182
|
-
)
|
|
183
|
-
|
|
184
176
|
@property
|
|
185
177
|
def model_name(self) -> GeminiModelName:
|
|
186
178
|
"""The model name."""
|
|
@@ -259,6 +251,8 @@ class GeminiModel(Model):
|
|
|
259
251
|
yield r
|
|
260
252
|
|
|
261
253
|
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
|
|
254
|
+
vendor_details: dict[str, Any] | None = None
|
|
255
|
+
|
|
262
256
|
if len(response['candidates']) != 1:
|
|
263
257
|
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
|
|
264
258
|
if 'content' not in response['candidates'][0]:
|
|
@@ -269,9 +263,19 @@ class GeminiModel(Model):
|
|
|
269
263
|
'Content field missing from Gemini response', str(response)
|
|
270
264
|
)
|
|
271
265
|
parts = response['candidates'][0]['content']['parts']
|
|
266
|
+
vendor_id = response.get('vendor_id', None)
|
|
267
|
+
finish_reason = response['candidates'][0].get('finish_reason')
|
|
268
|
+
if finish_reason:
|
|
269
|
+
vendor_details = {'finish_reason': finish_reason}
|
|
272
270
|
usage = _metadata_as_usage(response)
|
|
273
271
|
usage.requests = 1
|
|
274
|
-
return _process_response_from_parts(
|
|
272
|
+
return _process_response_from_parts(
|
|
273
|
+
parts,
|
|
274
|
+
response.get('model_version', self._model_name),
|
|
275
|
+
usage,
|
|
276
|
+
vendor_id=vendor_id,
|
|
277
|
+
vendor_details=vendor_details,
|
|
278
|
+
)
|
|
275
279
|
|
|
276
280
|
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
|
|
277
281
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
@@ -451,13 +455,12 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
451
455
|
responses_to_yield = gemini_responses[:-1]
|
|
452
456
|
for r in responses_to_yield[current_gemini_response_index:]:
|
|
453
457
|
current_gemini_response_index += 1
|
|
454
|
-
self._usage += _metadata_as_usage(r)
|
|
455
458
|
yield r
|
|
456
459
|
|
|
457
460
|
# Now yield the final response, which should be complete
|
|
458
461
|
if gemini_responses: # pragma: no branch
|
|
459
462
|
r = gemini_responses[-1]
|
|
460
|
-
self._usage
|
|
463
|
+
self._usage = _metadata_as_usage(r)
|
|
461
464
|
yield r
|
|
462
465
|
|
|
463
466
|
@property
|
|
@@ -611,7 +614,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
|
|
|
611
614
|
|
|
612
615
|
|
|
613
616
|
def _process_response_from_parts(
|
|
614
|
-
parts: Sequence[_GeminiPartUnion],
|
|
617
|
+
parts: Sequence[_GeminiPartUnion],
|
|
618
|
+
model_name: GeminiModelName,
|
|
619
|
+
usage: usage.Usage,
|
|
620
|
+
vendor_id: str | None,
|
|
621
|
+
vendor_details: dict[str, Any] | None = None,
|
|
615
622
|
) -> ModelResponse:
|
|
616
623
|
items: list[ModelResponsePart] = []
|
|
617
624
|
for part in parts:
|
|
@@ -623,7 +630,9 @@ def _process_response_from_parts(
|
|
|
623
630
|
raise UnexpectedModelBehavior(
|
|
624
631
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
625
632
|
)
|
|
626
|
-
return ModelResponse(
|
|
633
|
+
return ModelResponse(
|
|
634
|
+
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
|
|
635
|
+
)
|
|
627
636
|
|
|
628
637
|
|
|
629
638
|
class _GeminiFunctionCall(TypedDict):
|
|
@@ -735,6 +744,7 @@ class _GeminiResponse(TypedDict):
|
|
|
735
744
|
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
|
|
736
745
|
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
737
746
|
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
|
|
747
|
+
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]
|
|
738
748
|
|
|
739
749
|
|
|
740
750
|
class _GeminiCandidates(TypedDict):
|
|
@@ -751,8 +761,17 @@ class _GeminiCandidates(TypedDict):
|
|
|
751
761
|
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
|
|
752
762
|
|
|
753
763
|
|
|
764
|
+
class _GeminiModalityTokenCount(TypedDict):
|
|
765
|
+
"""See <https://ai.google.dev/api/generate-content#modalitytokencount>."""
|
|
766
|
+
|
|
767
|
+
modality: Annotated[
|
|
768
|
+
Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality')
|
|
769
|
+
]
|
|
770
|
+
token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)]
|
|
771
|
+
|
|
772
|
+
|
|
754
773
|
class _GeminiUsageMetaData(TypedDict, total=False):
|
|
755
|
-
"""See <https://ai.google.dev/api/generate-content#
|
|
774
|
+
"""See <https://ai.google.dev/api/generate-content#UsageMetadata>.
|
|
756
775
|
|
|
757
776
|
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
|
|
758
777
|
"""
|
|
@@ -761,6 +780,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
761
780
|
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
|
|
762
781
|
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
|
|
763
782
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
783
|
+
thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]]
|
|
784
|
+
tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]]
|
|
785
|
+
prompt_tokens_details: NotRequired[
|
|
786
|
+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')]
|
|
787
|
+
]
|
|
788
|
+
cache_tokens_details: NotRequired[
|
|
789
|
+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')]
|
|
790
|
+
]
|
|
791
|
+
candidates_tokens_details: NotRequired[
|
|
792
|
+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')]
|
|
793
|
+
]
|
|
794
|
+
tool_use_prompt_tokens_details: NotRequired[
|
|
795
|
+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')]
|
|
796
|
+
]
|
|
764
797
|
|
|
765
798
|
|
|
766
799
|
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
@@ -769,7 +802,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
|
769
802
|
return usage.Usage() # pragma: no cover
|
|
770
803
|
details: dict[str, int] = {}
|
|
771
804
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
772
|
-
details['
|
|
805
|
+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
|
|
806
|
+
|
|
807
|
+
if thoughts_token_count := metadata.get('thoughts_token_count'):
|
|
808
|
+
details['thoughts_tokens'] = thoughts_token_count
|
|
809
|
+
|
|
810
|
+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
|
|
811
|
+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
|
|
812
|
+
|
|
813
|
+
for key, metadata_details in metadata.items():
|
|
814
|
+
if key.endswith('_details') and metadata_details:
|
|
815
|
+
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
|
|
816
|
+
suffix = key.removesuffix('_details')
|
|
817
|
+
for detail in metadata_details:
|
|
818
|
+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
|
|
819
|
+
|
|
773
820
|
return usage.Usage(
|
|
774
821
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
775
822
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
@@ -806,93 +853,6 @@ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
|
|
|
806
853
|
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
|
|
807
854
|
|
|
808
855
|
|
|
809
|
-
class _GeminiJsonSchema(WalkJsonSchema):
|
|
810
|
-
"""Transforms the JSON Schema from Pydantic to be suitable for Gemini.
|
|
811
|
-
|
|
812
|
-
Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
|
|
813
|
-
a subset of OpenAPI v3.0.3.
|
|
814
|
-
|
|
815
|
-
Specifically:
|
|
816
|
-
* gemini doesn't allow the `title` keyword to be set
|
|
817
|
-
* gemini doesn't allow `$defs` — we need to inline the definitions where possible
|
|
818
|
-
"""
|
|
819
|
-
|
|
820
|
-
def __init__(self, schema: JsonSchema):
|
|
821
|
-
super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
|
|
822
|
-
|
|
823
|
-
def transform(self, schema: JsonSchema) -> JsonSchema:
|
|
824
|
-
# Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
|
|
825
|
-
additional_properties = schema.pop(
|
|
826
|
-
'additionalProperties', None
|
|
827
|
-
) # don't pop yet so it's included in the warning
|
|
828
|
-
if additional_properties:
|
|
829
|
-
original_schema = {**schema, 'additionalProperties': additional_properties}
|
|
830
|
-
warnings.warn(
|
|
831
|
-
'`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
|
|
832
|
-
f' Full schema: {self.schema}\n\n'
|
|
833
|
-
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
834
|
-
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
835
|
-
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
|
|
836
|
-
' and we will fix this behavior.',
|
|
837
|
-
UserWarning,
|
|
838
|
-
)
|
|
839
|
-
|
|
840
|
-
schema.pop('title', None)
|
|
841
|
-
schema.pop('default', None)
|
|
842
|
-
schema.pop('$schema', None)
|
|
843
|
-
if (const := schema.pop('const', None)) is not None: # pragma: no cover
|
|
844
|
-
# Gemini doesn't support const, but it does support enum with a single value
|
|
845
|
-
schema['enum'] = [const]
|
|
846
|
-
schema.pop('discriminator', None)
|
|
847
|
-
schema.pop('examples', None)
|
|
848
|
-
|
|
849
|
-
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
|
|
850
|
-
# where we add notes about these properties to the field description?
|
|
851
|
-
schema.pop('exclusiveMaximum', None)
|
|
852
|
-
schema.pop('exclusiveMinimum', None)
|
|
853
|
-
|
|
854
|
-
# Gemini only supports string enums, so we need to convert any enum values to strings.
|
|
855
|
-
# Pydantic will take care of transforming the transformed string values to the correct type.
|
|
856
|
-
if enum := schema.get('enum'):
|
|
857
|
-
schema['type'] = 'string'
|
|
858
|
-
schema['enum'] = [str(val) for val in enum]
|
|
859
|
-
|
|
860
|
-
type_ = schema.get('type')
|
|
861
|
-
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
|
|
862
|
-
# This gets hit when we have a discriminated union
|
|
863
|
-
# Gemini returns an API error in this case even though it says in its error message it shouldn't...
|
|
864
|
-
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
|
|
865
|
-
schema['anyOf'] = schema.pop('oneOf')
|
|
866
|
-
|
|
867
|
-
if type_ == 'string' and (fmt := schema.pop('format', None)):
|
|
868
|
-
description = schema.get('description')
|
|
869
|
-
if description:
|
|
870
|
-
schema['description'] = f'{description} (format: {fmt})'
|
|
871
|
-
else:
|
|
872
|
-
schema['description'] = f'Format: {fmt}'
|
|
873
|
-
|
|
874
|
-
if '$ref' in schema:
|
|
875
|
-
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
|
|
876
|
-
|
|
877
|
-
if 'prefixItems' in schema:
|
|
878
|
-
# prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
|
|
879
|
-
prefix_items = schema.pop('prefixItems')
|
|
880
|
-
items = schema.get('items')
|
|
881
|
-
unique_items = [items] if items is not None else []
|
|
882
|
-
for item in prefix_items:
|
|
883
|
-
if item not in unique_items:
|
|
884
|
-
unique_items.append(item)
|
|
885
|
-
if len(unique_items) > 1: # pragma: no cover
|
|
886
|
-
schema['items'] = {'anyOf': unique_items}
|
|
887
|
-
elif len(unique_items) == 1: # pragma: no branch
|
|
888
|
-
schema['items'] = unique_items[0]
|
|
889
|
-
schema.setdefault('minItems', len(prefix_items))
|
|
890
|
-
if items is None: # pragma: no branch
|
|
891
|
-
schema.setdefault('maxItems', len(prefix_items))
|
|
892
|
-
|
|
893
|
-
return schema
|
|
894
|
-
|
|
895
|
-
|
|
896
856
|
def _ensure_decodeable(content: bytearray) -> bytearray:
|
|
897
857
|
"""Trim any invalid unicode point bytes off the end of a bytearray.
|
|
898
858
|
|