pydantic-ai-slim 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +265 -118
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +6 -1
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +68 -108
  12. pydantic_ai/models/google.py +45 -110
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +22 -157
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +28 -58
  46. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/METADATA +5 -5
  47. pydantic_ai_slim-0.2.12.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.10.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.10.dist-info → pydantic_ai_slim-0.2.12.dist-info}/licenses/LICENSE +0 -0
@@ -9,16 +9,19 @@ from __future__ import annotations as _annotations
9
9
  from abc import ABC, abstractmethod
10
10
  from collections.abc import AsyncIterator, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
- from dataclasses import dataclass, field
12
+ from dataclasses import dataclass, field, replace
13
13
  from datetime import datetime
14
- from functools import cache
14
+ from functools import cache, cached_property
15
15
 
16
16
  import httpx
17
17
  from typing_extensions import Literal, TypeAliasType
18
18
 
19
+ from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
20
+
19
21
  from .._parts_manager import ModelResponsePartsManager
20
22
  from ..exceptions import UserError
21
23
  from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
24
+ from ..profiles._json_schema import JsonSchemaTransformer
22
25
  from ..settings import ModelSettings
23
26
  from ..tools import ToolDefinition
24
27
  from ..usage import Usage
@@ -68,6 +71,10 @@ KnownModelName = TypeAliasType(
68
71
  'bedrock:us.anthropic.claude-3-5-sonnet-20240620-v1:0',
69
72
  'bedrock:anthropic.claude-3-7-sonnet-20250219-v1:0',
70
73
  'bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0',
74
+ 'bedrock:anthropic.claude-opus-4-20250514-v1:0',
75
+ 'bedrock:us.anthropic.claude-opus-4-20250514-v1:0',
76
+ 'bedrock:anthropic.claude-sonnet-4-20250514-v1:0',
77
+ 'bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0',
71
78
  'bedrock:cohere.command-text-v14',
72
79
  'bedrock:cohere.command-r-v1:0',
73
80
  'bedrock:cohere.command-r-plus-v1:0',
@@ -292,6 +299,8 @@ class ModelRequestParameters:
292
299
  class Model(ABC):
293
300
  """Abstract class for a model."""
294
301
 
302
+ _profile: ModelProfileSpec | None = None
303
+
295
304
  @abstractmethod
296
305
  async def request(
297
306
  self,
@@ -323,6 +332,13 @@ class Model(ABC):
323
332
  In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
324
333
  for vendor/model-specific reasons.
325
334
  """
335
+ if transformer := self.profile.json_schema_transformer:
336
+ model_request_parameters = replace(
337
+ model_request_parameters,
338
+ function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
339
+ output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
340
+ )
341
+
326
342
  return model_request_parameters
327
343
 
328
344
  @property
@@ -331,6 +347,18 @@ class Model(ABC):
331
347
  """The model name."""
332
348
  raise NotImplementedError()
333
349
 
350
+ @cached_property
351
+ def profile(self) -> ModelProfile:
352
+ """The model profile."""
353
+ _profile = self._profile
354
+ if callable(_profile):
355
+ _profile = _profile(self.model_name)
356
+
357
+ if _profile is None:
358
+ return DEFAULT_PROFILE
359
+
360
+ return _profile
361
+
334
362
  @property
335
363
  @abstractmethod
336
364
  def system(self) -> str:
@@ -515,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
515
543
  from .cohere import CohereModel
516
544
 
517
545
  return CohereModel(model_name, provider=provider)
518
- elif provider in ('deepseek', 'openai', 'azure', 'openrouter'):
546
+ elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
519
547
  from .openai import OpenAIModel
520
548
 
521
549
  return OpenAIModel(model_name, provider=provider)
@@ -584,3 +612,11 @@ def get_user_agent() -> str:
584
612
  from .. import __version__
585
613
 
586
614
  return f'pydantic-ai/{__version__}'
615
+
616
+
617
+ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinition):
618
+ schema_transformer = transformer(t.parameters_json_schema, strict=t.strict)
619
+ parameters_json_schema = schema_transformer.walk()
620
+ if t.strict is None:
621
+ t = replace(t, strict=schema_transformer.is_strict_compatible)
622
+ return replace(t, parameters_json_schema=parameters_json_schema)
@@ -27,6 +27,7 @@ from ..messages import (
27
27
  ToolReturnPart,
28
28
  UserPromptPart,
29
29
  )
30
+ from ..profiles import ModelProfileSpec
30
31
  from ..providers import Provider, infer_provider
31
32
  from ..settings import ModelSettings
32
33
  from ..tools import ToolDefinition
@@ -118,6 +119,7 @@ class AnthropicModel(Model):
118
119
  model_name: AnthropicModelName,
119
120
  *,
120
121
  provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
122
+ profile: ModelProfileSpec | None = None,
121
123
  ):
122
124
  """Initialize an Anthropic model.
123
125
 
@@ -126,12 +128,14 @@ class AnthropicModel(Model):
126
128
  [here](https://docs.anthropic.com/en/docs/about-claude/models).
127
129
  provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
128
130
  instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
131
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
129
132
  """
130
133
  self._model_name = model_name
131
134
 
132
135
  if isinstance(provider, str):
133
136
  provider = infer_provider(provider)
134
137
  self.client = provider.client
138
+ self._profile = profile or provider.model_profile
135
139
 
136
140
  @property
137
141
  def base_url(self) -> str:
@@ -312,7 +316,8 @@ class AnthropicModel(Model):
312
316
  is_error=True,
313
317
  )
314
318
  user_content_params.append(retry_param)
315
- anthropic_messages.append(BetaMessageParam(role='user', content=user_content_params))
319
+ if len(user_content_params) > 0:
320
+ anthropic_messages.append(BetaMessageParam(role='user', content=user_content_params))
316
321
  elif isinstance(m, ModelResponse):
317
322
  assistant_content_params: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
318
323
  for response_part in m.parts:
@@ -32,8 +32,15 @@ from pydantic_ai.messages import (
32
32
  UserPromptPart,
33
33
  VideoUrl,
34
34
  )
35
- from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
35
+ from pydantic_ai.models import (
36
+ Model,
37
+ ModelRequestParameters,
38
+ StreamedResponse,
39
+ cached_async_http_client,
40
+ )
41
+ from pydantic_ai.profiles import ModelProfileSpec
36
42
  from pydantic_ai.providers import Provider, infer_provider
43
+ from pydantic_ai.providers.bedrock import BedrockModelProfile
37
44
  from pydantic_ai.settings import ModelSettings
38
45
  from pydantic_ai.tools import ToolDefinition
39
46
 
@@ -56,6 +63,7 @@ if TYPE_CHECKING:
56
63
  PromptVariableValuesTypeDef,
57
64
  SystemContentBlockTypeDef,
58
65
  ToolChoiceTypeDef,
66
+ ToolConfigurationTypeDef,
59
67
  ToolTypeDef,
60
68
  VideoBlockTypeDef,
61
69
  )
@@ -85,6 +93,10 @@ LatestBedrockModelNames = Literal[
85
93
  'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
86
94
  'anthropic.claude-3-7-sonnet-20250219-v1:0',
87
95
  'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
96
+ 'anthropic.claude-opus-4-20250514-v1:0',
97
+ 'us.anthropic.claude-opus-4-20250514-v1:0',
98
+ 'anthropic.claude-sonnet-4-20250514-v1:0',
99
+ 'us.anthropic.claude-sonnet-4-20250514-v1:0',
88
100
  'cohere.command-text-v14',
89
101
  'cohere.command-r-v1:0',
90
102
  'cohere.command-r-plus-v1:0',
@@ -190,6 +202,7 @@ class BedrockConverseModel(Model):
190
202
  model_name: BedrockModelName,
191
203
  *,
192
204
  provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
205
+ profile: ModelProfileSpec | None = None,
193
206
  ):
194
207
  """Initialize a Bedrock model.
195
208
 
@@ -200,12 +213,14 @@ class BedrockConverseModel(Model):
200
213
  provider: The provider to use for authentication and API access. Can be either the string
201
214
  'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
202
215
  created using the other parameters.
216
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
203
217
  """
204
218
  self._model_name = model_name
205
219
 
206
220
  if isinstance(provider, str):
207
221
  provider = infer_provider(provider)
208
222
  self.client = cast('BedrockRuntimeClient', provider.client)
223
+ self._profile = profile or provider.model_profile
209
224
 
210
225
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
211
226
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
@@ -301,15 +316,6 @@ class BedrockConverseModel(Model):
301
316
  model_settings: BedrockModelSettings | None,
302
317
  model_request_parameters: ModelRequestParameters,
303
318
  ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
304
- tools = self._get_tools(model_request_parameters)
305
- support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
306
- if not tools or not support_tools_choice:
307
- tool_choice: ToolChoiceTypeDef = {}
308
- elif not model_request_parameters.allow_text_output:
309
- tool_choice = {'any': {}} # pragma: no cover
310
- else:
311
- tool_choice = {'auto': {}}
312
-
313
319
  system_prompt, bedrock_messages = await self._map_messages(messages)
314
320
  inference_config = self._map_inference_config(model_settings)
315
321
 
@@ -320,6 +326,10 @@ class BedrockConverseModel(Model):
320
326
  'inferenceConfig': inference_config,
321
327
  }
322
328
 
329
+ tool_config = self._map_tool_config(model_request_parameters)
330
+ if tool_config:
331
+ params['toolConfig'] = tool_config
332
+
323
333
  # Bedrock supports a set of specific extra parameters
324
334
  if model_settings:
325
335
  if guardrail_config := model_settings.get('bedrock_guardrail_config', None):
@@ -337,11 +347,6 @@ class BedrockConverseModel(Model):
337
347
  if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
338
348
  params['promptVariables'] = prompt_variables
339
349
 
340
- if tools:
341
- params['toolConfig'] = {'tools': tools}
342
- if tool_choice:
343
- params['toolConfig']['toolChoice'] = tool_choice
344
-
345
350
  if stream:
346
351
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
347
352
  model_response = model_response['stream']
@@ -367,6 +372,23 @@ class BedrockConverseModel(Model):
367
372
 
368
373
  return inference_config
369
374
 
375
+ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
376
+ tools = self._get_tools(model_request_parameters)
377
+ if not tools:
378
+ return None
379
+
380
+ tool_choice: ToolChoiceTypeDef
381
+ if not model_request_parameters.allow_text_output:
382
+ tool_choice = {'any': {}}
383
+ else:
384
+ tool_choice = {'auto': {}}
385
+
386
+ tool_config: ToolConfigurationTypeDef = {'tools': tools}
387
+ if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
388
+ tool_config['toolChoice'] = tool_choice
389
+
390
+ return tool_config
391
+
370
392
  async def _map_messages(
371
393
  self, messages: list[ModelMessage]
372
394
  ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
@@ -374,6 +396,7 @@ class BedrockConverseModel(Model):
374
396
 
375
397
  Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
376
398
  """
399
+ profile = BedrockModelProfile.from_profile(self.profile)
377
400
  system_prompt: list[SystemContentBlockTypeDef] = []
378
401
  bedrock_messages: list[MessageUnionTypeDef] = []
379
402
  document_count: Iterator[int] = count(1)
@@ -393,7 +416,11 @@ class BedrockConverseModel(Model):
393
416
  {
394
417
  'toolResult': {
395
418
  'toolUseId': part.tool_call_id,
396
- 'content': [{'text': part.model_response_str()}],
419
+ 'content': [
420
+ {'text': part.model_response_str()}
421
+ if profile.bedrock_tool_result_format == 'text'
422
+ else {'json': part.model_response_object()}
423
+ ],
397
424
  'status': 'success',
398
425
  }
399
426
  }
@@ -20,6 +20,7 @@ from ..messages import (
20
20
  ToolReturnPart,
21
21
  UserPromptPart,
22
22
  )
23
+ from ..profiles import ModelProfileSpec
23
24
  from ..providers import Provider, infer_provider
24
25
  from ..settings import ModelSettings
25
26
  from ..tools import ToolDefinition
@@ -107,6 +108,7 @@ class CohereModel(Model):
107
108
  model_name: CohereModelName,
108
109
  *,
109
110
  provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111
+ profile: ModelProfileSpec | None = None,
110
112
  ):
111
113
  """Initialize an Cohere model.
112
114
 
@@ -116,12 +118,14 @@ class CohereModel(Model):
116
118
  provider: The provider to use for authentication and API access. Can be either the string
117
119
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
118
120
  created using the other parameters.
121
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
119
122
  """
120
123
  self._model_name = model_name
121
124
 
122
125
  if isinstance(provider, str):
123
126
  provider = infer_provider(provider)
124
127
  self.client = provider.client
128
+ self._profile = profile or provider.model_profile
125
129
 
126
130
  @property
127
131
  def base_url(self) -> str:
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import warnings
5
4
  from collections.abc import AsyncIterator, Sequence
6
5
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field, replace
6
+ from dataclasses import dataclass, field
8
7
  from datetime import datetime
9
8
  from typing import Annotated, Any, Literal, Protocol, Union, cast
10
9
  from uuid import uuid4
@@ -16,7 +15,7 @@ from typing_extensions import NotRequired, TypedDict, assert_never
16
15
 
17
16
  from pydantic_ai.providers import Provider, infer_provider
18
17
 
19
- from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
18
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
20
19
  from ..messages import (
21
20
  AudioUrl,
22
21
  BinaryContent,
@@ -35,6 +34,7 @@ from ..messages import (
35
34
  UserPromptPart,
36
35
  VideoUrl,
37
36
  )
37
+ from ..profiles import ModelProfileSpec
38
38
  from ..settings import ModelSettings
39
39
  from ..tools import ToolDefinition
40
40
  from . import (
@@ -45,7 +45,6 @@ from . import (
45
45
  check_allow_model_requests,
46
46
  get_user_agent,
47
47
  )
48
- from ._json_schema import JsonSchema, WalkJsonSchema
49
48
 
50
49
  LatestGeminiModelNames = Literal[
51
50
  'gemini-1.5-flash',
@@ -121,6 +120,7 @@ class GeminiModel(Model):
121
120
  model_name: GeminiModelName,
122
121
  *,
123
122
  provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
123
+ profile: ModelProfileSpec | None = None,
124
124
  ):
125
125
  """Initialize a Gemini model.
126
126
 
@@ -129,6 +129,7 @@ class GeminiModel(Model):
129
129
  provider: The provider to use for authentication and API access. Can be either the string
130
130
  'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
131
131
  If not provided, a new provider will be created using the other parameters.
132
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
132
133
  """
133
134
  self._model_name = model_name
134
135
  self._provider = provider
@@ -138,6 +139,7 @@ class GeminiModel(Model):
138
139
  self._system = provider.name
139
140
  self.client = provider.client
140
141
  self._url = str(self.client.base_url)
142
+ self._profile = profile or provider.model_profile
141
143
 
142
144
  @property
143
145
  def base_url(self) -> str:
@@ -171,16 +173,6 @@ class GeminiModel(Model):
171
173
  ) as http_response:
172
174
  yield await self._process_streamed_response(http_response)
173
175
 
174
- def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
175
- def _customize_tool_def(t: ToolDefinition):
176
- return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
177
-
178
- return ModelRequestParameters(
179
- function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
180
- allow_text_output=model_request_parameters.allow_text_output,
181
- output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
182
- )
183
-
184
176
  @property
185
177
  def model_name(self) -> GeminiModelName:
186
178
  """The model name."""
@@ -259,6 +251,8 @@ class GeminiModel(Model):
259
251
  yield r
260
252
 
261
253
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
254
+ vendor_details: dict[str, Any] | None = None
255
+
262
256
  if len(response['candidates']) != 1:
263
257
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
264
258
  if 'content' not in response['candidates'][0]:
@@ -269,9 +263,19 @@ class GeminiModel(Model):
269
263
  'Content field missing from Gemini response', str(response)
270
264
  )
271
265
  parts = response['candidates'][0]['content']['parts']
266
+ vendor_id = response.get('vendor_id', None)
267
+ finish_reason = response['candidates'][0].get('finish_reason')
268
+ if finish_reason:
269
+ vendor_details = {'finish_reason': finish_reason}
272
270
  usage = _metadata_as_usage(response)
273
271
  usage.requests = 1
274
- return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
272
+ return _process_response_from_parts(
273
+ parts,
274
+ response.get('model_version', self._model_name),
275
+ usage,
276
+ vendor_id=vendor_id,
277
+ vendor_details=vendor_details,
278
+ )
275
279
 
276
280
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
277
281
  """Process a streamed response, and prepare a streaming response to return."""
@@ -451,13 +455,12 @@ class GeminiStreamedResponse(StreamedResponse):
451
455
  responses_to_yield = gemini_responses[:-1]
452
456
  for r in responses_to_yield[current_gemini_response_index:]:
453
457
  current_gemini_response_index += 1
454
- self._usage += _metadata_as_usage(r)
455
458
  yield r
456
459
 
457
460
  # Now yield the final response, which should be complete
458
461
  if gemini_responses: # pragma: no branch
459
462
  r = gemini_responses[-1]
460
- self._usage += _metadata_as_usage(r)
463
+ self._usage = _metadata_as_usage(r)
461
464
  yield r
462
465
 
463
466
  @property
@@ -611,7 +614,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
611
614
 
612
615
 
613
616
  def _process_response_from_parts(
614
- parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
617
+ parts: Sequence[_GeminiPartUnion],
618
+ model_name: GeminiModelName,
619
+ usage: usage.Usage,
620
+ vendor_id: str | None,
621
+ vendor_details: dict[str, Any] | None = None,
615
622
  ) -> ModelResponse:
616
623
  items: list[ModelResponsePart] = []
617
624
  for part in parts:
@@ -623,7 +630,9 @@ def _process_response_from_parts(
623
630
  raise UnexpectedModelBehavior(
624
631
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
625
632
  )
626
- return ModelResponse(parts=items, usage=usage, model_name=model_name)
633
+ return ModelResponse(
634
+ parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
635
+ )
627
636
 
628
637
 
629
638
  class _GeminiFunctionCall(TypedDict):
@@ -735,6 +744,7 @@ class _GeminiResponse(TypedDict):
735
744
  usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
736
745
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
737
746
  model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
747
+ vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]
738
748
 
739
749
 
740
750
  class _GeminiCandidates(TypedDict):
@@ -751,8 +761,17 @@ class _GeminiCandidates(TypedDict):
751
761
  safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
752
762
 
753
763
 
764
+ class _GeminiModalityTokenCount(TypedDict):
765
+ """See <https://ai.google.dev/api/generate-content#modalitytokencount>."""
766
+
767
+ modality: Annotated[
768
+ Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality')
769
+ ]
770
+ token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)]
771
+
772
+
754
773
  class _GeminiUsageMetaData(TypedDict, total=False):
755
- """See <https://ai.google.dev/api/generate-content#FinishReason>.
774
+ """See <https://ai.google.dev/api/generate-content#UsageMetadata>.
756
775
 
757
776
  The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
758
777
  """
@@ -761,6 +780,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
761
780
  candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
762
781
  total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
763
782
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
783
+ thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]]
784
+ tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]]
785
+ prompt_tokens_details: NotRequired[
786
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')]
787
+ ]
788
+ cache_tokens_details: NotRequired[
789
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')]
790
+ ]
791
+ candidates_tokens_details: NotRequired[
792
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')]
793
+ ]
794
+ tool_use_prompt_tokens_details: NotRequired[
795
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')]
796
+ ]
764
797
 
765
798
 
766
799
  def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
@@ -769,7 +802,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
769
802
  return usage.Usage() # pragma: no cover
770
803
  details: dict[str, int] = {}
771
804
  if cached_content_token_count := metadata.get('cached_content_token_count'):
772
- details['cached_content_token_count'] = cached_content_token_count # pragma: no cover
805
+ details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
806
+
807
+ if thoughts_token_count := metadata.get('thoughts_token_count'):
808
+ details['thoughts_tokens'] = thoughts_token_count
809
+
810
+ if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
811
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
812
+
813
+ for key, metadata_details in metadata.items():
814
+ if key.endswith('_details') and metadata_details:
815
+ metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
816
+ suffix = key.removesuffix('_details')
817
+ for detail in metadata_details:
818
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
819
+
773
820
  return usage.Usage(
774
821
  request_tokens=metadata.get('prompt_token_count', 0),
775
822
  response_tokens=metadata.get('candidates_token_count', 0),
@@ -806,93 +853,6 @@ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
806
853
  _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
807
854
 
808
855
 
809
- class _GeminiJsonSchema(WalkJsonSchema):
810
- """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
811
-
812
- Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
813
- a subset of OpenAPI v3.0.3.
814
-
815
- Specifically:
816
- * gemini doesn't allow the `title` keyword to be set
817
- * gemini doesn't allow `$defs` — we need to inline the definitions where possible
818
- """
819
-
820
- def __init__(self, schema: JsonSchema):
821
- super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
822
-
823
- def transform(self, schema: JsonSchema) -> JsonSchema:
824
- # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
825
- additional_properties = schema.pop(
826
- 'additionalProperties', None
827
- ) # don't pop yet so it's included in the warning
828
- if additional_properties:
829
- original_schema = {**schema, 'additionalProperties': additional_properties}
830
- warnings.warn(
831
- '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
832
- f' Full schema: {self.schema}\n\n'
833
- f'Source of additionalProperties within the full schema: {original_schema}\n\n'
834
- 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
835
- "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
836
- ' and we will fix this behavior.',
837
- UserWarning,
838
- )
839
-
840
- schema.pop('title', None)
841
- schema.pop('default', None)
842
- schema.pop('$schema', None)
843
- if (const := schema.pop('const', None)) is not None: # pragma: no cover
844
- # Gemini doesn't support const, but it does support enum with a single value
845
- schema['enum'] = [const]
846
- schema.pop('discriminator', None)
847
- schema.pop('examples', None)
848
-
849
- # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
850
- # where we add notes about these properties to the field description?
851
- schema.pop('exclusiveMaximum', None)
852
- schema.pop('exclusiveMinimum', None)
853
-
854
- # Gemini only supports string enums, so we need to convert any enum values to strings.
855
- # Pydantic will take care of transforming the transformed string values to the correct type.
856
- if enum := schema.get('enum'):
857
- schema['type'] = 'string'
858
- schema['enum'] = [str(val) for val in enum]
859
-
860
- type_ = schema.get('type')
861
- if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
862
- # This gets hit when we have a discriminated union
863
- # Gemini returns an API error in this case even though it says in its error message it shouldn't...
864
- # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
865
- schema['anyOf'] = schema.pop('oneOf')
866
-
867
- if type_ == 'string' and (fmt := schema.pop('format', None)):
868
- description = schema.get('description')
869
- if description:
870
- schema['description'] = f'{description} (format: {fmt})'
871
- else:
872
- schema['description'] = f'Format: {fmt}'
873
-
874
- if '$ref' in schema:
875
- raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
876
-
877
- if 'prefixItems' in schema:
878
- # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
879
- prefix_items = schema.pop('prefixItems')
880
- items = schema.get('items')
881
- unique_items = [items] if items is not None else []
882
- for item in prefix_items:
883
- if item not in unique_items:
884
- unique_items.append(item)
885
- if len(unique_items) > 1: # pragma: no cover
886
- schema['items'] = {'anyOf': unique_items}
887
- elif len(unique_items) == 1: # pragma: no branch
888
- schema['items'] = unique_items[0]
889
- schema.setdefault('minItems', len(prefix_items))
890
- if items is None: # pragma: no branch
891
- schema.setdefault('maxItems', len(prefix_items))
892
-
893
- return schema
894
-
895
-
896
856
  def _ensure_decodeable(content: bytearray) -> bytearray:
897
857
  """Trim any invalid unicode point bytes off the end of a bytearray.
898
858