pydantic-ai-slim 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. pydantic_ai/_agent_graph.py +29 -35
  2. pydantic_ai/{_pydantic.py → _function_schema.py} +48 -8
  3. pydantic_ai/_output.py +266 -119
  4. pydantic_ai/agent.py +15 -15
  5. pydantic_ai/mcp.py +1 -1
  6. pydantic_ai/messages.py +2 -2
  7. pydantic_ai/models/__init__.py +39 -3
  8. pydantic_ai/models/anthropic.py +4 -0
  9. pydantic_ai/models/bedrock.py +43 -16
  10. pydantic_ai/models/cohere.py +4 -0
  11. pydantic_ai/models/gemini.py +78 -109
  12. pydantic_ai/models/google.py +47 -112
  13. pydantic_ai/models/groq.py +17 -2
  14. pydantic_ai/models/mistral.py +4 -0
  15. pydantic_ai/models/openai.py +25 -158
  16. pydantic_ai/profiles/__init__.py +39 -0
  17. pydantic_ai/{models → profiles}/_json_schema.py +23 -2
  18. pydantic_ai/profiles/amazon.py +9 -0
  19. pydantic_ai/profiles/anthropic.py +8 -0
  20. pydantic_ai/profiles/cohere.py +8 -0
  21. pydantic_ai/profiles/deepseek.py +8 -0
  22. pydantic_ai/profiles/google.py +100 -0
  23. pydantic_ai/profiles/grok.py +8 -0
  24. pydantic_ai/profiles/meta.py +9 -0
  25. pydantic_ai/profiles/mistral.py +8 -0
  26. pydantic_ai/profiles/openai.py +144 -0
  27. pydantic_ai/profiles/qwen.py +9 -0
  28. pydantic_ai/providers/__init__.py +18 -0
  29. pydantic_ai/providers/anthropic.py +5 -0
  30. pydantic_ai/providers/azure.py +34 -0
  31. pydantic_ai/providers/bedrock.py +60 -1
  32. pydantic_ai/providers/cohere.py +5 -0
  33. pydantic_ai/providers/deepseek.py +12 -0
  34. pydantic_ai/providers/fireworks.py +99 -0
  35. pydantic_ai/providers/google.py +5 -0
  36. pydantic_ai/providers/google_gla.py +5 -0
  37. pydantic_ai/providers/google_vertex.py +5 -0
  38. pydantic_ai/providers/grok.py +82 -0
  39. pydantic_ai/providers/groq.py +25 -0
  40. pydantic_ai/providers/mistral.py +5 -0
  41. pydantic_ai/providers/openai.py +5 -0
  42. pydantic_ai/providers/openrouter.py +36 -0
  43. pydantic_ai/providers/together.py +96 -0
  44. pydantic_ai/result.py +34 -103
  45. pydantic_ai/tools.py +29 -59
  46. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/METADATA +4 -4
  47. pydantic_ai_slim-0.2.13.dist-info/RECORD +73 -0
  48. pydantic_ai_slim-0.2.11.dist-info/RECORD +0 -59
  49. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/WHEEL +0 -0
  50. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/entry_points.txt +0 -0
  51. {pydantic_ai_slim-0.2.11.dist-info → pydantic_ai_slim-0.2.13.dist-info}/licenses/LICENSE +0 -0
@@ -9,16 +9,19 @@ from __future__ import annotations as _annotations
9
9
  from abc import ABC, abstractmethod
10
10
  from collections.abc import AsyncIterator, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
- from dataclasses import dataclass, field
12
+ from dataclasses import dataclass, field, replace
13
13
  from datetime import datetime
14
- from functools import cache
14
+ from functools import cache, cached_property
15
15
 
16
16
  import httpx
17
17
  from typing_extensions import Literal, TypeAliasType
18
18
 
19
+ from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
20
+
19
21
  from .._parts_manager import ModelResponsePartsManager
20
22
  from ..exceptions import UserError
21
23
  from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
24
+ from ..profiles._json_schema import JsonSchemaTransformer
22
25
  from ..settings import ModelSettings
23
26
  from ..tools import ToolDefinition
24
27
  from ..usage import Usage
@@ -68,6 +71,10 @@ KnownModelName = TypeAliasType(
68
71
  'bedrock:us.anthropic.claude-3-5-sonnet-20240620-v1:0',
69
72
  'bedrock:anthropic.claude-3-7-sonnet-20250219-v1:0',
70
73
  'bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0',
74
+ 'bedrock:anthropic.claude-opus-4-20250514-v1:0',
75
+ 'bedrock:us.anthropic.claude-opus-4-20250514-v1:0',
76
+ 'bedrock:anthropic.claude-sonnet-4-20250514-v1:0',
77
+ 'bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0',
71
78
  'bedrock:cohere.command-text-v14',
72
79
  'bedrock:cohere.command-r-v1:0',
73
80
  'bedrock:cohere.command-r-plus-v1:0',
@@ -292,6 +299,8 @@ class ModelRequestParameters:
292
299
  class Model(ABC):
293
300
  """Abstract class for a model."""
294
301
 
302
+ _profile: ModelProfileSpec | None = None
303
+
295
304
  @abstractmethod
296
305
  async def request(
297
306
  self,
@@ -323,6 +332,13 @@ class Model(ABC):
323
332
  In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
324
333
  for vendor/model-specific reasons.
325
334
  """
335
+ if transformer := self.profile.json_schema_transformer:
336
+ model_request_parameters = replace(
337
+ model_request_parameters,
338
+ function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
339
+ output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
340
+ )
341
+
326
342
  return model_request_parameters
327
343
 
328
344
  @property
@@ -331,6 +347,18 @@ class Model(ABC):
331
347
  """The model name."""
332
348
  raise NotImplementedError()
333
349
 
350
+ @cached_property
351
+ def profile(self) -> ModelProfile:
352
+ """The model profile."""
353
+ _profile = self._profile
354
+ if callable(_profile):
355
+ _profile = _profile(self.model_name)
356
+
357
+ if _profile is None:
358
+ return DEFAULT_PROFILE
359
+
360
+ return _profile
361
+
334
362
  @property
335
363
  @abstractmethod
336
364
  def system(self) -> str:
@@ -515,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
515
543
  from .cohere import CohereModel
516
544
 
517
545
  return CohereModel(model_name, provider=provider)
518
- elif provider in ('deepseek', 'openai', 'azure', 'openrouter'):
546
+ elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
519
547
  from .openai import OpenAIModel
520
548
 
521
549
  return OpenAIModel(model_name, provider=provider)
@@ -584,3 +612,11 @@ def get_user_agent() -> str:
584
612
  from .. import __version__
585
613
 
586
614
  return f'pydantic-ai/{__version__}'
615
+
616
+
617
+ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinition):
618
+ schema_transformer = transformer(t.parameters_json_schema, strict=t.strict)
619
+ parameters_json_schema = schema_transformer.walk()
620
+ if t.strict is None:
621
+ t = replace(t, strict=schema_transformer.is_strict_compatible)
622
+ return replace(t, parameters_json_schema=parameters_json_schema)
@@ -27,6 +27,7 @@ from ..messages import (
27
27
  ToolReturnPart,
28
28
  UserPromptPart,
29
29
  )
30
+ from ..profiles import ModelProfileSpec
30
31
  from ..providers import Provider, infer_provider
31
32
  from ..settings import ModelSettings
32
33
  from ..tools import ToolDefinition
@@ -118,6 +119,7 @@ class AnthropicModel(Model):
118
119
  model_name: AnthropicModelName,
119
120
  *,
120
121
  provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
122
+ profile: ModelProfileSpec | None = None,
121
123
  ):
122
124
  """Initialize an Anthropic model.
123
125
 
@@ -126,12 +128,14 @@ class AnthropicModel(Model):
126
128
  [here](https://docs.anthropic.com/en/docs/about-claude/models).
127
129
  provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
128
130
  instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
131
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
129
132
  """
130
133
  self._model_name = model_name
131
134
 
132
135
  if isinstance(provider, str):
133
136
  provider = infer_provider(provider)
134
137
  self.client = provider.client
138
+ self._profile = profile or provider.model_profile
135
139
 
136
140
  @property
137
141
  def base_url(self) -> str:
@@ -32,8 +32,15 @@ from pydantic_ai.messages import (
32
32
  UserPromptPart,
33
33
  VideoUrl,
34
34
  )
35
- from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
35
+ from pydantic_ai.models import (
36
+ Model,
37
+ ModelRequestParameters,
38
+ StreamedResponse,
39
+ cached_async_http_client,
40
+ )
41
+ from pydantic_ai.profiles import ModelProfileSpec
36
42
  from pydantic_ai.providers import Provider, infer_provider
43
+ from pydantic_ai.providers.bedrock import BedrockModelProfile
37
44
  from pydantic_ai.settings import ModelSettings
38
45
  from pydantic_ai.tools import ToolDefinition
39
46
 
@@ -56,6 +63,7 @@ if TYPE_CHECKING:
56
63
  PromptVariableValuesTypeDef,
57
64
  SystemContentBlockTypeDef,
58
65
  ToolChoiceTypeDef,
66
+ ToolConfigurationTypeDef,
59
67
  ToolTypeDef,
60
68
  VideoBlockTypeDef,
61
69
  )
@@ -85,6 +93,10 @@ LatestBedrockModelNames = Literal[
85
93
  'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
86
94
  'anthropic.claude-3-7-sonnet-20250219-v1:0',
87
95
  'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
96
+ 'anthropic.claude-opus-4-20250514-v1:0',
97
+ 'us.anthropic.claude-opus-4-20250514-v1:0',
98
+ 'anthropic.claude-sonnet-4-20250514-v1:0',
99
+ 'us.anthropic.claude-sonnet-4-20250514-v1:0',
88
100
  'cohere.command-text-v14',
89
101
  'cohere.command-r-v1:0',
90
102
  'cohere.command-r-plus-v1:0',
@@ -190,6 +202,7 @@ class BedrockConverseModel(Model):
190
202
  model_name: BedrockModelName,
191
203
  *,
192
204
  provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
205
+ profile: ModelProfileSpec | None = None,
193
206
  ):
194
207
  """Initialize a Bedrock model.
195
208
 
@@ -200,12 +213,14 @@ class BedrockConverseModel(Model):
200
213
  provider: The provider to use for authentication and API access. Can be either the string
201
214
  'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
202
215
  created using the other parameters.
216
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
203
217
  """
204
218
  self._model_name = model_name
205
219
 
206
220
  if isinstance(provider, str):
207
221
  provider = infer_provider(provider)
208
222
  self.client = cast('BedrockRuntimeClient', provider.client)
223
+ self._profile = profile or provider.model_profile
209
224
 
210
225
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
211
226
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
@@ -301,15 +316,6 @@ class BedrockConverseModel(Model):
301
316
  model_settings: BedrockModelSettings | None,
302
317
  model_request_parameters: ModelRequestParameters,
303
318
  ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
304
- tools = self._get_tools(model_request_parameters)
305
- support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
306
- if not tools or not support_tools_choice:
307
- tool_choice: ToolChoiceTypeDef = {}
308
- elif not model_request_parameters.allow_text_output:
309
- tool_choice = {'any': {}} # pragma: no cover
310
- else:
311
- tool_choice = {'auto': {}}
312
-
313
319
  system_prompt, bedrock_messages = await self._map_messages(messages)
314
320
  inference_config = self._map_inference_config(model_settings)
315
321
 
@@ -320,6 +326,10 @@ class BedrockConverseModel(Model):
320
326
  'inferenceConfig': inference_config,
321
327
  }
322
328
 
329
+ tool_config = self._map_tool_config(model_request_parameters)
330
+ if tool_config:
331
+ params['toolConfig'] = tool_config
332
+
323
333
  # Bedrock supports a set of specific extra parameters
324
334
  if model_settings:
325
335
  if guardrail_config := model_settings.get('bedrock_guardrail_config', None):
@@ -337,11 +347,6 @@ class BedrockConverseModel(Model):
337
347
  if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
338
348
  params['promptVariables'] = prompt_variables
339
349
 
340
- if tools:
341
- params['toolConfig'] = {'tools': tools}
342
- if tool_choice:
343
- params['toolConfig']['toolChoice'] = tool_choice
344
-
345
350
  if stream:
346
351
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
347
352
  model_response = model_response['stream']
@@ -367,6 +372,23 @@ class BedrockConverseModel(Model):
367
372
 
368
373
  return inference_config
369
374
 
375
+ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
376
+ tools = self._get_tools(model_request_parameters)
377
+ if not tools:
378
+ return None
379
+
380
+ tool_choice: ToolChoiceTypeDef
381
+ if not model_request_parameters.allow_text_output:
382
+ tool_choice = {'any': {}}
383
+ else:
384
+ tool_choice = {'auto': {}}
385
+
386
+ tool_config: ToolConfigurationTypeDef = {'tools': tools}
387
+ if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
388
+ tool_config['toolChoice'] = tool_choice
389
+
390
+ return tool_config
391
+
370
392
  async def _map_messages(
371
393
  self, messages: list[ModelMessage]
372
394
  ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
@@ -374,6 +396,7 @@ class BedrockConverseModel(Model):
374
396
 
375
397
  Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
376
398
  """
399
+ profile = BedrockModelProfile.from_profile(self.profile)
377
400
  system_prompt: list[SystemContentBlockTypeDef] = []
378
401
  bedrock_messages: list[MessageUnionTypeDef] = []
379
402
  document_count: Iterator[int] = count(1)
@@ -393,7 +416,11 @@ class BedrockConverseModel(Model):
393
416
  {
394
417
  'toolResult': {
395
418
  'toolUseId': part.tool_call_id,
396
- 'content': [{'text': part.model_response_str()}],
419
+ 'content': [
420
+ {'text': part.model_response_str()}
421
+ if profile.bedrock_tool_result_format == 'text'
422
+ else {'json': part.model_response_object()}
423
+ ],
397
424
  'status': 'success',
398
425
  }
399
426
  }
@@ -20,6 +20,7 @@ from ..messages import (
20
20
  ToolReturnPart,
21
21
  UserPromptPart,
22
22
  )
23
+ from ..profiles import ModelProfileSpec
23
24
  from ..providers import Provider, infer_provider
24
25
  from ..settings import ModelSettings
25
26
  from ..tools import ToolDefinition
@@ -107,6 +108,7 @@ class CohereModel(Model):
107
108
  model_name: CohereModelName,
108
109
  *,
109
110
  provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111
+ profile: ModelProfileSpec | None = None,
110
112
  ):
111
113
  """Initialize an Cohere model.
112
114
 
@@ -116,12 +118,14 @@ class CohereModel(Model):
116
118
  provider: The provider to use for authentication and API access. Can be either the string
117
119
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
118
120
  created using the other parameters.
121
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
119
122
  """
120
123
  self._model_name = model_name
121
124
 
122
125
  if isinstance(provider, str):
123
126
  provider = infer_provider(provider)
124
127
  self.client = provider.client
128
+ self._profile = profile or provider.model_profile
125
129
 
126
130
  @property
127
131
  def base_url(self) -> str:
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import warnings
5
4
  from collections.abc import AsyncIterator, Sequence
6
5
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field, replace
6
+ from dataclasses import dataclass, field
8
7
  from datetime import datetime
9
8
  from typing import Annotated, Any, Literal, Protocol, Union, cast
10
9
  from uuid import uuid4
@@ -16,7 +15,7 @@ from typing_extensions import NotRequired, TypedDict, assert_never
16
15
 
17
16
  from pydantic_ai.providers import Provider, infer_provider
18
17
 
19
- from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
18
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
20
19
  from ..messages import (
21
20
  AudioUrl,
22
21
  BinaryContent,
@@ -35,6 +34,7 @@ from ..messages import (
35
34
  UserPromptPart,
36
35
  VideoUrl,
37
36
  )
37
+ from ..profiles import ModelProfileSpec
38
38
  from ..settings import ModelSettings
39
39
  from ..tools import ToolDefinition
40
40
  from . import (
@@ -45,7 +45,6 @@ from . import (
45
45
  check_allow_model_requests,
46
46
  get_user_agent,
47
47
  )
48
- from ._json_schema import JsonSchema, WalkJsonSchema
49
48
 
50
49
  LatestGeminiModelNames = Literal[
51
50
  'gemini-1.5-flash',
@@ -121,6 +120,7 @@ class GeminiModel(Model):
121
120
  model_name: GeminiModelName,
122
121
  *,
123
122
  provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
123
+ profile: ModelProfileSpec | None = None,
124
124
  ):
125
125
  """Initialize a Gemini model.
126
126
 
@@ -129,6 +129,7 @@ class GeminiModel(Model):
129
129
  provider: The provider to use for authentication and API access. Can be either the string
130
130
  'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
131
131
  If not provided, a new provider will be created using the other parameters.
132
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
132
133
  """
133
134
  self._model_name = model_name
134
135
  self._provider = provider
@@ -138,6 +139,7 @@ class GeminiModel(Model):
138
139
  self._system = provider.name
139
140
  self.client = provider.client
140
141
  self._url = str(self.client.base_url)
142
+ self._profile = profile or provider.model_profile
141
143
 
142
144
  @property
143
145
  def base_url(self) -> str:
@@ -171,16 +173,6 @@ class GeminiModel(Model):
171
173
  ) as http_response:
172
174
  yield await self._process_streamed_response(http_response)
173
175
 
174
- def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
175
- def _customize_tool_def(t: ToolDefinition):
176
- return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk())
177
-
178
- return ModelRequestParameters(
179
- function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
180
- allow_text_output=model_request_parameters.allow_text_output,
181
- output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
182
- )
183
-
184
176
  @property
185
177
  def model_name(self) -> GeminiModelName:
186
178
  """The model name."""
@@ -259,6 +251,8 @@ class GeminiModel(Model):
259
251
  yield r
260
252
 
261
253
  def _process_response(self, response: _GeminiResponse) -> ModelResponse:
254
+ vendor_details: dict[str, Any] | None = None
255
+
262
256
  if len(response['candidates']) != 1:
263
257
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
264
258
  if 'content' not in response['candidates'][0]:
@@ -269,9 +263,19 @@ class GeminiModel(Model):
269
263
  'Content field missing from Gemini response', str(response)
270
264
  )
271
265
  parts = response['candidates'][0]['content']['parts']
266
+ vendor_id = response.get('vendor_id', None)
267
+ finish_reason = response['candidates'][0].get('finish_reason')
268
+ if finish_reason:
269
+ vendor_details = {'finish_reason': finish_reason}
272
270
  usage = _metadata_as_usage(response)
273
271
  usage.requests = 1
274
- return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
272
+ return _process_response_from_parts(
273
+ parts,
274
+ response.get('model_version', self._model_name),
275
+ usage,
276
+ vendor_id=vendor_id,
277
+ vendor_details=vendor_details,
278
+ )
275
279
 
276
280
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
277
281
  """Process a streamed response, and prepare a streaming response to return."""
@@ -427,7 +431,8 @@ class GeminiStreamedResponse(StreamedResponse):
427
431
  if maybe_event is not None: # pragma: no branch
428
432
  yield maybe_event
429
433
  else:
430
- assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' # pragma: no cover
434
+ if not any([key in gemini_part for key in ['function_response', 'thought']]):
435
+ raise AssertionError(f'Unexpected part: {gemini_part}') # pragma: no cover
431
436
 
432
437
  async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
433
438
  # This method exists to ensure we only yield completed items, so we don't need to worry about
@@ -451,13 +456,12 @@ class GeminiStreamedResponse(StreamedResponse):
451
456
  responses_to_yield = gemini_responses[:-1]
452
457
  for r in responses_to_yield[current_gemini_response_index:]:
453
458
  current_gemini_response_index += 1
454
- self._usage += _metadata_as_usage(r)
455
459
  yield r
456
460
 
457
461
  # Now yield the final response, which should be complete
458
462
  if gemini_responses: # pragma: no branch
459
463
  r = gemini_responses[-1]
460
- self._usage += _metadata_as_usage(r)
464
+ self._usage = _metadata_as_usage(r)
461
465
  yield r
462
466
 
463
467
  @property
@@ -602,6 +606,11 @@ class _GeminiFileDataPart(TypedDict):
602
606
  file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
603
607
 
604
608
 
609
+ class _GeminiThoughtPart(TypedDict):
610
+ thought: bool
611
+ thought_signature: Annotated[str, pydantic.Field(alias='thoughtSignature')]
612
+
613
+
605
614
  class _GeminiFunctionCallPart(TypedDict):
606
615
  function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
607
616
 
@@ -611,7 +620,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
611
620
 
612
621
 
613
622
  def _process_response_from_parts(
614
- parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
623
+ parts: Sequence[_GeminiPartUnion],
624
+ model_name: GeminiModelName,
625
+ usage: usage.Usage,
626
+ vendor_id: str | None,
627
+ vendor_details: dict[str, Any] | None = None,
615
628
  ) -> ModelResponse:
616
629
  items: list[ModelResponsePart] = []
617
630
  for part in parts:
@@ -623,7 +636,9 @@ def _process_response_from_parts(
623
636
  raise UnexpectedModelBehavior(
624
637
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
625
638
  )
626
- return ModelResponse(parts=items, usage=usage, model_name=model_name)
639
+ return ModelResponse(
640
+ parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
641
+ )
627
642
 
628
643
 
629
644
  class _GeminiFunctionCall(TypedDict):
@@ -656,6 +671,8 @@ def _part_discriminator(v: Any) -> str:
656
671
  return 'inline_data' # pragma: no cover
657
672
  elif 'fileData' in v:
658
673
  return 'file_data' # pragma: no cover
674
+ elif 'thought' in v:
675
+ return 'thought'
659
676
  elif 'functionCall' in v or 'function_call' in v:
660
677
  return 'function_call'
661
678
  elif 'functionResponse' in v or 'function_response' in v:
@@ -673,6 +690,7 @@ _GeminiPartUnion = Annotated[
673
690
  Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
674
691
  Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
675
692
  Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
693
+ Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
676
694
  ],
677
695
  pydantic.Discriminator(_part_discriminator),
678
696
  ]
@@ -735,6 +753,7 @@ class _GeminiResponse(TypedDict):
735
753
  usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
736
754
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
737
755
  model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
756
+ vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]
738
757
 
739
758
 
740
759
  class _GeminiCandidates(TypedDict):
@@ -751,8 +770,17 @@ class _GeminiCandidates(TypedDict):
751
770
  safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
752
771
 
753
772
 
773
+ class _GeminiModalityTokenCount(TypedDict):
774
+ """See <https://ai.google.dev/api/generate-content#modalitytokencount>."""
775
+
776
+ modality: Annotated[
777
+ Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality')
778
+ ]
779
+ token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)]
780
+
781
+
754
782
  class _GeminiUsageMetaData(TypedDict, total=False):
755
- """See <https://ai.google.dev/api/generate-content#FinishReason>.
783
+ """See <https://ai.google.dev/api/generate-content#UsageMetadata>.
756
784
 
757
785
  The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
758
786
  """
@@ -761,6 +789,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
761
789
  candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
762
790
  total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
763
791
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
792
+ thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]]
793
+ tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]]
794
+ prompt_tokens_details: NotRequired[
795
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')]
796
+ ]
797
+ cache_tokens_details: NotRequired[
798
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')]
799
+ ]
800
+ candidates_tokens_details: NotRequired[
801
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')]
802
+ ]
803
+ tool_use_prompt_tokens_details: NotRequired[
804
+ Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')]
805
+ ]
764
806
 
765
807
 
766
808
  def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
@@ -769,7 +811,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
769
811
  return usage.Usage() # pragma: no cover
770
812
  details: dict[str, int] = {}
771
813
  if cached_content_token_count := metadata.get('cached_content_token_count'):
772
- details['cached_content_token_count'] = cached_content_token_count # pragma: no cover
814
+ details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
815
+
816
+ if thoughts_token_count := metadata.get('thoughts_token_count'):
817
+ details['thoughts_tokens'] = thoughts_token_count
818
+
819
+ if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
820
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
821
+
822
+ for key, metadata_details in metadata.items():
823
+ if key.endswith('_details') and metadata_details:
824
+ metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
825
+ suffix = key.removesuffix('_details')
826
+ for detail in metadata_details:
827
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
828
+
773
829
  return usage.Usage(
774
830
  request_tokens=metadata.get('prompt_token_count', 0),
775
831
  response_tokens=metadata.get('candidates_token_count', 0),
@@ -806,93 +862,6 @@ _gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
806
862
  _gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
807
863
 
808
864
 
809
- class _GeminiJsonSchema(WalkJsonSchema):
810
- """Transforms the JSON Schema from Pydantic to be suitable for Gemini.
811
-
812
- Gemini which [supports](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations)
813
- a subset of OpenAPI v3.0.3.
814
-
815
- Specifically:
816
- * gemini doesn't allow the `title` keyword to be set
817
- * gemini doesn't allow `$defs` — we need to inline the definitions where possible
818
- """
819
-
820
- def __init__(self, schema: JsonSchema):
821
- super().__init__(schema, prefer_inlined_defs=True, simplify_nullable_unions=True)
822
-
823
- def transform(self, schema: JsonSchema) -> JsonSchema:
824
- # Note: we need to remove `additionalProperties: False` since it is currently mishandled by Gemini
825
- additional_properties = schema.pop(
826
- 'additionalProperties', None
827
- ) # don't pop yet so it's included in the warning
828
- if additional_properties:
829
- original_schema = {**schema, 'additionalProperties': additional_properties}
830
- warnings.warn(
831
- '`additionalProperties` is not supported by Gemini; it will be removed from the tool JSON schema.'
832
- f' Full schema: {self.schema}\n\n'
833
- f'Source of additionalProperties within the full schema: {original_schema}\n\n'
834
- 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
835
- "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
836
- ' and we will fix this behavior.',
837
- UserWarning,
838
- )
839
-
840
- schema.pop('title', None)
841
- schema.pop('default', None)
842
- schema.pop('$schema', None)
843
- if (const := schema.pop('const', None)) is not None: # pragma: no cover
844
- # Gemini doesn't support const, but it does support enum with a single value
845
- schema['enum'] = [const]
846
- schema.pop('discriminator', None)
847
- schema.pop('examples', None)
848
-
849
- # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
850
- # where we add notes about these properties to the field description?
851
- schema.pop('exclusiveMaximum', None)
852
- schema.pop('exclusiveMinimum', None)
853
-
854
- # Gemini only supports string enums, so we need to convert any enum values to strings.
855
- # Pydantic will take care of transforming the transformed string values to the correct type.
856
- if enum := schema.get('enum'):
857
- schema['type'] = 'string'
858
- schema['enum'] = [str(val) for val in enum]
859
-
860
- type_ = schema.get('type')
861
- if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
862
- # This gets hit when we have a discriminated union
863
- # Gemini returns an API error in this case even though it says in its error message it shouldn't...
864
- # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
865
- schema['anyOf'] = schema.pop('oneOf')
866
-
867
- if type_ == 'string' and (fmt := schema.pop('format', None)):
868
- description = schema.get('description')
869
- if description:
870
- schema['description'] = f'{description} (format: {fmt})'
871
- else:
872
- schema['description'] = f'Format: {fmt}'
873
-
874
- if '$ref' in schema:
875
- raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Gemini: {schema["$ref"]}')
876
-
877
- if 'prefixItems' in schema:
878
- # prefixItems is not currently supported in Gemini, so we convert it to items for best compatibility
879
- prefix_items = schema.pop('prefixItems')
880
- items = schema.get('items')
881
- unique_items = [items] if items is not None else []
882
- for item in prefix_items:
883
- if item not in unique_items:
884
- unique_items.append(item)
885
- if len(unique_items) > 1: # pragma: no cover
886
- schema['items'] = {'anyOf': unique_items}
887
- elif len(unique_items) == 1: # pragma: no branch
888
- schema['items'] = unique_items[0]
889
- schema.setdefault('minItems', len(prefix_items))
890
- if items is None: # pragma: no branch
891
- schema.setdefault('maxItems', len(prefix_items))
892
-
893
- return schema
894
-
895
-
896
865
  def _ensure_decodeable(content: bytearray) -> bytearray:
897
866
  """Trim any invalid unicode point bytes off the end of a bytearray.
898
867