pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union
11
+ from uuid import uuid4
11
12
 
12
13
  import pydantic
13
- import pydantic_core
14
14
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
- from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
15
+ from typing_extensions import NotRequired, TypedDict, assert_never
16
16
 
17
- from .. import UnexpectedModelBehavior, _utils, exceptions, result
17
+ from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
18
  from ..messages import (
19
19
  ModelMessage,
20
20
  ModelRequest,
21
21
  ModelResponse,
22
22
  ModelResponsePart,
23
+ ModelResponseStreamEvent,
23
24
  RetryPromptPart,
24
25
  SystemPromptPart,
25
26
  TextPart,
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
31
32
  from ..tools import ToolDefinition
32
33
  from . import (
33
34
  AgentModel,
34
- EitherStreamedResponse,
35
35
  Model,
36
- StreamStructuredResponse,
37
- StreamTextResponse,
36
+ StreamedResponse,
38
37
  cached_async_http_client,
39
38
  check_allow_model_requests,
40
39
  get_user_agent,
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
171
170
 
172
171
  async def request(
173
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
174
- ) -> tuple[ModelResponse, result.Usage]:
173
+ ) -> tuple[ModelResponse, usage.Usage]:
175
174
  async with self._make_request(messages, False, model_settings) as http_response:
176
175
  response = _gemini_response_ta.validate_json(await http_response.aread())
177
176
  return self._process_response(response), _metadata_as_usage(response)
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
179
178
  @asynccontextmanager
180
179
  async def request_stream(
181
180
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
182
- ) -> AsyncIterator[EitherStreamedResponse]:
181
+ ) -> AsyncIterator[StreamedResponse]:
183
182
  async with self._make_request(messages, True, model_settings) as http_response:
184
183
  yield await self._process_streamed_response(http_response)
185
184
 
@@ -238,7 +237,7 @@ class GeminiAgentModel(AgentModel):
238
237
  return _process_response_from_parts(parts)
239
238
 
240
239
  @staticmethod
241
- async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
240
+ async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
242
241
  """Process a streamed response, and prepare a streaming response to return."""
243
242
  aiter_bytes = http_response.aiter_bytes()
244
243
  start_response: _GeminiResponse | None = None
@@ -259,11 +258,7 @@ class GeminiAgentModel(AgentModel):
259
258
  if start_response is None:
260
259
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
261
260
 
262
- # TODO: Update this once we rework stream responses to be more flexible
263
- if _extract_response_parts(start_response).is_left():
264
- return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
265
- else:
266
- return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
261
+ return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
267
262
 
268
263
  @classmethod
269
264
  def _message_to_gemini_content(
@@ -302,86 +297,69 @@ class GeminiAgentModel(AgentModel):
302
297
 
303
298
 
304
299
  @dataclass
305
- class GeminiStreamTextResponse(StreamTextResponse):
306
- """Implementation of `StreamTextResponse` for the Gemini model."""
307
-
308
- _json_content: bytearray
309
- _stream: AsyncIterator[bytes]
310
- _position: int = 0
311
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
312
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
313
-
314
- async def __anext__(self) -> None:
315
- chunk = await self._stream.__anext__()
316
- self._json_content.extend(chunk)
317
-
318
- def get(self, *, final: bool = False) -> Iterable[str]:
319
- if final:
320
- all_items = pydantic_core.from_json(self._json_content)
321
- new_items = all_items[self._position :]
322
- self._position = len(all_items)
323
- new_responses = _gemini_streamed_response_ta.validate_python(new_items)
324
- else:
325
- all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
326
- new_items = all_items[self._position : -1]
327
- self._position = len(all_items) - 1
328
- new_responses = _gemini_streamed_response_ta.validate_python(
329
- new_items, experimental_allow_partial='trailing-strings'
330
- )
331
- for r in new_responses:
332
- self._usage += _metadata_as_usage(r)
333
- parts = r['candidates'][0]['content']['parts']
334
- if _all_text_parts(parts):
335
- for part in parts:
336
- yield part['text']
337
- else:
338
- raise UnexpectedModelBehavior(
339
- 'Streamed response with unexpected content, expected all parts to be text'
340
- )
341
-
342
- def usage(self) -> result.Usage:
343
- return self._usage
344
-
345
- def timestamp(self) -> datetime:
346
- return self._timestamp
347
-
348
-
349
- @dataclass
350
- class GeminiStreamStructuredResponse(StreamStructuredResponse):
351
- """Implementation of `StreamStructuredResponse` for the Gemini model."""
300
+ class GeminiStreamedResponse(StreamedResponse):
301
+ """Implementation of `StreamedResponse` for the Gemini model."""
352
302
 
353
303
  _content: bytearray
354
304
  _stream: AsyncIterator[bytes]
355
305
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
356
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
357
-
358
- async def __anext__(self) -> None:
359
- chunk = await self._stream.__anext__()
360
- self._content.extend(chunk)
361
-
362
- def get(self, *, final: bool = False) -> ModelResponse:
363
- """Get the `ModelResponse` at this point.
364
306
 
365
- NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
366
- reply with a single response, when returning a structured data.
307
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
308
+ async for gemini_response in self._get_gemini_responses():
309
+ candidate = gemini_response['candidates'][0]
310
+ gemini_part: _GeminiPartUnion
311
+ for gemini_part in candidate['content']['parts']:
312
+ if 'text' in gemini_part:
313
+ # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
314
+ # amongst the tool call deltas
315
+ yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
316
+
317
+ elif 'function_call' in gemini_part:
318
+ # Here, we assume all function_call parts are complete and don't have deltas.
319
+ # We do this by assigning a unique randomly generated "vendor_part_id".
320
+ # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
321
+ # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
322
+ maybe_event = self._parts_manager.handle_tool_call_delta(
323
+ vendor_part_id=uuid4(),
324
+ tool_name=gemini_part['function_call']['name'],
325
+ args=gemini_part['function_call']['args'],
326
+ tool_call_id=None,
327
+ )
328
+ if maybe_event is not None:
329
+ yield maybe_event
330
+ else:
331
+ assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
332
+
333
+ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
334
+ # This method exists to ensure we only yield completed items, so we don't need to worry about
335
+ # partial gemini responses, which would make everything more complicated
336
+
337
+ gemini_responses: list[_GeminiResponse] = []
338
+ current_gemini_response_index = 0
339
+ # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
340
+ # But changing that would make things a lot more complicated.
341
+ async for chunk in self._stream:
342
+ self._content.extend(chunk)
343
+
344
+ gemini_responses = _gemini_streamed_response_ta.validate_json(
345
+ self._content,
346
+ experimental_allow_partial='trailing-strings',
347
+ )
367
348
 
368
- I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
369
- separate parts.
370
- """
371
- responses = _gemini_streamed_response_ta.validate_json(
372
- self._content,
373
- experimental_allow_partial='off' if final else 'trailing-strings',
374
- )
375
- combined_parts: list[_GeminiPartUnion] = []
376
- self._usage = result.Usage()
377
- for r in responses:
349
+ # The idea: yield only up to the latest response, which might still be partial.
350
+ # Note that if the latest response is complete, we could yield it immediately, but there's not a good
351
+ # allow_partial API to determine if the last item in the list is complete.
352
+ responses_to_yield = gemini_responses[:-1]
353
+ for r in responses_to_yield[current_gemini_response_index:]:
354
+ current_gemini_response_index += 1
355
+ self._usage += _metadata_as_usage(r)
356
+ yield r
357
+
358
+ # Now yield the final response, which should be complete
359
+ if gemini_responses:
360
+ r = gemini_responses[-1]
378
361
  self._usage += _metadata_as_usage(r)
379
- candidate = r['candidates'][0]
380
- combined_parts.extend(candidate['content']['parts'])
381
- return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
382
-
383
- def usage(self) -> result.Usage:
384
- return self._usage
362
+ yield r
385
363
 
386
364
  def timestamp(self) -> datetime:
387
365
  return self._timestamp
@@ -458,9 +436,14 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
458
436
  items: list[ModelResponsePart] = []
459
437
  for part in parts:
460
438
  if 'text' in part:
461
- items.append(TextPart(part['text']))
439
+ items.append(TextPart(content=part['text']))
462
440
  elif 'function_call' in part:
463
- items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
441
+ items.append(
442
+ ToolCallPart.from_raw_args(
443
+ tool_name=part['function_call']['name'],
444
+ args=part['function_call']['args'],
445
+ )
446
+ )
464
447
  elif 'function_response' in part:
465
448
  raise exceptions.UnexpectedModelBehavior(
466
449
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
575
558
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
576
559
 
577
560
 
578
- # TODO: Delete the next three functions once we've reworked streams to be more flexible
579
- def _extract_response_parts(
580
- response: _GeminiResponse,
581
- ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
582
- """Extract the parts of the response from the Gemini API.
583
-
584
- Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
585
- """
586
- if len(response['candidates']) != 1:
587
- raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
588
- parts = response['candidates'][0]['content']['parts']
589
- if _all_function_call_parts(parts):
590
- return _utils.Either(left=parts)
591
- elif _all_text_parts(parts):
592
- return _utils.Either(right=parts)
593
- else:
594
- raise exceptions.UnexpectedModelBehavior(
595
- f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
596
- )
597
-
598
-
599
- def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
600
- return all('function_call' in part for part in parts)
601
-
602
-
603
- def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
604
- return all('text' in part for part in parts)
605
-
606
-
607
561
  class _GeminiCandidates(TypedDict):
608
562
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
609
563
 
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
630
584
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
631
585
 
632
586
 
633
- def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
587
+ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
634
588
  metadata = response.get('usage_metadata')
635
589
  if metadata is None:
636
- return result.Usage()
590
+ return usage.Usage()
637
591
  details: dict[str, int] = {}
638
592
  if cached_content_token_count := metadata.get('cached_content_token_count'):
639
593
  details['cached_content_token_count'] = cached_content_token_count
640
- return result.Usage(
594
+ return usage.Usage(
641
595
  request_tokens=metadata.get('prompt_token_count', 0),
642
596
  response_tokens=metadata.get('candidates_token_count', 0),
643
597
  total_tokens=metadata.get('total_token_count', 0),
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator, Iterable
3
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timezone
@@ -10,13 +10,14 @@ from typing import Literal, overload
10
10
  from httpx import AsyncClient as AsyncHTTPClient
11
11
  from typing_extensions import assert_never
12
12
 
13
- from .. import UnexpectedModelBehavior, _utils, result
13
+ from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
15
15
  from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -24,15 +25,12 @@ from ..messages import (
24
25
  ToolReturnPart,
25
26
  UserPromptPart,
26
27
  )
27
- from ..result import Usage
28
28
  from ..settings import ModelSettings
29
29
  from ..tools import ToolDefinition
30
30
  from . import (
31
31
  AgentModel,
32
- EitherStreamedResponse,
33
32
  Model,
34
- StreamStructuredResponse,
35
- StreamTextResponse,
33
+ StreamedResponse,
36
34
  cached_async_http_client,
37
35
  check_allow_model_requests,
38
36
  )
@@ -41,7 +39,6 @@ try:
41
39
  from groq import NOT_GIVEN, AsyncGroq, AsyncStream
42
40
  from groq.types import chat
43
41
  from groq.types.chat import ChatCompletion, ChatCompletionChunk
44
- from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
45
42
  except ImportError as _import_error:
46
43
  raise ImportError(
47
44
  'Please install `groq` to use the Groq model, '
@@ -157,14 +154,14 @@ class GroqAgentModel(AgentModel):
157
154
 
158
155
  async def request(
159
156
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
160
- ) -> tuple[ModelResponse, result.Usage]:
157
+ ) -> tuple[ModelResponse, usage.Usage]:
161
158
  response = await self._completions_create(messages, False, model_settings)
162
159
  return self._process_response(response), _map_usage(response)
163
160
 
164
161
  @asynccontextmanager
165
162
  async def request_stream(
166
163
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
167
- ) -> AsyncIterator[EitherStreamedResponse]:
164
+ ) -> AsyncIterator[StreamedResponse]:
168
165
  response = await self._completions_create(messages, True, model_settings)
169
166
  async with response:
170
167
  yield await self._process_streamed_response(response)
@@ -217,38 +214,23 @@ class GroqAgentModel(AgentModel):
217
214
  choice = response.choices[0]
218
215
  items: list[ModelResponsePart] = []
219
216
  if choice.message.content is not None:
220
- items.append(TextPart(choice.message.content))
217
+ items.append(TextPart(content=choice.message.content))
221
218
  if choice.message.tool_calls is not None:
222
219
  for c in choice.message.tool_calls:
223
- items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
220
+ items.append(
221
+ ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
222
+ )
224
223
  return ModelResponse(items, timestamp=timestamp)
225
224
 
226
225
  @staticmethod
227
- async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
226
+ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
228
227
  """Process a streamed response, and prepare a streaming response to return."""
229
- timestamp: datetime | None = None
230
- start_usage = Usage()
231
- # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
232
- while True:
233
- try:
234
- chunk = await response.__anext__()
235
- except StopAsyncIteration as e:
236
- raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
237
- timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
238
- start_usage += _map_usage(chunk)
239
-
240
- if chunk.choices:
241
- delta = chunk.choices[0].delta
242
-
243
- if delta.content is not None:
244
- return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
245
- elif delta.tool_calls is not None:
246
- return GroqStreamStructuredResponse(
247
- response,
248
- {c.index: c for c in delta.tool_calls},
249
- timestamp,
250
- start_usage,
251
- )
228
+ peekable_response = _utils.PeekableAsyncStream(response)
229
+ first_chunk = await peekable_response.peek()
230
+ if isinstance(first_chunk, _utils.Unset):
231
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
232
+
233
+ return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
252
234
 
253
235
  @classmethod
254
236
  def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
@@ -301,90 +283,36 @@ class GroqAgentModel(AgentModel):
301
283
 
302
284
 
303
285
  @dataclass
304
- class GroqStreamTextResponse(StreamTextResponse):
305
- """Implementation of `StreamTextResponse` for Groq models."""
286
+ class GroqStreamedResponse(StreamedResponse):
287
+ """Implementation of `StreamedResponse` for Groq models."""
306
288
 
307
- _first: str | None
308
- _response: AsyncStream[ChatCompletionChunk]
289
+ _response: AsyncIterable[ChatCompletionChunk]
309
290
  _timestamp: datetime
310
- _usage: result.Usage
311
- _buffer: list[str] = field(default_factory=list, init=False)
312
-
313
- async def __anext__(self) -> None:
314
- if self._first is not None:
315
- self._buffer.append(self._first)
316
- self._first = None
317
- return None
318
-
319
- chunk = await self._response.__anext__()
320
- self._usage = _map_usage(chunk)
321
-
322
- try:
323
- choice = chunk.choices[0]
324
- except IndexError:
325
- raise StopAsyncIteration()
326
291
 
327
- # we don't raise StopAsyncIteration on the last chunk because usage comes after this
328
- if choice.finish_reason is None:
329
- assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
330
- if choice.delta.content is not None:
331
- self._buffer.append(choice.delta.content)
292
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
293
+ async for chunk in self._response:
294
+ self._usage += _map_usage(chunk)
332
295
 
333
- def get(self, *, final: bool = False) -> Iterable[str]:
334
- yield from self._buffer
335
- self._buffer.clear()
336
-
337
- def usage(self) -> Usage:
338
- return self._usage
339
-
340
- def timestamp(self) -> datetime:
341
- return self._timestamp
342
-
343
-
344
- @dataclass
345
- class GroqStreamStructuredResponse(StreamStructuredResponse):
346
- """Implementation of `StreamStructuredResponse` for Groq models."""
347
-
348
- _response: AsyncStream[ChatCompletionChunk]
349
- _delta_tool_calls: dict[int, ChoiceDeltaToolCall]
350
- _timestamp: datetime
351
- _usage: result.Usage
352
-
353
- async def __anext__(self) -> None:
354
- chunk = await self._response.__anext__()
355
- self._usage = _map_usage(chunk)
356
-
357
- try:
358
- choice = chunk.choices[0]
359
- except IndexError:
360
- raise StopAsyncIteration()
361
-
362
- if choice.finish_reason is not None:
363
- raise StopAsyncIteration()
364
-
365
- assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
366
-
367
- for new in choice.delta.tool_calls or []:
368
- if current := self._delta_tool_calls.get(new.index):
369
- if current.function is None:
370
- current.function = new.function
371
- elif new.function is not None:
372
- current.function.name = _utils.add_optional(current.function.name, new.function.name)
373
- current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
374
- else:
375
- self._delta_tool_calls[new.index] = new
376
-
377
- def get(self, *, final: bool = False) -> ModelResponse:
378
- items: list[ModelResponsePart] = []
379
- for c in self._delta_tool_calls.values():
380
- if f := c.function:
381
- if f.name is not None and f.arguments is not None:
382
- items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
383
-
384
- return ModelResponse(items, timestamp=self._timestamp)
385
-
386
- def usage(self) -> Usage:
387
- return self._usage
296
+ try:
297
+ choice = chunk.choices[0]
298
+ except IndexError:
299
+ continue
300
+
301
+ # Handle the text part of the response
302
+ content = choice.delta.content
303
+ if content is not None:
304
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
305
+
306
+ # Handle the tool calls
307
+ for dtc in choice.delta.tool_calls or []:
308
+ maybe_event = self._parts_manager.handle_tool_call_delta(
309
+ vendor_part_id=dtc.index,
310
+ tool_name=dtc.function and dtc.function.name,
311
+ args=dtc.function and dtc.function.arguments,
312
+ tool_call_id=dtc.id,
313
+ )
314
+ if maybe_event is not None:
315
+ yield maybe_event
388
316
 
389
317
  def timestamp(self) -> datetime:
390
318
  return self._timestamp
@@ -398,18 +326,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
398
326
  )
399
327
 
400
328
 
401
- def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
402
- usage = None
329
+ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
330
+ response_usage = None
403
331
  if isinstance(completion, ChatCompletion):
404
- usage = completion.usage
332
+ response_usage = completion.usage
405
333
  elif completion.x_groq is not None:
406
- usage = completion.x_groq.usage
334
+ response_usage = completion.x_groq.usage
407
335
 
408
- if usage is None:
409
- return result.Usage()
336
+ if response_usage is None:
337
+ return usage.Usage()
410
338
 
411
- return result.Usage(
412
- request_tokens=usage.prompt_tokens,
413
- response_tokens=usage.completion_tokens,
414
- total_tokens=usage.total_tokens,
339
+ return usage.Usage(
340
+ request_tokens=response_usage.prompt_tokens,
341
+ response_tokens=response_usage.completion_tokens,
342
+ total_tokens=response_usage.total_tokens,
415
343
  )