pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (39) hide show
  1. pydantic_ai/__init__.py +2 -1
  2. pydantic_ai/_agent_graph.py +2 -2
  3. pydantic_ai/_cli.py +18 -3
  4. pydantic_ai/_run_context.py +2 -2
  5. pydantic_ai/ag_ui.py +4 -4
  6. pydantic_ai/agent/__init__.py +7 -9
  7. pydantic_ai/agent/abstract.py +16 -18
  8. pydantic_ai/agent/wrapper.py +4 -6
  9. pydantic_ai/builtin_tools.py +9 -1
  10. pydantic_ai/direct.py +4 -4
  11. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  12. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  13. pydantic_ai/messages.py +16 -6
  14. pydantic_ai/models/__init__.py +5 -5
  15. pydantic_ai/models/anthropic.py +27 -26
  16. pydantic_ai/models/bedrock.py +24 -26
  17. pydantic_ai/models/cohere.py +20 -25
  18. pydantic_ai/models/fallback.py +15 -15
  19. pydantic_ai/models/function.py +7 -9
  20. pydantic_ai/models/gemini.py +43 -39
  21. pydantic_ai/models/google.py +76 -50
  22. pydantic_ai/models/groq.py +22 -19
  23. pydantic_ai/models/huggingface.py +18 -21
  24. pydantic_ai/models/instrumented.py +4 -4
  25. pydantic_ai/models/mcp_sampling.py +1 -2
  26. pydantic_ai/models/mistral.py +24 -22
  27. pydantic_ai/models/openai.py +98 -44
  28. pydantic_ai/models/test.py +4 -5
  29. pydantic_ai/profiles/openai.py +13 -3
  30. pydantic_ai/providers/openai.py +1 -1
  31. pydantic_ai/result.py +5 -5
  32. pydantic_ai/run.py +4 -11
  33. pydantic_ai/tools.py +5 -2
  34. pydantic_ai/usage.py +230 -68
  35. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/METADATA +10 -4
  36. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/RECORD +39 -39
  37. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/WHEEL +0 -0
  38. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/entry_points.txt +0 -0
  39. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -42,7 +42,7 @@ from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
42
42
  from ..profiles._json_schema import JsonSchemaTransformer
43
43
  from ..settings import ModelSettings
44
44
  from ..tools import ToolDefinition
45
- from ..usage import Usage
45
+ from ..usage import RequestUsage
46
46
 
47
47
  KnownModelName = TypeAliasType(
48
48
  'KnownModelName',
@@ -418,7 +418,7 @@ class Model(ABC):
418
418
  messages: list[ModelMessage],
419
419
  model_settings: ModelSettings | None,
420
420
  model_request_parameters: ModelRequestParameters,
421
- ) -> Usage:
421
+ ) -> RequestUsage:
422
422
  """Make a request to the model for counting tokens."""
423
423
  # This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
424
424
  raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
@@ -480,7 +480,7 @@ class Model(ABC):
480
480
  @property
481
481
  @abstractmethod
482
482
  def system(self) -> str:
483
- """The system / model provider, ex: openai.
483
+ """The model provider, ex: openai.
484
484
 
485
485
  Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
486
486
  so should use well-known values listed in
@@ -547,7 +547,7 @@ class StreamedResponse(ABC):
547
547
 
548
548
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
549
549
  _event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
550
- _usage: Usage = field(default_factory=Usage, init=False)
550
+ _usage: RequestUsage = field(default_factory=RequestUsage, init=False)
551
551
 
552
552
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
553
553
  """Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
@@ -600,7 +600,7 @@ class StreamedResponse(ABC):
600
600
  usage=self.usage(),
601
601
  )
602
602
 
603
- def usage(self) -> Usage:
603
+ def usage(self) -> RequestUsage:
604
604
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
605
605
  return self._usage
606
606
 
@@ -137,7 +137,7 @@ class AnthropicModel(Model):
137
137
  client: AsyncAnthropic = field(repr=False)
138
138
 
139
139
  _model_name: AnthropicModelName = field(repr=False)
140
- _system: str = field(default='anthropic', repr=False)
140
+ _provider: Provider[AsyncAnthropic] = field(repr=False)
141
141
 
142
142
  def __init__(
143
143
  self,
@@ -161,6 +161,7 @@ class AnthropicModel(Model):
161
161
 
162
162
  if isinstance(provider, str):
163
163
  provider = infer_provider(provider)
164
+ self._provider = provider
164
165
  self.client = provider.client
165
166
 
166
167
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -169,6 +170,16 @@ class AnthropicModel(Model):
169
170
  def base_url(self) -> str:
170
171
  return str(self.client.base_url)
171
172
 
173
+ @property
174
+ def model_name(self) -> AnthropicModelName:
175
+ """The model name."""
176
+ return self._model_name
177
+
178
+ @property
179
+ def system(self) -> str:
180
+ """The model provider."""
181
+ return self._provider.name
182
+
172
183
  async def request(
173
184
  self,
174
185
  messages: list[ModelMessage],
@@ -180,7 +191,6 @@ class AnthropicModel(Model):
180
191
  messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
181
192
  )
182
193
  model_response = self._process_response(response)
183
- model_response.usage.requests = 1
184
194
  return model_response
185
195
 
186
196
  @asynccontextmanager
@@ -198,16 +208,6 @@ class AnthropicModel(Model):
198
208
  async with response:
199
209
  yield await self._process_streamed_response(response, model_request_parameters)
200
210
 
201
- @property
202
- def model_name(self) -> AnthropicModelName:
203
- """The model name."""
204
- return self._model_name
205
-
206
- @property
207
- def system(self) -> str:
208
- """The system / model provider."""
209
- return self._system
210
-
211
211
  @overload
212
212
  async def _messages_create(
213
213
  self,
@@ -325,7 +325,9 @@ class AnthropicModel(Model):
325
325
  )
326
326
  )
327
327
 
328
- return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
328
+ return ModelResponse(
329
+ items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id
330
+ )
329
331
 
330
332
  async def _process_streamed_response(
331
333
  self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
@@ -528,7 +530,7 @@ class AnthropicModel(Model):
528
530
  }
529
531
 
530
532
 
531
- def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
533
+ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
532
534
  if isinstance(message, BetaMessage):
533
535
  response_usage = message.usage
534
536
  elif isinstance(message, BetaRawMessageStartEvent):
@@ -541,7 +543,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
541
543
  # - RawContentBlockStartEvent
542
544
  # - RawContentBlockDeltaEvent
543
545
  # - RawContentBlockStopEvent
544
- return usage.Usage()
546
+ return usage.RequestUsage()
545
547
 
546
548
  # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
547
549
  # `response_tokens`
@@ -552,17 +554,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
552
554
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
553
555
  # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
554
556
  # This approach maintains request_tokens as the count of all input tokens, with cached counts as details
555
- request_tokens = (
556
- details.get('input_tokens', 0)
557
- + details.get('cache_creation_input_tokens', 0)
558
- + details.get('cache_read_input_tokens', 0)
559
- )
560
-
561
- return usage.Usage(
562
- request_tokens=request_tokens or None,
563
- response_tokens=response_usage.output_tokens,
564
- total_tokens=request_tokens + response_usage.output_tokens,
565
- details=details or None,
557
+ cache_write_tokens = details.get('cache_creation_input_tokens', 0)
558
+ cache_read_tokens = details.get('cache_read_input_tokens', 0)
559
+ request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
560
+
561
+ return usage.RequestUsage(
562
+ input_tokens=request_tokens,
563
+ cache_read_tokens=cache_read_tokens,
564
+ cache_write_tokens=cache_write_tokens,
565
+ output_tokens=response_usage.output_tokens,
566
+ details=details,
566
567
  )
567
568
 
568
569
 
@@ -190,17 +190,7 @@ class BedrockConverseModel(Model):
190
190
  client: BedrockRuntimeClient
191
191
 
192
192
  _model_name: BedrockModelName = field(repr=False)
193
- _system: str = field(default='bedrock', repr=False)
194
-
195
- @property
196
- def model_name(self) -> str:
197
- """The model name."""
198
- return self._model_name
199
-
200
- @property
201
- def system(self) -> str:
202
- """The system / model provider, ex: openai."""
203
- return self._system
193
+ _provider: Provider[BaseClient] = field(repr=False)
204
194
 
205
195
  def __init__(
206
196
  self,
@@ -226,10 +216,25 @@ class BedrockConverseModel(Model):
226
216
 
227
217
  if isinstance(provider, str):
228
218
  provider = infer_provider(provider)
219
+ self._provider = provider
229
220
  self.client = cast('BedrockRuntimeClient', provider.client)
230
221
 
231
222
  super().__init__(settings=settings, profile=profile or provider.model_profile)
232
223
 
224
+ @property
225
+ def base_url(self) -> str:
226
+ return str(self.client.meta.endpoint_url)
227
+
228
+ @property
229
+ def model_name(self) -> str:
230
+ """The model name."""
231
+ return self._model_name
232
+
233
+ @property
234
+ def system(self) -> str:
235
+ """The model provider."""
236
+ return self._provider.name
237
+
233
238
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
234
239
  return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
235
240
 
@@ -245,10 +250,6 @@ class BedrockConverseModel(Model):
245
250
 
246
251
  return {'toolSpec': tool_spec}
247
252
 
248
- @property
249
- def base_url(self) -> str:
250
- return str(self.client.meta.endpoint_url)
251
-
252
253
  async def request(
253
254
  self,
254
255
  messages: list[ModelMessage],
@@ -258,7 +259,6 @@ class BedrockConverseModel(Model):
258
259
  settings = cast(BedrockModelSettings, model_settings or {})
259
260
  response = await self._messages_create(messages, False, settings, model_request_parameters)
260
261
  model_response = await self._process_response(response)
261
- model_response.usage.requests = 1
262
262
  return model_response
263
263
 
264
264
  @asynccontextmanager
@@ -299,13 +299,12 @@ class BedrockConverseModel(Model):
299
299
  tool_call_id=tool_use['toolUseId'],
300
300
  ),
301
301
  )
302
- u = usage.Usage(
303
- request_tokens=response['usage']['inputTokens'],
304
- response_tokens=response['usage']['outputTokens'],
305
- total_tokens=response['usage']['totalTokens'],
302
+ u = usage.RequestUsage(
303
+ input_tokens=response['usage']['inputTokens'],
304
+ output_tokens=response['usage']['outputTokens'],
306
305
  )
307
306
  vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
308
- return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id)
307
+ return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id)
309
308
 
310
309
  @overload
311
310
  async def _messages_create(
@@ -670,11 +669,10 @@ class BedrockStreamedResponse(StreamedResponse):
670
669
  """Get the model name of the response."""
671
670
  return self._model_name
672
671
 
673
- def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage:
674
- return usage.Usage(
675
- request_tokens=metadata['usage']['inputTokens'],
676
- response_tokens=metadata['usage']['outputTokens'],
677
- total_tokens=metadata['usage']['totalTokens'],
672
+ def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
673
+ return usage.RequestUsage(
674
+ input_tokens=metadata['usage']['inputTokens'],
675
+ output_tokens=metadata['usage']['outputTokens'],
678
676
  )
679
677
 
680
678
 
@@ -30,11 +30,7 @@ from ..profiles import ModelProfileSpec
30
30
  from ..providers import Provider, infer_provider
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
- from . import (
34
- Model,
35
- ModelRequestParameters,
36
- check_allow_model_requests,
37
- )
33
+ from . import Model, ModelRequestParameters, check_allow_model_requests
38
34
 
39
35
  try:
40
36
  from cohere import (
@@ -106,7 +102,7 @@ class CohereModel(Model):
106
102
  client: AsyncClientV2 = field(repr=False)
107
103
 
108
104
  _model_name: CohereModelName = field(repr=False)
109
- _system: str = field(default='cohere', repr=False)
105
+ _provider: Provider[AsyncClientV2] = field(repr=False)
110
106
 
111
107
  def __init__(
112
108
  self,
@@ -131,6 +127,7 @@ class CohereModel(Model):
131
127
 
132
128
  if isinstance(provider, str):
133
129
  provider = infer_provider(provider)
130
+ self._provider = provider
134
131
  self.client = provider.client
135
132
 
136
133
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -140,6 +137,16 @@ class CohereModel(Model):
140
137
  client_wrapper = self.client._client_wrapper # type: ignore
141
138
  return str(client_wrapper.get_base_url())
142
139
 
140
+ @property
141
+ def model_name(self) -> CohereModelName:
142
+ """The model name."""
143
+ return self._model_name
144
+
145
+ @property
146
+ def system(self) -> str:
147
+ """The model provider."""
148
+ return self._provider.name
149
+
143
150
  async def request(
144
151
  self,
145
152
  messages: list[ModelMessage],
@@ -149,19 +156,8 @@ class CohereModel(Model):
149
156
  check_allow_model_requests()
150
157
  response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
151
158
  model_response = self._process_response(response)
152
- model_response.usage.requests = 1
153
159
  return model_response
154
160
 
155
- @property
156
- def model_name(self) -> CohereModelName:
157
- """The model name."""
158
- return self._model_name
159
-
160
- @property
161
- def system(self) -> str:
162
- """The system / model provider."""
163
- return self._system
164
-
165
161
  async def _chat(
166
162
  self,
167
163
  messages: list[ModelMessage],
@@ -301,10 +297,10 @@ class CohereModel(Model):
301
297
  assert_never(part)
302
298
 
303
299
 
304
- def _map_usage(response: V2ChatResponse) -> usage.Usage:
300
+ def _map_usage(response: V2ChatResponse) -> usage.RequestUsage:
305
301
  u = response.usage
306
302
  if u is None:
307
- return usage.Usage()
303
+ return usage.RequestUsage()
308
304
  else:
309
305
  details: dict[str, int] = {}
310
306
  if u.billed_units is not None:
@@ -317,11 +313,10 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage:
317
313
  if u.billed_units.classifications: # pragma: no cover
318
314
  details['classifications'] = int(u.billed_units.classifications)
319
315
 
320
- request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None
321
- response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None
322
- return usage.Usage(
323
- request_tokens=request_tokens,
324
- response_tokens=response_tokens,
325
- total_tokens=(request_tokens or 0) + (response_tokens or 0),
316
+ request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0
317
+ response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0
318
+ return usage.RequestUsage(
319
+ input_tokens=request_tokens,
320
+ output_tokens=response_tokens,
326
321
  details=details,
327
322
  )
@@ -33,8 +33,8 @@ class FallbackModel(Model):
33
33
 
34
34
  def __init__(
35
35
  self,
36
- default_model: Model | KnownModelName,
37
- *fallback_models: Model | KnownModelName,
36
+ default_model: Model | KnownModelName | str,
37
+ *fallback_models: Model | KnownModelName | str,
38
38
  fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
39
39
  ):
40
40
  """Initialize a fallback model instance.
@@ -52,6 +52,19 @@ class FallbackModel(Model):
52
52
  else:
53
53
  self._fallback_on = fallback_on
54
54
 
55
+ @property
56
+ def model_name(self) -> str:
57
+ """The model name."""
58
+ return f'fallback:{",".join(model.model_name for model in self.models)}'
59
+
60
+ @property
61
+ def system(self) -> str:
62
+ return f'fallback:{",".join(model.system for model in self.models)}'
63
+
64
+ @property
65
+ def base_url(self) -> str | None:
66
+ return self.models[0].base_url
67
+
55
68
  async def request(
56
69
  self,
57
70
  messages: list[ModelMessage],
@@ -121,19 +134,6 @@ class FallbackModel(Model):
121
134
  if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
122
135
  span.set_attributes(InstrumentedModel.model_attributes(model))
123
136
 
124
- @property
125
- def model_name(self) -> str:
126
- """The model name."""
127
- return f'fallback:{",".join(model.model_name for model in self.models)}'
128
-
129
- @property
130
- def system(self) -> str:
131
- return f'fallback:{",".join(model.system for model in self.models)}'
132
-
133
- @property
134
- def base_url(self) -> str | None:
135
- return self.models[0].base_url
136
-
137
137
 
138
138
  def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
139
139
  """Create a default fallback condition for the given exceptions."""
@@ -138,7 +138,6 @@ class FunctionModel(Model):
138
138
  # Add usage data if not already present
139
139
  if not response.usage.has_values(): # pragma: no branch
140
140
  response.usage = _estimate_usage(chain(messages, [response]))
141
- response.usage.requests = 1
142
141
  return response
143
142
 
144
143
  @asynccontextmanager
@@ -270,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
270
269
  async for item in self._iter:
271
270
  if isinstance(item, str):
272
271
  response_tokens = _estimate_string_tokens(item)
273
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
272
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
274
273
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
275
274
  if maybe_event is not None: # pragma: no branch
276
275
  yield maybe_event
@@ -279,7 +278,7 @@ class FunctionStreamedResponse(StreamedResponse):
279
278
  if isinstance(delta, DeltaThinkingPart):
280
279
  if delta.content: # pragma: no branch
281
280
  response_tokens = _estimate_string_tokens(delta.content)
282
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
281
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
283
282
  yield self._parts_manager.handle_thinking_delta(
284
283
  vendor_part_id=dtc_index,
285
284
  content=delta.content,
@@ -288,7 +287,7 @@ class FunctionStreamedResponse(StreamedResponse):
288
287
  elif isinstance(delta, DeltaToolCall):
289
288
  if delta.json_args:
290
289
  response_tokens = _estimate_string_tokens(delta.json_args)
291
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
290
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
292
291
  maybe_event = self._parts_manager.handle_tool_call_delta(
293
292
  vendor_part_id=dtc_index,
294
293
  tool_name=delta.name,
@@ -311,7 +310,7 @@ class FunctionStreamedResponse(StreamedResponse):
311
310
  return self._timestamp
312
311
 
313
312
 
314
- def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
313
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
315
314
  """Very rough guesstimate of the token usage associated with a series of messages.
316
315
 
317
316
  This is designed to be used solely to give plausible numbers for testing!
@@ -349,10 +348,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
349
348
  assert_never(part)
350
349
  else:
351
350
  assert_never(message)
352
- return usage.Usage(
353
- request_tokens=request_tokens,
354
- response_tokens=response_tokens,
355
- total_tokens=request_tokens + response_tokens,
351
+ return usage.RequestUsage(
352
+ input_tokens=request_tokens,
353
+ output_tokens=response_tokens,
356
354
  )
357
355
 
358
356
 
@@ -40,14 +40,7 @@ from ..profiles import ModelProfileSpec
40
40
  from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
- from . import (
44
- Model,
45
- ModelRequestParameters,
46
- StreamedResponse,
47
- check_allow_model_requests,
48
- download_item,
49
- get_user_agent,
50
- )
43
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
51
44
 
52
45
  LatestGeminiModelNames = Literal[
53
46
  'gemini-2.0-flash',
@@ -108,10 +101,9 @@ class GeminiModel(Model):
108
101
  client: httpx.AsyncClient = field(repr=False)
109
102
 
110
103
  _model_name: GeminiModelName = field(repr=False)
111
- _provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
104
+ _provider: Provider[httpx.AsyncClient] = field(repr=False)
112
105
  _auth: AuthProtocol | None = field(repr=False)
113
106
  _url: str | None = field(repr=False)
114
- _system: str = field(default='gemini', repr=False)
115
107
 
116
108
  def __init__(
117
109
  self,
@@ -132,11 +124,10 @@ class GeminiModel(Model):
132
124
  settings: Default model settings for this model instance.
133
125
  """
134
126
  self._model_name = model_name
135
- self._provider = provider
136
127
 
137
128
  if isinstance(provider, str):
138
129
  provider = infer_provider(provider)
139
- self._system = provider.name
130
+ self._provider = provider
140
131
  self.client = provider.client
141
132
  self._url = str(self.client.base_url)
142
133
 
@@ -147,6 +138,16 @@ class GeminiModel(Model):
147
138
  assert self._url is not None, 'URL not initialized' # pragma: no cover
148
139
  return self._url # pragma: no cover
149
140
 
141
+ @property
142
+ def model_name(self) -> GeminiModelName:
143
+ """The model name."""
144
+ return self._model_name
145
+
146
+ @property
147
+ def system(self) -> str:
148
+ """The model provider."""
149
+ return self._provider.name
150
+
150
151
  async def request(
151
152
  self,
152
153
  messages: list[ModelMessage],
@@ -175,16 +176,6 @@ class GeminiModel(Model):
175
176
  ) as http_response:
176
177
  yield await self._process_streamed_response(http_response, model_request_parameters)
177
178
 
178
- @property
179
- def model_name(self) -> GeminiModelName:
180
- """The model name."""
181
- return self._model_name
182
-
183
- @property
184
- def system(self) -> str:
185
- """The system / model provider."""
186
- return self._system
187
-
188
179
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
189
180
  tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
190
181
  return _GeminiTools(function_declarations=tools) if tools else None
@@ -237,7 +228,7 @@ class GeminiModel(Model):
237
228
  request_data['safetySettings'] = gemini_safety_settings
238
229
 
239
230
  if gemini_labels := model_settings.get('gemini_labels'):
240
- if self._system == 'google-vertex':
231
+ if self._provider.name == 'google-vertex':
241
232
  request_data['labels'] = gemini_labels # pragma: lax no cover
242
233
 
243
234
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
@@ -278,7 +269,6 @@ class GeminiModel(Model):
278
269
  if finish_reason:
279
270
  vendor_details = {'finish_reason': finish_reason}
280
271
  usage = _metadata_as_usage(response)
281
- usage.requests = 1
282
272
  return _process_response_from_parts(
283
273
  parts,
284
274
  response.get('model_version', self._model_name),
@@ -673,7 +663,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
673
663
  def _process_response_from_parts(
674
664
  parts: Sequence[_GeminiPartUnion],
675
665
  model_name: GeminiModelName,
676
- usage: usage.Usage,
666
+ usage: usage.RequestUsage,
677
667
  vendor_id: str | None,
678
668
  vendor_details: dict[str, Any] | None = None,
679
669
  ) -> ModelResponse:
@@ -693,7 +683,7 @@ def _process_response_from_parts(
693
683
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
694
684
  )
695
685
  return ModelResponse(
696
- parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
686
+ parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details
697
687
  )
698
688
 
699
689
 
@@ -859,31 +849,45 @@ class _GeminiUsageMetaData(TypedDict, total=False):
859
849
  ]
860
850
 
861
851
 
862
- def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
852
+ def _metadata_as_usage(response: _GeminiResponse) -> usage.RequestUsage:
863
853
  metadata = response.get('usage_metadata')
864
854
  if metadata is None:
865
- return usage.Usage() # pragma: no cover
855
+ return usage.RequestUsage()
866
856
  details: dict[str, int] = {}
867
- if cached_content_token_count := metadata.get('cached_content_token_count'):
868
- details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
857
+ if cached_content_token_count := metadata.get('cached_content_token_count', 0):
858
+ details['cached_content_tokens'] = cached_content_token_count
869
859
 
870
- if thoughts_token_count := metadata.get('thoughts_token_count'):
860
+ if thoughts_token_count := metadata.get('thoughts_token_count', 0):
871
861
  details['thoughts_tokens'] = thoughts_token_count
872
862
 
873
- if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
874
- details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
863
+ if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count', 0):
864
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
875
865
 
866
+ input_audio_tokens = 0
867
+ output_audio_tokens = 0
868
+ cache_audio_read_tokens = 0
876
869
  for key, metadata_details in metadata.items():
877
870
  if key.endswith('_details') and metadata_details:
878
871
  metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
879
872
  suffix = key.removesuffix('_details')
880
873
  for detail in metadata_details:
881
- details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
882
-
883
- return usage.Usage(
884
- request_tokens=metadata.get('prompt_token_count', 0),
885
- response_tokens=metadata.get('candidates_token_count', 0),
886
- total_tokens=metadata.get('total_token_count', 0),
874
+ modality = detail['modality']
875
+ details[f'{modality.lower()}_{suffix}'] = value = detail.get('token_count', 0)
876
+ if value and modality == 'AUDIO':
877
+ if key == 'prompt_tokens_details':
878
+ input_audio_tokens = value
879
+ elif key == 'candidates_tokens_details':
880
+ output_audio_tokens = value
881
+ elif key == 'cache_tokens_details': # pragma: no branch
882
+ cache_audio_read_tokens = value
883
+
884
+ return usage.RequestUsage(
885
+ input_tokens=metadata.get('prompt_token_count', 0),
886
+ output_tokens=metadata.get('candidates_token_count', 0),
887
+ cache_read_tokens=cached_content_token_count,
888
+ input_audio_tokens=input_audio_tokens,
889
+ output_audio_tokens=output_audio_tokens,
890
+ cache_audio_read_tokens=cache_audio_read_tokens,
887
891
  details=details,
888
892
  )
889
893