pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/exceptions.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
4
 
5
- __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
5
+ __all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6
6
 
7
7
 
8
8
  class ModelRetry(Exception):
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
30
30
  super().__init__(message)
31
31
 
32
32
 
33
- class UnexpectedModelBehavior(RuntimeError):
33
+ class AgentRunError(RuntimeError):
34
+ """Base class for errors occurring during an agent run."""
35
+
36
+ message: str
37
+ """The error message."""
38
+
39
+ def __init__(self, message: str):
40
+ self.message = message
41
+ super().__init__(message)
42
+
43
+ def __str__(self) -> str:
44
+ return self.message
45
+
46
+
47
+ class UsageLimitExceeded(AgentRunError):
48
+ """Error raised when a Model's usage exceeds the specified limits."""
49
+
50
+
51
+ class UnexpectedModelBehavior(AgentRunError):
34
52
  """Error caused by unexpected Model behavior, e.g. an unexpected response code."""
35
53
 
36
54
  message: str
pydantic_ai/messages.py CHANGED
@@ -2,11 +2,11 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union
5
+ from typing import Annotated, Any, Literal, Union, cast
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from typing_extensions import Self
9
+ from typing_extensions import Self, assert_never
10
10
 
11
11
  from ._utils import now_utc as _now_utc
12
12
 
@@ -190,12 +190,34 @@ class ToolCallPart:
190
190
  """Part type identifier, this is available on all parts as a discriminator."""
191
191
 
192
192
  @classmethod
193
- def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
194
- return cls(tool_name, ArgsJson(args_json), tool_call_id)
193
+ def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
194
+ """Create a `ToolCallPart` from raw arguments."""
195
+ if isinstance(args, str):
196
+ return cls(tool_name, ArgsJson(args), tool_call_id)
197
+ elif isinstance(args, dict):
198
+ return cls(tool_name, ArgsDict(args), tool_call_id)
199
+ else:
200
+ assert_never(args)
195
201
 
196
- @classmethod
197
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
198
- return cls(tool_name, ArgsDict(args_dict), tool_call_id)
202
+ def args_as_dict(self) -> dict[str, Any]:
203
+ """Return the arguments as a Python dictionary.
204
+
205
+ This is just for convenience with models that require dicts as input.
206
+ """
207
+ if isinstance(self.args, ArgsDict):
208
+ return self.args.args_dict
209
+ args = pydantic_core.from_json(self.args.args_json)
210
+ assert isinstance(args, dict), 'args should be a dict'
211
+ return cast(dict[str, Any], args)
212
+
213
+ def args_as_json_str(self) -> str:
214
+ """Return the arguments as a JSON string.
215
+
216
+ This is just for convenience with models that require JSON strings as input.
217
+ """
218
+ if isinstance(self.args, ArgsJson):
219
+ return self.args.args_json
220
+ return pydantic_core.to_json(self.args.args_dict).decode()
199
221
 
200
222
  def has_content(self) -> bool:
201
223
  if isinstance(self.args, ArgsDict):
@@ -20,7 +20,7 @@ from ..messages import ModelMessage, ModelResponse
20
20
  from ..settings import ModelSettings
21
21
 
22
22
  if TYPE_CHECKING:
23
- from ..result import Cost
23
+ from ..result import Usage
24
24
  from ..tools import ToolDefinition
25
25
 
26
26
 
@@ -31,6 +31,7 @@ KnownModelName = Literal[
31
31
  'openai:gpt-4',
32
32
  'openai:o1-preview',
33
33
  'openai:o1-mini',
34
+ 'openai:o1',
34
35
  'openai:gpt-3.5-turbo',
35
36
  'groq:llama-3.3-70b-versatile',
36
37
  'groq:llama-3.1-70b-versatile',
@@ -122,7 +123,7 @@ class AgentModel(ABC):
122
123
  @abstractmethod
123
124
  async def request(
124
125
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
125
- ) -> tuple[ModelResponse, Cost]:
126
+ ) -> tuple[ModelResponse, Usage]:
126
127
  """Make a request to the model."""
127
128
  raise NotImplementedError()
128
129
 
@@ -164,10 +165,10 @@ class StreamTextResponse(ABC):
164
165
  raise NotImplementedError()
165
166
 
166
167
  @abstractmethod
167
- def cost(self) -> Cost:
168
- """Return the cost of the request.
168
+ def usage(self) -> Usage:
169
+ """Return the usage of the request.
169
170
 
170
- NOTE: this won't return the ful cost until the stream is finished.
171
+ NOTE: this won't return the full usage until the stream is finished.
171
172
  """
172
173
  raise NotImplementedError()
173
174
 
@@ -205,10 +206,10 @@ class StreamStructuredResponse(ABC):
205
206
  raise NotImplementedError()
206
207
 
207
208
  @abstractmethod
208
- def cost(self) -> Cost:
209
- """Get the cost of the request.
209
+ def usage(self) -> Usage:
210
+ """Get the usage of the request.
210
211
 
211
- NOTE: this won't return the full cost until the stream is finished.
212
+ NOTE: this won't return the full usage until the stream is finished.
212
213
  """
213
214
  raise NotImplementedError()
214
215
 
@@ -235,7 +236,7 @@ The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
235
236
  def check_allow_model_requests() -> None:
236
237
  """Check if model requests are allowed.
237
238
 
238
- If you're defining your own models that have cost or latency associated with their use, you should call this in
239
+ If you're defining your own models that have costs or latency associated with their use, you should call this in
239
240
  [`Model.agent_model`][pydantic_ai.models.Model.agent_model].
240
241
 
241
242
  Raises:
@@ -158,9 +158,9 @@ class AnthropicAgentModel(AgentModel):
158
158
 
159
159
  async def request(
160
160
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> tuple[ModelResponse, result.Cost]:
161
+ ) -> tuple[ModelResponse, result.Usage]:
162
162
  response = await self._messages_create(messages, False, model_settings)
163
- return self._process_response(response), _map_cost(response)
163
+ return self._process_response(response), _map_usage(response)
164
164
 
165
165
  @asynccontextmanager
166
166
  async def request_stream(
@@ -220,7 +220,7 @@ class AnthropicAgentModel(AgentModel):
220
220
  else:
221
221
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
222
  items.append(
223
- ToolCallPart.from_dict(
223
+ ToolCallPart.from_raw_args(
224
224
  item.name,
225
225
  cast(dict[str, Any], item.input),
226
226
  item.id,
@@ -282,11 +282,11 @@ class AnthropicAgentModel(AgentModel):
282
282
  MessageParam(
283
283
  role='user',
284
284
  content=[
285
- ToolUseBlockParam(
286
- id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
- input=part.model_response(),
288
- name=part.tool_name,
289
- type='tool_use',
285
+ ToolResultBlockParam(
286
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
+ type='tool_result',
288
+ content=part.model_response(),
289
+ is_error=True,
290
290
  ),
291
291
  ],
292
292
  )
@@ -311,11 +311,11 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
311
311
  id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312
312
  type='tool_use',
313
313
  name=t.tool_name,
314
- input=t.args.args_dict,
314
+ input=t.args_as_dict(),
315
315
  )
316
316
 
317
317
 
318
- def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
318
+ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
319
319
  if isinstance(message, AnthropicMessage):
320
320
  usage = message.usage
321
321
  else:
@@ -332,11 +332,11 @@ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
332
332
  usage = None
333
333
 
334
334
  if usage is None:
335
- return result.Cost()
335
+ return result.Usage()
336
336
 
337
337
  request_tokens = getattr(usage, 'input_tokens', None)
338
338
 
339
- return result.Cost(
339
+ return result.Usage(
340
340
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
341
  request_tokens=request_tokens,
342
342
  response_tokens=usage.output_tokens,
@@ -9,12 +9,10 @@ from datetime import datetime
9
9
  from itertools import chain
10
10
  from typing import Callable, Union, cast
11
11
 
12
- import pydantic_core
13
12
  from typing_extensions import TypeAlias, assert_never, overload
14
13
 
15
14
  from .. import _utils, result
16
15
  from ..messages import (
17
- ArgsJson,
18
16
  ModelMessage,
19
17
  ModelRequest,
20
18
  ModelResponse,
@@ -144,7 +142,7 @@ class FunctionAgentModel(AgentModel):
144
142
 
145
143
  async def request(
146
144
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
147
- ) -> tuple[ModelResponse, result.Cost]:
145
+ ) -> tuple[ModelResponse, result.Usage]:
148
146
  agent_info = replace(self.agent_info, model_settings=model_settings)
149
147
 
150
148
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -155,7 +153,7 @@ class FunctionAgentModel(AgentModel):
155
153
  assert isinstance(response_, ModelResponse), response_
156
154
  response = response_
157
155
  # TODO is `messages` right here? Should it just be new messages?
158
- return response, _estimate_cost(chain(messages, [response]))
156
+ return response, _estimate_usage(chain(messages, [response]))
159
157
 
160
158
  @asynccontextmanager
161
159
  async def request_stream(
@@ -198,8 +196,8 @@ class FunctionStreamTextResponse(StreamTextResponse):
198
196
  yield from self._buffer
199
197
  self._buffer.clear()
200
198
 
201
- def cost(self) -> result.Cost:
202
- return result.Cost()
199
+ def usage(self) -> result.Usage:
200
+ return result.Usage()
203
201
 
204
202
  def timestamp(self) -> datetime:
205
203
  return self._timestamp
@@ -232,19 +230,19 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
232
230
  calls: list[ModelResponsePart] = []
233
231
  for c in self._delta_tool_calls.values():
234
232
  if c.name is not None and c.json_args is not None:
235
- calls.append(ToolCallPart.from_json(c.name, c.json_args))
233
+ calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
236
234
 
237
235
  return ModelResponse(calls, timestamp=self._timestamp)
238
236
 
239
- def cost(self) -> result.Cost:
240
- return result.Cost()
237
+ def usage(self) -> result.Usage:
238
+ return _estimate_usage([self.get()])
241
239
 
242
240
  def timestamp(self) -> datetime:
243
241
  return self._timestamp
244
242
 
245
243
 
246
- def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
247
- """Very rough guesstimate of the number of tokens associate with a series of messages.
244
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
245
+ """Very rough guesstimate of the token usage associated with a series of messages.
248
246
 
249
247
  This is designed to be used solely to give plausible numbers for testing!
250
248
  """
@@ -255,32 +253,28 @@ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
255
253
  if isinstance(message, ModelRequest):
256
254
  for part in message.parts:
257
255
  if isinstance(part, (SystemPromptPart, UserPromptPart)):
258
- request_tokens += _string_cost(part.content)
256
+ request_tokens += _estimate_string_usage(part.content)
259
257
  elif isinstance(part, ToolReturnPart):
260
- request_tokens += _string_cost(part.model_response_str())
258
+ request_tokens += _estimate_string_usage(part.model_response_str())
261
259
  elif isinstance(part, RetryPromptPart):
262
- request_tokens += _string_cost(part.model_response())
260
+ request_tokens += _estimate_string_usage(part.model_response())
263
261
  else:
264
262
  assert_never(part)
265
263
  elif isinstance(message, ModelResponse):
266
264
  for part in message.parts:
267
265
  if isinstance(part, TextPart):
268
- response_tokens += _string_cost(part.content)
266
+ response_tokens += _estimate_string_usage(part.content)
269
267
  elif isinstance(part, ToolCallPart):
270
268
  call = part
271
- if isinstance(call.args, ArgsJson):
272
- args_str = call.args.args_json
273
- else:
274
- args_str = pydantic_core.to_json(call.args.args_dict).decode()
275
- response_tokens += 1 + _string_cost(args_str)
269
+ response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
276
270
  else:
277
271
  assert_never(part)
278
272
  else:
279
273
  assert_never(message)
280
- return result.Cost(
274
+ return result.Usage(
281
275
  request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
282
276
  )
283
277
 
284
278
 
285
- def _string_cost(content: str) -> int:
279
+ def _estimate_string_usage(content: str) -> int:
286
280
  return len(re.split(r'[\s",.:]+', content))
@@ -16,7 +16,6 @@ from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
16
16
 
17
17
  from .. import UnexpectedModelBehavior, _utils, exceptions, result
18
18
  from ..messages import (
19
- ArgsDict,
20
19
  ModelMessage,
21
20
  ModelRequest,
22
21
  ModelResponse,
@@ -172,10 +171,10 @@ class GeminiAgentModel(AgentModel):
172
171
 
173
172
  async def request(
174
173
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
175
- ) -> tuple[ModelResponse, result.Cost]:
174
+ ) -> tuple[ModelResponse, result.Usage]:
176
175
  async with self._make_request(messages, False, model_settings) as http_response:
177
176
  response = _gemini_response_ta.validate_json(await http_response.aread())
178
- return self._process_response(response), _metadata_as_cost(response)
177
+ return self._process_response(response), _metadata_as_usage(response)
179
178
 
180
179
  @asynccontextmanager
181
180
  async def request_stream(
@@ -301,7 +300,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
301
300
  _stream: AsyncIterator[bytes]
302
301
  _position: int = 0
303
302
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
304
- _cost: result.Cost = field(default_factory=result.Cost, init=False)
303
+ _usage: result.Usage = field(default_factory=result.Usage, init=False)
305
304
 
306
305
  async def __anext__(self) -> None:
307
306
  chunk = await self._stream.__anext__()
@@ -321,7 +320,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
321
320
  new_items, experimental_allow_partial='trailing-strings'
322
321
  )
323
322
  for r in new_responses:
324
- self._cost += _metadata_as_cost(r)
323
+ self._usage += _metadata_as_usage(r)
325
324
  parts = r['candidates'][0]['content']['parts']
326
325
  if _all_text_parts(parts):
327
326
  for part in parts:
@@ -331,8 +330,8 @@ class GeminiStreamTextResponse(StreamTextResponse):
331
330
  'Streamed response with unexpected content, expected all parts to be text'
332
331
  )
333
332
 
334
- def cost(self) -> result.Cost:
335
- return self._cost
333
+ def usage(self) -> result.Usage:
334
+ return self._usage
336
335
 
337
336
  def timestamp(self) -> datetime:
338
337
  return self._timestamp
@@ -345,7 +344,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
345
344
  _content: bytearray
346
345
  _stream: AsyncIterator[bytes]
347
346
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
348
- _cost: result.Cost = field(default_factory=result.Cost, init=False)
347
+ _usage: result.Usage = field(default_factory=result.Usage, init=False)
349
348
 
350
349
  async def __anext__(self) -> None:
351
350
  chunk = await self._stream.__anext__()
@@ -365,15 +364,15 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
365
364
  experimental_allow_partial='off' if final else 'trailing-strings',
366
365
  )
367
366
  combined_parts: list[_GeminiPartUnion] = []
368
- self._cost = result.Cost()
367
+ self._usage = result.Usage()
369
368
  for r in responses:
370
- self._cost += _metadata_as_cost(r)
369
+ self._usage += _metadata_as_usage(r)
371
370
  candidate = r['candidates'][0]
372
371
  combined_parts.extend(candidate['content']['parts'])
373
372
  return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
374
373
 
375
- def cost(self) -> result.Cost:
376
- return self._cost
374
+ def usage(self) -> result.Usage:
375
+ return self._usage
377
376
 
378
377
  def timestamp(self) -> datetime:
379
378
  return self._timestamp
@@ -460,8 +459,7 @@ class _GeminiFunctionCallPart(TypedDict):
460
459
 
461
460
 
462
461
  def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
463
- assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
464
- return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
462
+ return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
465
463
 
466
464
 
467
465
  def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
@@ -470,7 +468,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
470
468
  if 'text' in part:
471
469
  items.append(TextPart(part['text']))
472
470
  elif 'function_call' in part:
473
- items.append(ToolCallPart.from_dict(part['function_call']['name'], part['function_call']['args']))
471
+ items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
474
472
  elif 'function_response' in part:
475
473
  raise exceptions.UnexpectedModelBehavior(
476
474
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
@@ -640,14 +638,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
640
638
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
641
639
 
642
640
 
643
- def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
641
+ def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
644
642
  metadata = response.get('usage_metadata')
645
643
  if metadata is None:
646
- return result.Cost()
644
+ return result.Usage()
647
645
  details: dict[str, int] = {}
648
646
  if cached_content_token_count := metadata.get('cached_content_token_count'):
649
647
  details['cached_content_token_count'] = cached_content_token_count
650
- return result.Cost(
648
+ return result.Usage(
651
649
  request_tokens=metadata.get('prompt_token_count', 0),
652
650
  response_tokens=metadata.get('candidates_token_count', 0),
653
651
  total_tokens=metadata.get('total_token_count', 0),
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, result
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
- ArgsJson,
17
16
  ModelMessage,
18
17
  ModelRequest,
19
18
  ModelResponse,
@@ -25,7 +24,7 @@ from ..messages import (
25
24
  ToolReturnPart,
26
25
  UserPromptPart,
27
26
  )
28
- from ..result import Cost
27
+ from ..result import Usage
29
28
  from ..settings import ModelSettings
30
29
  from ..tools import ToolDefinition
31
30
  from . import (
@@ -158,9 +157,9 @@ class GroqAgentModel(AgentModel):
158
157
 
159
158
  async def request(
160
159
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> tuple[ModelResponse, result.Cost]:
160
+ ) -> tuple[ModelResponse, result.Usage]:
162
161
  response = await self._completions_create(messages, False, model_settings)
163
- return self._process_response(response), _map_cost(response)
162
+ return self._process_response(response), _map_usage(response)
164
163
 
165
164
  @asynccontextmanager
166
165
  async def request_stream(
@@ -221,14 +220,14 @@ class GroqAgentModel(AgentModel):
221
220
  items.append(TextPart(choice.message.content))
222
221
  if choice.message.tool_calls is not None:
223
222
  for c in choice.message.tool_calls:
224
- items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
223
+ items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
225
224
  return ModelResponse(items, timestamp=timestamp)
226
225
 
227
226
  @staticmethod
228
227
  async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
229
228
  """Process a streamed response, and prepare a streaming response to return."""
230
229
  timestamp: datetime | None = None
231
- start_cost = Cost()
230
+ start_usage = Usage()
232
231
  # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
233
232
  while True:
234
233
  try:
@@ -236,19 +235,19 @@ class GroqAgentModel(AgentModel):
236
235
  except StopAsyncIteration as e:
237
236
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
238
237
  timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
239
- start_cost += _map_cost(chunk)
238
+ start_usage += _map_usage(chunk)
240
239
 
241
240
  if chunk.choices:
242
241
  delta = chunk.choices[0].delta
243
242
 
244
243
  if delta.content is not None:
245
- return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
244
+ return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
246
245
  elif delta.tool_calls is not None:
247
246
  return GroqStreamStructuredResponse(
248
247
  response,
249
248
  {c.index: c for c in delta.tool_calls},
250
249
  timestamp,
251
- start_cost,
250
+ start_usage,
252
251
  )
253
252
 
254
253
  @classmethod
@@ -308,7 +307,7 @@ class GroqStreamTextResponse(StreamTextResponse):
308
307
  _first: str | None
309
308
  _response: AsyncStream[ChatCompletionChunk]
310
309
  _timestamp: datetime
311
- _cost: result.Cost
310
+ _usage: result.Usage
312
311
  _buffer: list[str] = field(default_factory=list, init=False)
313
312
 
314
313
  async def __anext__(self) -> None:
@@ -318,7 +317,7 @@ class GroqStreamTextResponse(StreamTextResponse):
318
317
  return None
319
318
 
320
319
  chunk = await self._response.__anext__()
321
- self._cost = _map_cost(chunk)
320
+ self._usage = _map_usage(chunk)
322
321
 
323
322
  try:
324
323
  choice = chunk.choices[0]
@@ -335,8 +334,8 @@ class GroqStreamTextResponse(StreamTextResponse):
335
334
  yield from self._buffer
336
335
  self._buffer.clear()
337
336
 
338
- def cost(self) -> Cost:
339
- return self._cost
337
+ def usage(self) -> Usage:
338
+ return self._usage
340
339
 
341
340
  def timestamp(self) -> datetime:
342
341
  return self._timestamp
@@ -349,11 +348,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
349
348
  _response: AsyncStream[ChatCompletionChunk]
350
349
  _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
351
350
  _timestamp: datetime
352
- _cost: result.Cost
351
+ _usage: result.Usage
353
352
 
354
353
  async def __anext__(self) -> None:
355
354
  chunk = await self._response.__anext__()
356
- self._cost = _map_cost(chunk)
355
+ self._usage = _map_usage(chunk)
357
356
 
358
357
  try:
359
358
  choice = chunk.choices[0]
@@ -380,27 +379,26 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
380
379
  for c in self._delta_tool_calls.values():
381
380
  if f := c.function:
382
381
  if f.name is not None and f.arguments is not None:
383
- items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
382
+ items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
384
383
 
385
384
  return ModelResponse(items, timestamp=self._timestamp)
386
385
 
387
- def cost(self) -> Cost:
388
- return self._cost
386
+ def usage(self) -> Usage:
387
+ return self._usage
389
388
 
390
389
  def timestamp(self) -> datetime:
391
390
  return self._timestamp
392
391
 
393
392
 
394
393
  def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
395
- assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
396
394
  return chat.ChatCompletionMessageToolCallParam(
397
395
  id=_guard_tool_call_id(t=t, model_source='Groq'),
398
396
  type='function',
399
- function={'name': t.tool_name, 'arguments': t.args.args_json},
397
+ function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
400
398
  )
401
399
 
402
400
 
403
- def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
401
+ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
404
402
  usage = None
405
403
  if isinstance(completion, ChatCompletion):
406
404
  usage = completion.usage
@@ -408,9 +406,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
408
406
  usage = completion.x_groq.usage
409
407
 
410
408
  if usage is None:
411
- return result.Cost()
409
+ return result.Usage()
412
410
 
413
- return result.Cost(
411
+ return result.Usage(
414
412
  request_tokens=usage.prompt_tokens,
415
413
  response_tokens=usage.completion_tokens,
416
414
  total_tokens=usage.total_tokens,