pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -8,6 +8,7 @@ from datetime import datetime, timezone
8
8
  from itertools import chain
9
9
  from typing import Any, Callable, Literal, Union
10
10
 
11
+ import pydantic_core
11
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
12
13
  from typing_extensions import assert_never
13
14
 
@@ -26,7 +27,7 @@ from ..messages import (
26
27
  ToolReturnPart,
27
28
  UserPromptPart,
28
29
  )
29
- from ..result import Cost
30
+ from ..result import Usage
30
31
  from ..settings import ModelSettings
31
32
  from ..tools import ToolDefinition
32
33
  from . import (
@@ -39,7 +40,6 @@ from . import (
39
40
  )
40
41
 
41
42
  try:
42
- from json_repair import repair_json
43
43
  from mistralai import (
44
44
  UNSET,
45
45
  CompletionChunk as MistralCompletionChunk,
@@ -156,10 +156,10 @@ class MistralAgentModel(AgentModel):
156
156
 
157
157
  async def request(
158
158
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
159
- ) -> tuple[ModelResponse, Cost]:
159
+ ) -> tuple[ModelResponse, Usage]:
160
160
  """Make a non-streaming request to the model from Pydantic AI call."""
161
161
  response = await self._completions_create(messages, model_settings)
162
- return self._process_response(response), _map_cost(response)
162
+ return self._process_response(response), _map_usage(response)
163
163
 
164
164
  @asynccontextmanager
165
165
  async def request_stream(
@@ -198,11 +198,10 @@ class MistralAgentModel(AgentModel):
198
198
  """Create a streaming completion request to the Mistral model."""
199
199
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
200
200
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
201
-
202
201
  model_settings = model_settings or {}
203
202
 
204
203
  if self.result_tools and self.function_tools or self.function_tools:
205
- # Function Calling Mode
204
+ # Function Calling
206
205
  response = await self.client.chat.stream_async(
207
206
  model=str(self.model_name),
208
207
  messages=mistral_messages,
@@ -218,9 +217,9 @@ class MistralAgentModel(AgentModel):
218
217
  elif self.result_tools:
219
218
  # Json Mode
220
219
  parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
221
-
222
220
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
223
221
  mistral_messages.append(user_output_format_message)
222
+
224
223
  response = await self.client.chat.stream_async(
225
224
  model=str(self.model_name),
226
225
  messages=mistral_messages,
@@ -270,12 +269,13 @@ class MistralAgentModel(AgentModel):
270
269
  @staticmethod
271
270
  def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
272
271
  """Process a non-streamed response, and prepare a message to return."""
272
+ assert response.choices, 'Unexpected empty response choice.'
273
+
273
274
  if response.created:
274
275
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
275
276
  else:
276
277
  timestamp = _now_utc()
277
278
 
278
- assert response.choices, 'Unexpected empty response choice.'
279
279
  choice = response.choices[0]
280
280
  content = choice.message.content
281
281
  tool_calls = choice.message.tool_calls
@@ -297,7 +297,7 @@ class MistralAgentModel(AgentModel):
297
297
  response: MistralEventStreamAsync[MistralCompletionEvent],
298
298
  ) -> EitherStreamedResponse:
299
299
  """Process a streamed response, and prepare a streaming response to return."""
300
- start_cost = Cost()
300
+ start_usage = Usage()
301
301
 
302
302
  # Iterate until we get either `tool_calls` or `content` from the first chunk.
303
303
  while True:
@@ -307,7 +307,7 @@ class MistralAgentModel(AgentModel):
307
307
  except StopAsyncIteration as e:
308
308
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
309
309
 
310
- start_cost += _map_cost(chunk)
310
+ start_usage += _map_usage(chunk)
311
311
 
312
312
  if chunk.created:
313
313
  timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
@@ -329,11 +329,11 @@ class MistralAgentModel(AgentModel):
329
329
  response,
330
330
  content,
331
331
  timestamp,
332
- start_cost,
332
+ start_usage,
333
333
  )
334
334
 
335
335
  elif content:
336
- return MistralStreamTextResponse(content, response, timestamp, start_cost)
336
+ return MistralStreamTextResponse(content, response, timestamp, start_usage)
337
337
 
338
338
  @staticmethod
339
339
  def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
@@ -474,7 +474,7 @@ class MistralStreamTextResponse(StreamTextResponse):
474
474
  _first: str | None
475
475
  _response: MistralEventStreamAsync[MistralCompletionEvent]
476
476
  _timestamp: datetime
477
- _cost: Cost
477
+ _usage: Usage
478
478
  _buffer: list[str] = field(default_factory=list, init=False)
479
479
 
480
480
  async def __anext__(self) -> None:
@@ -484,7 +484,7 @@ class MistralStreamTextResponse(StreamTextResponse):
484
484
  return None
485
485
 
486
486
  chunk = await self._response.__anext__()
487
- self._cost += _map_cost(chunk.data)
487
+ self._usage += _map_usage(chunk.data)
488
488
 
489
489
  try:
490
490
  choice = chunk.data.choices[0]
@@ -502,8 +502,8 @@ class MistralStreamTextResponse(StreamTextResponse):
502
502
  yield from self._buffer
503
503
  self._buffer.clear()
504
504
 
505
- def cost(self) -> Cost:
506
- return self._cost
505
+ def usage(self) -> Usage:
506
+ return self._usage
507
507
 
508
508
  def timestamp(self) -> datetime:
509
509
  return self._timestamp
@@ -518,11 +518,11 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
518
518
  _response: MistralEventStreamAsync[MistralCompletionEvent]
519
519
  _delta_content: str | None
520
520
  _timestamp: datetime
521
- _cost: Cost
521
+ _usage: Usage
522
522
 
523
523
  async def __anext__(self) -> None:
524
524
  chunk = await self._response.__anext__()
525
- self._cost += _map_cost(chunk.data)
525
+ self._usage += _map_usage(chunk.data)
526
526
 
527
527
  try:
528
528
  choice = chunk.data.choices[0]
@@ -546,39 +546,31 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
546
546
  calls.append(tool)
547
547
 
548
548
  elif self._delta_content and self._result_tools:
549
- # NOTE: Params set for the most efficient and fastest way.
550
- output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
551
- assert isinstance(
552
- output_json, dict
553
- ), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
549
+ output_json: dict[str, Any] | None = pydantic_core.from_json(
550
+ self._delta_content, allow_partial='trailing-strings'
551
+ )
554
552
 
555
553
  if output_json:
556
554
  for result_tool in self._result_tools.values():
557
- # NOTE: Additional verification to prevent JSON validation to crash in `result.py`
555
+ # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
558
556
  # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
559
- # For example, `return_type=list[str]` expects a 'response' key with value type array of str.
560
- # when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
561
- # when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
562
- # This ensures it's corrected to `{"response": {}}` and other required parameters and type.
563
- if not self._validate_required_json_shema(output_json, result_tool.parameters_json_schema):
557
+ # Example with BaseModel and required fields.
558
+ if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
564
559
  continue
565
560
 
566
- tool = ToolCallPart.from_dict(
567
- tool_name=result_tool.name,
568
- args_dict=output_json,
569
- )
561
+ tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
570
562
  calls.append(tool)
571
563
 
572
564
  return ModelResponse(calls, timestamp=self._timestamp)
573
565
 
574
- def cost(self) -> Cost:
575
- return self._cost
566
+ def usage(self) -> Usage:
567
+ return self._usage
576
568
 
577
569
  def timestamp(self) -> datetime:
578
570
  return self._timestamp
579
571
 
580
572
  @staticmethod
581
- def _validate_required_json_shema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
573
+ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
582
574
  """Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
583
575
  required_params = json_schema.get('required', [])
584
576
  properties = json_schema.get('properties', {})
@@ -602,7 +594,7 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
602
594
 
603
595
  if isinstance(json_dict[param], dict) and 'properties' in param_schema:
604
596
  nested_schema = param_schema
605
- if not MistralStreamStructuredResponse._validate_required_json_shema(json_dict[param], nested_schema):
597
+ if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
606
598
  return False
607
599
 
608
600
  return True
@@ -633,29 +625,20 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
633
625
  tool_call_id = tool_call.id or None
634
626
  func_call = tool_call.function
635
627
 
636
- if isinstance(func_call.arguments, str):
637
- return ToolCallPart.from_json(
638
- tool_name=func_call.name,
639
- args_json=func_call.arguments,
640
- tool_call_id=tool_call_id,
641
- )
642
- else:
643
- return ToolCallPart.from_dict(
644
- tool_name=func_call.name, args_dict=func_call.arguments, tool_call_id=tool_call_id
645
- )
628
+ return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
646
629
 
647
630
 
648
- def _map_cost(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Cost:
649
- """Maps a Mistral Completion Chunk or Chat Completion Response to a Cost."""
631
+ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
632
+ """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
650
633
  if response.usage:
651
- return Cost(
634
+ return Usage(
652
635
  request_tokens=response.usage.prompt_tokens,
653
636
  response_tokens=response.usage.completion_tokens,
654
637
  total_tokens=response.usage.total_tokens,
655
638
  details=None,
656
639
  )
657
640
  else:
658
- return Cost()
641
+ return Usage()
659
642
 
660
643
 
661
644
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, result
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
- ArgsJson,
17
16
  ModelMessage,
18
17
  ModelRequest,
19
18
  ModelResponse,
@@ -25,7 +24,7 @@ from ..messages import (
25
24
  ToolReturnPart,
26
25
  UserPromptPart,
27
26
  )
28
- from ..result import Cost
27
+ from ..result import Usage
29
28
  from ..settings import ModelSettings
30
29
  from ..tools import ToolDefinition
31
30
  from . import (
@@ -147,9 +146,9 @@ class OpenAIAgentModel(AgentModel):
147
146
 
148
147
  async def request(
149
148
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
150
- ) -> tuple[ModelResponse, result.Cost]:
149
+ ) -> tuple[ModelResponse, result.Usage]:
151
150
  response = await self._completions_create(messages, False, model_settings)
152
- return self._process_response(response), _map_cost(response)
151
+ return self._process_response(response), _map_usage(response)
153
152
 
154
153
  @asynccontextmanager
155
154
  async def request_stream(
@@ -211,14 +210,14 @@ class OpenAIAgentModel(AgentModel):
211
210
  items.append(TextPart(choice.message.content))
212
211
  if choice.message.tool_calls is not None:
213
212
  for c in choice.message.tool_calls:
214
- items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
213
+ items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
215
214
  return ModelResponse(items, timestamp=timestamp)
216
215
 
217
216
  @staticmethod
218
217
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
219
218
  """Process a streamed response, and prepare a streaming response to return."""
220
219
  timestamp: datetime | None = None
221
- start_cost = Cost()
220
+ start_usage = Usage()
222
221
  # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
223
222
  while True:
224
223
  try:
@@ -227,19 +226,19 @@ class OpenAIAgentModel(AgentModel):
227
226
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
228
227
 
229
228
  timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
230
- start_cost += _map_cost(chunk)
229
+ start_usage += _map_usage(chunk)
231
230
 
232
231
  if chunk.choices:
233
232
  delta = chunk.choices[0].delta
234
233
 
235
234
  if delta.content is not None:
236
- return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
235
+ return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
237
236
  elif delta.tool_calls is not None:
238
237
  return OpenAIStreamStructuredResponse(
239
238
  response,
240
239
  {c.index: c for c in delta.tool_calls},
241
240
  timestamp,
242
- start_cost,
241
+ start_usage,
243
242
  )
244
243
  # else continue until we get either delta.content or delta.tool_calls
245
244
 
@@ -302,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
302
301
  _first: str | None
303
302
  _response: AsyncStream[ChatCompletionChunk]
304
303
  _timestamp: datetime
305
- _cost: result.Cost
304
+ _usage: result.Usage
306
305
  _buffer: list[str] = field(default_factory=list, init=False)
307
306
 
308
307
  async def __anext__(self) -> None:
@@ -312,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
312
311
  return None
313
312
 
314
313
  chunk = await self._response.__anext__()
315
- self._cost += _map_cost(chunk)
314
+ self._usage += _map_usage(chunk)
316
315
  try:
317
316
  choice = chunk.choices[0]
318
317
  except IndexError:
@@ -328,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
328
327
  yield from self._buffer
329
328
  self._buffer.clear()
330
329
 
331
- def cost(self) -> Cost:
332
- return self._cost
330
+ def usage(self) -> Usage:
331
+ return self._usage
333
332
 
334
333
  def timestamp(self) -> datetime:
335
334
  return self._timestamp
@@ -342,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
342
341
  _response: AsyncStream[ChatCompletionChunk]
343
342
  _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
344
343
  _timestamp: datetime
345
- _cost: result.Cost
344
+ _usage: result.Usage
346
345
 
347
346
  async def __anext__(self) -> None:
348
347
  chunk = await self._response.__anext__()
349
- self._cost += _map_cost(chunk)
348
+ self._usage += _map_usage(chunk)
350
349
  try:
351
350
  choice = chunk.choices[0]
352
351
  except IndexError:
@@ -372,37 +371,36 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
372
371
  for c in self._delta_tool_calls.values():
373
372
  if f := c.function:
374
373
  if f.name is not None and f.arguments is not None:
375
- items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
374
+ items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
376
375
 
377
376
  return ModelResponse(items, timestamp=self._timestamp)
378
377
 
379
- def cost(self) -> Cost:
380
- return self._cost
378
+ def usage(self) -> Usage:
379
+ return self._usage
381
380
 
382
381
  def timestamp(self) -> datetime:
383
382
  return self._timestamp
384
383
 
385
384
 
386
385
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
387
- assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
388
386
  return chat.ChatCompletionMessageToolCallParam(
389
387
  id=_guard_tool_call_id(t=t, model_source='OpenAI'),
390
388
  type='function',
391
- function={'name': t.tool_name, 'arguments': t.args.args_json},
389
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
392
390
  )
393
391
 
394
392
 
395
- def _map_cost(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Cost:
393
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
396
394
  usage = response.usage
397
395
  if usage is None:
398
- return result.Cost()
396
+ return result.Usage()
399
397
  else:
400
398
  details: dict[str, int] = {}
401
399
  if usage.completion_tokens_details is not None:
402
400
  details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
403
401
  if usage.prompt_tokens_details is not None:
404
402
  details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
405
- return result.Cost(
403
+ return result.Usage(
406
404
  request_tokens=usage.prompt_tokens,
407
405
  response_tokens=usage.completion_tokens,
408
406
  total_tokens=usage.total_tokens,
@@ -21,7 +21,7 @@ from ..messages import (
21
21
  ToolCallPart,
22
22
  ToolReturnPart,
23
23
  )
24
- from ..result import Cost
24
+ from ..result import Usage
25
25
  from ..settings import ModelSettings
26
26
  from ..tools import ToolDefinition
27
27
  from . import (
@@ -31,6 +31,7 @@ from . import (
31
31
  StreamStructuredResponse,
32
32
  StreamTextResponse,
33
33
  )
34
+ from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
34
35
 
35
36
 
36
37
  @dataclass
@@ -131,15 +132,17 @@ class TestAgentModel(AgentModel):
131
132
 
132
133
  async def request(
133
134
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
134
- ) -> tuple[ModelResponse, Cost]:
135
- return self._request(messages, model_settings), Cost()
135
+ ) -> tuple[ModelResponse, Usage]:
136
+ model_response = self._request(messages, model_settings)
137
+ usage = _estimate_usage([*messages, model_response])
138
+ return model_response, usage
136
139
 
137
140
  @asynccontextmanager
138
141
  async def request_stream(
139
142
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
140
143
  ) -> AsyncIterator[EitherStreamedResponse]:
141
144
  msg = self._request(messages, model_settings)
142
- cost = Cost()
145
+ usage = _estimate_usage(messages)
143
146
 
144
147
  # TODO: Rework this once we make StreamTextResponse more general
145
148
  texts: list[str] = []
@@ -153,9 +156,9 @@ class TestAgentModel(AgentModel):
153
156
  assert_never(item)
154
157
 
155
158
  if texts:
156
- yield TestStreamTextResponse('\n\n'.join(texts), cost)
159
+ yield TestStreamTextResponse('\n\n'.join(texts), usage)
157
160
  else:
158
- yield TestStreamStructuredResponse(msg, cost)
161
+ yield TestStreamStructuredResponse(msg, usage)
159
162
 
160
163
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
161
164
  return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
@@ -164,7 +167,7 @@ class TestAgentModel(AgentModel):
164
167
  # if there are tools, the first thing we want to do is call all of them
165
168
  if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
166
169
  return ModelResponse(
167
- parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
170
+ parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
168
171
  )
169
172
 
170
173
  if messages:
@@ -176,7 +179,7 @@ class TestAgentModel(AgentModel):
176
179
  if new_retry_names:
177
180
  return ModelResponse(
178
181
  parts=[
179
- ToolCallPart.from_dict(name, self.gen_tool_args(args))
182
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
180
183
  for name, args in self.tool_calls
181
184
  if name in new_retry_names
182
185
  ]
@@ -202,10 +205,10 @@ class TestAgentModel(AgentModel):
202
205
  custom_result_args = self.result.right
203
206
  result_tool = self.result_tools[self.seed % len(self.result_tools)]
204
207
  if custom_result_args is not None:
205
- return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
208
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
206
209
  else:
207
210
  response_args = self.gen_tool_args(result_tool)
208
- return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
211
+ return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
209
212
 
210
213
 
211
214
  @dataclass
@@ -213,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
213
216
  """A text response that streams test data."""
214
217
 
215
218
  _text: str
216
- _cost: Cost
219
+ _usage: Usage
217
220
  _iter: Iterator[str] = field(init=False)
218
221
  _timestamp: datetime = field(default_factory=_utils.now_utc)
219
222
  _buffer: list[str] = field(default_factory=list, init=False)
@@ -228,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
228
231
  self._iter = iter(words)
229
232
 
230
233
  async def __anext__(self) -> None:
231
- self._buffer.append(_utils.sync_anext(self._iter))
234
+ next_str = _utils.sync_anext(self._iter)
235
+ response_tokens = _estimate_string_usage(next_str)
236
+ self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
237
+ self._buffer.append(next_str)
232
238
 
233
239
  def get(self, *, final: bool = False) -> Iterable[str]:
234
240
  yield from self._buffer
235
241
  self._buffer.clear()
236
242
 
237
- def cost(self) -> Cost:
238
- return self._cost
243
+ def usage(self) -> Usage:
244
+ return self._usage
239
245
 
240
246
  def timestamp(self) -> datetime:
241
247
  return self._timestamp
@@ -246,7 +252,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
246
252
  """A structured response that streams test data."""
247
253
 
248
254
  _structured_response: ModelResponse
249
- _cost: Cost
255
+ _usage: Usage
250
256
  _iter: Iterator[None] = field(default_factory=lambda: iter([None]))
251
257
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
252
258
 
@@ -256,8 +262,8 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
256
262
  def get(self, *, final: bool = False) -> ModelResponse:
257
263
  return self._structured_response
258
264
 
259
- def cost(self) -> Cost:
260
- return self._cost
265
+ def usage(self) -> Usage:
266
+ return self._usage
261
267
 
262
268
  def timestamp(self) -> datetime:
263
269
  return self._timestamp