pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pydantic_ai/_agent_graph.py +60 -57
  2. pydantic_ai/_cli.py +18 -3
  3. pydantic_ai/_parts_manager.py +5 -4
  4. pydantic_ai/_run_context.py +2 -2
  5. pydantic_ai/_tool_manager.py +50 -29
  6. pydantic_ai/ag_ui.py +4 -4
  7. pydantic_ai/agent/__init__.py +69 -84
  8. pydantic_ai/agent/abstract.py +16 -18
  9. pydantic_ai/agent/wrapper.py +4 -6
  10. pydantic_ai/direct.py +4 -4
  11. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  12. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  13. pydantic_ai/messages.py +16 -6
  14. pydantic_ai/models/__init__.py +5 -5
  15. pydantic_ai/models/anthropic.py +47 -46
  16. pydantic_ai/models/bedrock.py +25 -27
  17. pydantic_ai/models/cohere.py +20 -25
  18. pydantic_ai/models/fallback.py +15 -15
  19. pydantic_ai/models/function.py +7 -9
  20. pydantic_ai/models/gemini.py +43 -39
  21. pydantic_ai/models/google.py +59 -40
  22. pydantic_ai/models/groq.py +23 -19
  23. pydantic_ai/models/huggingface.py +27 -23
  24. pydantic_ai/models/instrumented.py +4 -4
  25. pydantic_ai/models/mcp_sampling.py +1 -2
  26. pydantic_ai/models/mistral.py +24 -22
  27. pydantic_ai/models/openai.py +101 -45
  28. pydantic_ai/models/test.py +4 -5
  29. pydantic_ai/profiles/__init__.py +10 -1
  30. pydantic_ai/profiles/deepseek.py +1 -1
  31. pydantic_ai/profiles/moonshotai.py +1 -1
  32. pydantic_ai/profiles/openai.py +13 -3
  33. pydantic_ai/profiles/qwen.py +4 -1
  34. pydantic_ai/providers/__init__.py +4 -0
  35. pydantic_ai/providers/huggingface.py +27 -0
  36. pydantic_ai/providers/ollama.py +105 -0
  37. pydantic_ai/providers/openai.py +1 -1
  38. pydantic_ai/providers/openrouter.py +2 -0
  39. pydantic_ai/result.py +6 -6
  40. pydantic_ai/run.py +4 -11
  41. pydantic_ai/tools.py +9 -9
  42. pydantic_ai/usage.py +229 -67
  43. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
  44. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
  45. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
  46. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
  47. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -42,7 +42,7 @@ from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
42
42
  from ..profiles._json_schema import JsonSchemaTransformer
43
43
  from ..settings import ModelSettings
44
44
  from ..tools import ToolDefinition
45
- from ..usage import Usage
45
+ from ..usage import RequestUsage
46
46
 
47
47
  KnownModelName = TypeAliasType(
48
48
  'KnownModelName',
@@ -418,7 +418,7 @@ class Model(ABC):
418
418
  messages: list[ModelMessage],
419
419
  model_settings: ModelSettings | None,
420
420
  model_request_parameters: ModelRequestParameters,
421
- ) -> Usage:
421
+ ) -> RequestUsage:
422
422
  """Make a request to the model for counting tokens."""
423
423
  # This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
424
424
  raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
@@ -480,7 +480,7 @@ class Model(ABC):
480
480
  @property
481
481
  @abstractmethod
482
482
  def system(self) -> str:
483
- """The system / model provider, ex: openai.
483
+ """The model provider, ex: openai.
484
484
 
485
485
  Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
486
486
  so should use well-known values listed in
@@ -547,7 +547,7 @@ class StreamedResponse(ABC):
547
547
 
548
548
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
549
549
  _event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
550
- _usage: Usage = field(default_factory=Usage, init=False)
550
+ _usage: RequestUsage = field(default_factory=RequestUsage, init=False)
551
551
 
552
552
  def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
553
553
  """Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
@@ -600,7 +600,7 @@ class StreamedResponse(ABC):
600
600
  usage=self.usage(),
601
601
  )
602
602
 
603
- def usage(self) -> Usage:
603
+ def usage(self) -> RequestUsage:
604
604
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
605
605
  return self._usage
606
606
 
@@ -8,14 +8,6 @@ from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
- from anthropic.types.beta import (
12
- BetaCitationsDelta,
13
- BetaCodeExecutionToolResultBlock,
14
- BetaCodeExecutionToolResultBlockParam,
15
- BetaInputJSONDelta,
16
- BetaServerToolUseBlockParam,
17
- BetaWebSearchToolResultBlockParam,
18
- )
19
11
  from typing_extensions import assert_never
20
12
 
21
13
  from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
@@ -47,24 +39,21 @@ from ..profiles import ModelProfileSpec
47
39
  from ..providers import Provider, infer_provider
48
40
  from ..settings import ModelSettings
49
41
  from ..tools import ToolDefinition
50
- from . import (
51
- Model,
52
- ModelRequestParameters,
53
- StreamedResponse,
54
- check_allow_model_requests,
55
- download_item,
56
- get_user_agent,
57
- )
42
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
58
43
 
59
44
  try:
60
45
  from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
61
46
  from anthropic.types.beta import (
62
47
  BetaBase64PDFBlockParam,
63
48
  BetaBase64PDFSourceParam,
49
+ BetaCitationsDelta,
64
50
  BetaCodeExecutionTool20250522Param,
51
+ BetaCodeExecutionToolResultBlock,
52
+ BetaCodeExecutionToolResultBlockParam,
65
53
  BetaContentBlock,
66
54
  BetaContentBlockParam,
67
55
  BetaImageBlockParam,
56
+ BetaInputJSONDelta,
68
57
  BetaMessage,
69
58
  BetaMessageParam,
70
59
  BetaMetadataParam,
@@ -78,6 +67,7 @@ try:
78
67
  BetaRawMessageStreamEvent,
79
68
  BetaRedactedThinkingBlock,
80
69
  BetaServerToolUseBlock,
70
+ BetaServerToolUseBlockParam,
81
71
  BetaSignatureDelta,
82
72
  BetaTextBlock,
83
73
  BetaTextBlockParam,
@@ -94,6 +84,7 @@ try:
94
84
  BetaToolUseBlockParam,
95
85
  BetaWebSearchTool20250305Param,
96
86
  BetaWebSearchToolResultBlock,
87
+ BetaWebSearchToolResultBlockParam,
97
88
  )
98
89
  from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
99
90
  from anthropic.types.model_param import ModelParam
@@ -146,7 +137,7 @@ class AnthropicModel(Model):
146
137
  client: AsyncAnthropic = field(repr=False)
147
138
 
148
139
  _model_name: AnthropicModelName = field(repr=False)
149
- _system: str = field(default='anthropic', repr=False)
140
+ _provider: Provider[AsyncAnthropic] = field(repr=False)
150
141
 
151
142
  def __init__(
152
143
  self,
@@ -170,6 +161,7 @@ class AnthropicModel(Model):
170
161
 
171
162
  if isinstance(provider, str):
172
163
  provider = infer_provider(provider)
164
+ self._provider = provider
173
165
  self.client = provider.client
174
166
 
175
167
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -178,6 +170,16 @@ class AnthropicModel(Model):
178
170
  def base_url(self) -> str:
179
171
  return str(self.client.base_url)
180
172
 
173
+ @property
174
+ def model_name(self) -> AnthropicModelName:
175
+ """The model name."""
176
+ return self._model_name
177
+
178
+ @property
179
+ def system(self) -> str:
180
+ """The model provider."""
181
+ return self._provider.name
182
+
181
183
  async def request(
182
184
  self,
183
185
  messages: list[ModelMessage],
@@ -189,7 +191,6 @@ class AnthropicModel(Model):
189
191
  messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
190
192
  )
191
193
  model_response = self._process_response(response)
192
- model_response.usage.requests = 1
193
194
  return model_response
194
195
 
195
196
  @asynccontextmanager
@@ -207,16 +208,6 @@ class AnthropicModel(Model):
207
208
  async with response:
208
209
  yield await self._process_streamed_response(response, model_request_parameters)
209
210
 
210
- @property
211
- def model_name(self) -> AnthropicModelName:
212
- """The model name."""
213
- return self._model_name
214
-
215
- @property
216
- def system(self) -> str:
217
- """The system / model provider."""
218
- return self._system
219
-
220
211
  @overload
221
212
  async def _messages_create(
222
213
  self,
@@ -246,7 +237,9 @@ class AnthropicModel(Model):
246
237
  ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
247
238
  # standalone function to make it easier to override
248
239
  tools = self._get_tools(model_request_parameters)
249
- tools += self._get_builtin_tools(model_request_parameters)
240
+ builtin_tools, tool_headers = self._get_builtin_tools(model_request_parameters)
241
+ tools += builtin_tools
242
+
250
243
  tool_choice: BetaToolChoiceParam | None
251
244
 
252
245
  if not tools:
@@ -264,8 +257,10 @@ class AnthropicModel(Model):
264
257
 
265
258
  try:
266
259
  extra_headers = model_settings.get('extra_headers', {})
260
+ for k, v in tool_headers.items():
261
+ extra_headers.setdefault(k, v)
267
262
  extra_headers.setdefault('User-Agent', get_user_agent())
268
- extra_headers.setdefault('anthropic-beta', 'code-execution-2025-05-22')
263
+
269
264
  return await self.client.beta.messages.create(
270
265
  max_tokens=model_settings.get('max_tokens', 4096),
271
266
  system=system_prompt or NOT_GIVEN,
@@ -330,7 +325,9 @@ class AnthropicModel(Model):
330
325
  )
331
326
  )
332
327
 
333
- return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
328
+ return ModelResponse(
329
+ items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id
330
+ )
334
331
 
335
332
  async def _process_streamed_response(
336
333
  self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
@@ -352,8 +349,11 @@ class AnthropicModel(Model):
352
349
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
353
350
  return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
354
351
 
355
- def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
352
+ def _get_builtin_tools(
353
+ self, model_request_parameters: ModelRequestParameters
354
+ ) -> tuple[list[BetaToolUnionParam], dict[str, str]]:
356
355
  tools: list[BetaToolUnionParam] = []
356
+ extra_headers: dict[str, str] = {}
357
357
  for tool in model_request_parameters.builtin_tools:
358
358
  if isinstance(tool, WebSearchTool):
359
359
  user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
@@ -361,18 +361,20 @@ class AnthropicModel(Model):
361
361
  BetaWebSearchTool20250305Param(
362
362
  name='web_search',
363
363
  type='web_search_20250305',
364
+ max_uses=tool.max_uses,
364
365
  allowed_domains=tool.allowed_domains,
365
366
  blocked_domains=tool.blocked_domains,
366
367
  user_location=user_location,
367
368
  )
368
369
  )
369
370
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
371
+ extra_headers['anthropic-beta'] = 'code-execution-2025-05-22'
370
372
  tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
371
373
  else: # pragma: no cover
372
374
  raise UserError(
373
375
  f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
374
376
  )
375
- return tools
377
+ return tools, extra_headers
376
378
 
377
379
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
378
380
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
@@ -528,7 +530,7 @@ class AnthropicModel(Model):
528
530
  }
529
531
 
530
532
 
531
- def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
533
+ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
532
534
  if isinstance(message, BetaMessage):
533
535
  response_usage = message.usage
534
536
  elif isinstance(message, BetaRawMessageStartEvent):
@@ -541,7 +543,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
541
543
  # - RawContentBlockStartEvent
542
544
  # - RawContentBlockDeltaEvent
543
545
  # - RawContentBlockStopEvent
544
- return usage.Usage()
546
+ return usage.RequestUsage()
545
547
 
546
548
  # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
547
549
  # `response_tokens`
@@ -552,17 +554,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Usage:
552
554
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
553
555
  # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
554
556
  # This approach maintains request_tokens as the count of all input tokens, with cached counts as details
555
- request_tokens = (
556
- details.get('input_tokens', 0)
557
- + details.get('cache_creation_input_tokens', 0)
558
- + details.get('cache_read_input_tokens', 0)
559
- )
560
-
561
- return usage.Usage(
562
- request_tokens=request_tokens or None,
563
- response_tokens=response_usage.output_tokens,
564
- total_tokens=request_tokens + response_usage.output_tokens,
565
- details=details or None,
557
+ cache_write_tokens = details.get('cache_creation_input_tokens', 0)
558
+ cache_read_tokens = details.get('cache_read_input_tokens', 0)
559
+ request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
560
+
561
+ return usage.RequestUsage(
562
+ input_tokens=request_tokens,
563
+ cache_read_tokens=cache_read_tokens,
564
+ cache_write_tokens=cache_write_tokens,
565
+ output_tokens=response_usage.output_tokens,
566
+ details=details,
566
567
  )
567
568
 
568
569
 
@@ -190,17 +190,7 @@ class BedrockConverseModel(Model):
190
190
  client: BedrockRuntimeClient
191
191
 
192
192
  _model_name: BedrockModelName = field(repr=False)
193
- _system: str = field(default='bedrock', repr=False)
194
-
195
- @property
196
- def model_name(self) -> str:
197
- """The model name."""
198
- return self._model_name
199
-
200
- @property
201
- def system(self) -> str:
202
- """The system / model provider, ex: openai."""
203
- return self._system
193
+ _provider: Provider[BaseClient] = field(repr=False)
204
194
 
205
195
  def __init__(
206
196
  self,
@@ -226,10 +216,25 @@ class BedrockConverseModel(Model):
226
216
 
227
217
  if isinstance(provider, str):
228
218
  provider = infer_provider(provider)
219
+ self._provider = provider
229
220
  self.client = cast('BedrockRuntimeClient', provider.client)
230
221
 
231
222
  super().__init__(settings=settings, profile=profile or provider.model_profile)
232
223
 
224
+ @property
225
+ def base_url(self) -> str:
226
+ return str(self.client.meta.endpoint_url)
227
+
228
+ @property
229
+ def model_name(self) -> str:
230
+ """The model name."""
231
+ return self._model_name
232
+
233
+ @property
234
+ def system(self) -> str:
235
+ """The model provider."""
236
+ return self._provider.name
237
+
233
238
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
234
239
  return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
235
240
 
@@ -245,10 +250,6 @@ class BedrockConverseModel(Model):
245
250
 
246
251
  return {'toolSpec': tool_spec}
247
252
 
248
- @property
249
- def base_url(self) -> str:
250
- return str(self.client.meta.endpoint_url)
251
-
252
253
  async def request(
253
254
  self,
254
255
  messages: list[ModelMessage],
@@ -258,7 +259,6 @@ class BedrockConverseModel(Model):
258
259
  settings = cast(BedrockModelSettings, model_settings or {})
259
260
  response = await self._messages_create(messages, False, settings, model_request_parameters)
260
261
  model_response = await self._process_response(response)
261
- model_response.usage.requests = 1
262
262
  return model_response
263
263
 
264
264
  @asynccontextmanager
@@ -299,13 +299,12 @@ class BedrockConverseModel(Model):
299
299
  tool_call_id=tool_use['toolUseId'],
300
300
  ),
301
301
  )
302
- u = usage.Usage(
303
- request_tokens=response['usage']['inputTokens'],
304
- response_tokens=response['usage']['outputTokens'],
305
- total_tokens=response['usage']['totalTokens'],
302
+ u = usage.RequestUsage(
303
+ input_tokens=response['usage']['inputTokens'],
304
+ output_tokens=response['usage']['outputTokens'],
306
305
  )
307
306
  vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
308
- return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id)
307
+ return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id)
309
308
 
310
309
  @overload
311
310
  async def _messages_create(
@@ -648,7 +647,7 @@ class BedrockStreamedResponse(StreamedResponse):
648
647
  )
649
648
  if 'text' in delta:
650
649
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
651
- if maybe_event is not None:
650
+ if maybe_event is not None: # pragma: no branch
652
651
  yield maybe_event
653
652
  if 'toolUse' in delta:
654
653
  tool_use = delta['toolUse']
@@ -670,11 +669,10 @@ class BedrockStreamedResponse(StreamedResponse):
670
669
  """Get the model name of the response."""
671
670
  return self._model_name
672
671
 
673
- def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.Usage:
674
- return usage.Usage(
675
- request_tokens=metadata['usage']['inputTokens'],
676
- response_tokens=metadata['usage']['outputTokens'],
677
- total_tokens=metadata['usage']['totalTokens'],
672
+ def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
673
+ return usage.RequestUsage(
674
+ input_tokens=metadata['usage']['inputTokens'],
675
+ output_tokens=metadata['usage']['outputTokens'],
678
676
  )
679
677
 
680
678
 
@@ -30,11 +30,7 @@ from ..profiles import ModelProfileSpec
30
30
  from ..providers import Provider, infer_provider
31
31
  from ..settings import ModelSettings
32
32
  from ..tools import ToolDefinition
33
- from . import (
34
- Model,
35
- ModelRequestParameters,
36
- check_allow_model_requests,
37
- )
33
+ from . import Model, ModelRequestParameters, check_allow_model_requests
38
34
 
39
35
  try:
40
36
  from cohere import (
@@ -106,7 +102,7 @@ class CohereModel(Model):
106
102
  client: AsyncClientV2 = field(repr=False)
107
103
 
108
104
  _model_name: CohereModelName = field(repr=False)
109
- _system: str = field(default='cohere', repr=False)
105
+ _provider: Provider[AsyncClientV2] = field(repr=False)
110
106
 
111
107
  def __init__(
112
108
  self,
@@ -131,6 +127,7 @@ class CohereModel(Model):
131
127
 
132
128
  if isinstance(provider, str):
133
129
  provider = infer_provider(provider)
130
+ self._provider = provider
134
131
  self.client = provider.client
135
132
 
136
133
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -140,6 +137,16 @@ class CohereModel(Model):
140
137
  client_wrapper = self.client._client_wrapper # type: ignore
141
138
  return str(client_wrapper.get_base_url())
142
139
 
140
+ @property
141
+ def model_name(self) -> CohereModelName:
142
+ """The model name."""
143
+ return self._model_name
144
+
145
+ @property
146
+ def system(self) -> str:
147
+ """The model provider."""
148
+ return self._provider.name
149
+
143
150
  async def request(
144
151
  self,
145
152
  messages: list[ModelMessage],
@@ -149,19 +156,8 @@ class CohereModel(Model):
149
156
  check_allow_model_requests()
150
157
  response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
151
158
  model_response = self._process_response(response)
152
- model_response.usage.requests = 1
153
159
  return model_response
154
160
 
155
- @property
156
- def model_name(self) -> CohereModelName:
157
- """The model name."""
158
- return self._model_name
159
-
160
- @property
161
- def system(self) -> str:
162
- """The system / model provider."""
163
- return self._system
164
-
165
161
  async def _chat(
166
162
  self,
167
163
  messages: list[ModelMessage],
@@ -301,10 +297,10 @@ class CohereModel(Model):
301
297
  assert_never(part)
302
298
 
303
299
 
304
- def _map_usage(response: V2ChatResponse) -> usage.Usage:
300
+ def _map_usage(response: V2ChatResponse) -> usage.RequestUsage:
305
301
  u = response.usage
306
302
  if u is None:
307
- return usage.Usage()
303
+ return usage.RequestUsage()
308
304
  else:
309
305
  details: dict[str, int] = {}
310
306
  if u.billed_units is not None:
@@ -317,11 +313,10 @@ def _map_usage(response: V2ChatResponse) -> usage.Usage:
317
313
  if u.billed_units.classifications: # pragma: no cover
318
314
  details['classifications'] = int(u.billed_units.classifications)
319
315
 
320
- request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else None
321
- response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else None
322
- return usage.Usage(
323
- request_tokens=request_tokens,
324
- response_tokens=response_tokens,
325
- total_tokens=(request_tokens or 0) + (response_tokens or 0),
316
+ request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0
317
+ response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0
318
+ return usage.RequestUsage(
319
+ input_tokens=request_tokens,
320
+ output_tokens=response_tokens,
326
321
  details=details,
327
322
  )
@@ -33,8 +33,8 @@ class FallbackModel(Model):
33
33
 
34
34
  def __init__(
35
35
  self,
36
- default_model: Model | KnownModelName,
37
- *fallback_models: Model | KnownModelName,
36
+ default_model: Model | KnownModelName | str,
37
+ *fallback_models: Model | KnownModelName | str,
38
38
  fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
39
39
  ):
40
40
  """Initialize a fallback model instance.
@@ -52,6 +52,19 @@ class FallbackModel(Model):
52
52
  else:
53
53
  self._fallback_on = fallback_on
54
54
 
55
+ @property
56
+ def model_name(self) -> str:
57
+ """The model name."""
58
+ return f'fallback:{",".join(model.model_name for model in self.models)}'
59
+
60
+ @property
61
+ def system(self) -> str:
62
+ return f'fallback:{",".join(model.system for model in self.models)}'
63
+
64
+ @property
65
+ def base_url(self) -> str | None:
66
+ return self.models[0].base_url
67
+
55
68
  async def request(
56
69
  self,
57
70
  messages: list[ModelMessage],
@@ -121,19 +134,6 @@ class FallbackModel(Model):
121
134
  if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
122
135
  span.set_attributes(InstrumentedModel.model_attributes(model))
123
136
 
124
- @property
125
- def model_name(self) -> str:
126
- """The model name."""
127
- return f'fallback:{",".join(model.model_name for model in self.models)}'
128
-
129
- @property
130
- def system(self) -> str:
131
- return f'fallback:{",".join(model.system for model in self.models)}'
132
-
133
- @property
134
- def base_url(self) -> str | None:
135
- return self.models[0].base_url
136
-
137
137
 
138
138
  def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
139
139
  """Create a default fallback condition for the given exceptions."""
@@ -138,7 +138,6 @@ class FunctionModel(Model):
138
138
  # Add usage data if not already present
139
139
  if not response.usage.has_values(): # pragma: no branch
140
140
  response.usage = _estimate_usage(chain(messages, [response]))
141
- response.usage.requests = 1
142
141
  return response
143
142
 
144
143
  @asynccontextmanager
@@ -270,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
270
269
  async for item in self._iter:
271
270
  if isinstance(item, str):
272
271
  response_tokens = _estimate_string_tokens(item)
273
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
272
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
274
273
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
275
274
  if maybe_event is not None: # pragma: no branch
276
275
  yield maybe_event
@@ -279,7 +278,7 @@ class FunctionStreamedResponse(StreamedResponse):
279
278
  if isinstance(delta, DeltaThinkingPart):
280
279
  if delta.content: # pragma: no branch
281
280
  response_tokens = _estimate_string_tokens(delta.content)
282
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
281
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
283
282
  yield self._parts_manager.handle_thinking_delta(
284
283
  vendor_part_id=dtc_index,
285
284
  content=delta.content,
@@ -288,7 +287,7 @@ class FunctionStreamedResponse(StreamedResponse):
288
287
  elif isinstance(delta, DeltaToolCall):
289
288
  if delta.json_args:
290
289
  response_tokens = _estimate_string_tokens(delta.json_args)
291
- self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
290
+ self._usage += usage.RequestUsage(output_tokens=response_tokens)
292
291
  maybe_event = self._parts_manager.handle_tool_call_delta(
293
292
  vendor_part_id=dtc_index,
294
293
  tool_name=delta.name,
@@ -311,7 +310,7 @@ class FunctionStreamedResponse(StreamedResponse):
311
310
  return self._timestamp
312
311
 
313
312
 
314
- def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
313
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
315
314
  """Very rough guesstimate of the token usage associated with a series of messages.
316
315
 
317
316
  This is designed to be used solely to give plausible numbers for testing!
@@ -349,10 +348,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
349
348
  assert_never(part)
350
349
  else:
351
350
  assert_never(message)
352
- return usage.Usage(
353
- request_tokens=request_tokens,
354
- response_tokens=response_tokens,
355
- total_tokens=request_tokens + response_tokens,
351
+ return usage.RequestUsage(
352
+ input_tokens=request_tokens,
353
+ output_tokens=response_tokens,
356
354
  )
357
355
 
358
356