pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +82 -74
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +218 -9
- pydantic_ai/models/__init__.py +31 -72
- pydantic_ai/models/anthropic.py +21 -21
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +76 -122
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator,
|
|
5
|
+
from collections.abc import AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union
|
|
11
|
+
from uuid import uuid4
|
|
11
12
|
|
|
12
13
|
import pydantic
|
|
13
|
-
import pydantic_core
|
|
14
14
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
15
|
-
from typing_extensions import NotRequired, TypedDict,
|
|
15
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
16
|
|
|
17
|
-
from .. import UnexpectedModelBehavior, _utils, exceptions,
|
|
17
|
+
from .. import UnexpectedModelBehavior, _utils, exceptions, usage
|
|
18
18
|
from ..messages import (
|
|
19
19
|
ModelMessage,
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
|
|
|
31
32
|
from ..tools import ToolDefinition
|
|
32
33
|
from . import (
|
|
33
34
|
AgentModel,
|
|
34
|
-
EitherStreamedResponse,
|
|
35
35
|
Model,
|
|
36
|
-
|
|
37
|
-
StreamTextResponse,
|
|
36
|
+
StreamedResponse,
|
|
38
37
|
cached_async_http_client,
|
|
39
38
|
check_allow_model_requests,
|
|
40
39
|
get_user_agent,
|
|
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
170
|
|
|
172
171
|
async def request(
|
|
173
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
174
|
-
) -> tuple[ModelResponse,
|
|
173
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
175
174
|
async with self._make_request(messages, False, model_settings) as http_response:
|
|
176
175
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
177
176
|
return self._process_response(response), _metadata_as_usage(response)
|
|
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
178
|
@asynccontextmanager
|
|
180
179
|
async def request_stream(
|
|
181
180
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
182
|
-
) -> AsyncIterator[
|
|
181
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
183
182
|
async with self._make_request(messages, True, model_settings) as http_response:
|
|
184
183
|
yield await self._process_streamed_response(http_response)
|
|
185
184
|
|
|
@@ -238,7 +237,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
238
237
|
return _process_response_from_parts(parts)
|
|
239
238
|
|
|
240
239
|
@staticmethod
|
|
241
|
-
async def _process_streamed_response(http_response: HTTPResponse) ->
|
|
240
|
+
async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
|
|
242
241
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
243
242
|
aiter_bytes = http_response.aiter_bytes()
|
|
244
243
|
start_response: _GeminiResponse | None = None
|
|
@@ -259,11 +258,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
259
258
|
if start_response is None:
|
|
260
259
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
261
260
|
|
|
262
|
-
|
|
263
|
-
if _extract_response_parts(start_response).is_left():
|
|
264
|
-
return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
|
|
265
|
-
else:
|
|
266
|
-
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
|
|
261
|
+
return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
|
|
267
262
|
|
|
268
263
|
@classmethod
|
|
269
264
|
def _message_to_gemini_content(
|
|
@@ -302,86 +297,69 @@ class GeminiAgentModel(AgentModel):
|
|
|
302
297
|
|
|
303
298
|
|
|
304
299
|
@dataclass
|
|
305
|
-
class
|
|
306
|
-
"""Implementation of `
|
|
307
|
-
|
|
308
|
-
_json_content: bytearray
|
|
309
|
-
_stream: AsyncIterator[bytes]
|
|
310
|
-
_position: int = 0
|
|
311
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
312
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
313
|
-
|
|
314
|
-
async def __anext__(self) -> None:
|
|
315
|
-
chunk = await self._stream.__anext__()
|
|
316
|
-
self._json_content.extend(chunk)
|
|
317
|
-
|
|
318
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
319
|
-
if final:
|
|
320
|
-
all_items = pydantic_core.from_json(self._json_content)
|
|
321
|
-
new_items = all_items[self._position :]
|
|
322
|
-
self._position = len(all_items)
|
|
323
|
-
new_responses = _gemini_streamed_response_ta.validate_python(new_items)
|
|
324
|
-
else:
|
|
325
|
-
all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
|
|
326
|
-
new_items = all_items[self._position : -1]
|
|
327
|
-
self._position = len(all_items) - 1
|
|
328
|
-
new_responses = _gemini_streamed_response_ta.validate_python(
|
|
329
|
-
new_items, experimental_allow_partial='trailing-strings'
|
|
330
|
-
)
|
|
331
|
-
for r in new_responses:
|
|
332
|
-
self._usage += _metadata_as_usage(r)
|
|
333
|
-
parts = r['candidates'][0]['content']['parts']
|
|
334
|
-
if _all_text_parts(parts):
|
|
335
|
-
for part in parts:
|
|
336
|
-
yield part['text']
|
|
337
|
-
else:
|
|
338
|
-
raise UnexpectedModelBehavior(
|
|
339
|
-
'Streamed response with unexpected content, expected all parts to be text'
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
def usage(self) -> result.Usage:
|
|
343
|
-
return self._usage
|
|
344
|
-
|
|
345
|
-
def timestamp(self) -> datetime:
|
|
346
|
-
return self._timestamp
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
@dataclass
|
|
350
|
-
class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
351
|
-
"""Implementation of `StreamStructuredResponse` for the Gemini model."""
|
|
300
|
+
class GeminiStreamedResponse(StreamedResponse):
|
|
301
|
+
"""Implementation of `StreamedResponse` for the Gemini model."""
|
|
352
302
|
|
|
353
303
|
_content: bytearray
|
|
354
304
|
_stream: AsyncIterator[bytes]
|
|
355
305
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
356
|
-
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
357
|
-
|
|
358
|
-
async def __anext__(self) -> None:
|
|
359
|
-
chunk = await self._stream.__anext__()
|
|
360
|
-
self._content.extend(chunk)
|
|
361
|
-
|
|
362
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
363
|
-
"""Get the `ModelResponse` at this point.
|
|
364
306
|
|
|
365
|
-
|
|
366
|
-
|
|
307
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
308
|
+
async for gemini_response in self._get_gemini_responses():
|
|
309
|
+
candidate = gemini_response['candidates'][0]
|
|
310
|
+
gemini_part: _GeminiPartUnion
|
|
311
|
+
for gemini_part in candidate['content']['parts']:
|
|
312
|
+
if 'text' in gemini_part:
|
|
313
|
+
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
|
|
314
|
+
# amongst the tool call deltas
|
|
315
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
|
|
316
|
+
|
|
317
|
+
elif 'function_call' in gemini_part:
|
|
318
|
+
# Here, we assume all function_call parts are complete and don't have deltas.
|
|
319
|
+
# We do this by assigning a unique randomly generated "vendor_part_id".
|
|
320
|
+
# We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
|
|
321
|
+
# it would just be a bit more complicated. And we'd need to confirm the intended semantics.
|
|
322
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
323
|
+
vendor_part_id=uuid4(),
|
|
324
|
+
tool_name=gemini_part['function_call']['name'],
|
|
325
|
+
args=gemini_part['function_call']['args'],
|
|
326
|
+
tool_call_id=None,
|
|
327
|
+
)
|
|
328
|
+
if maybe_event is not None:
|
|
329
|
+
yield maybe_event
|
|
330
|
+
else:
|
|
331
|
+
assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
|
|
332
|
+
|
|
333
|
+
async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
|
|
334
|
+
# This method exists to ensure we only yield completed items, so we don't need to worry about
|
|
335
|
+
# partial gemini responses, which would make everything more complicated
|
|
336
|
+
|
|
337
|
+
gemini_responses: list[_GeminiResponse] = []
|
|
338
|
+
current_gemini_response_index = 0
|
|
339
|
+
# Right now, there are some circumstances where we will have information that could be yielded sooner than it is
|
|
340
|
+
# But changing that would make things a lot more complicated.
|
|
341
|
+
async for chunk in self._stream:
|
|
342
|
+
self._content.extend(chunk)
|
|
343
|
+
|
|
344
|
+
gemini_responses = _gemini_streamed_response_ta.validate_json(
|
|
345
|
+
self._content,
|
|
346
|
+
experimental_allow_partial='trailing-strings',
|
|
347
|
+
)
|
|
367
348
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
349
|
+
# The idea: yield only up to the latest response, which might still be partial.
|
|
350
|
+
# Note that if the latest response is complete, we could yield it immediately, but there's not a good
|
|
351
|
+
# allow_partial API to determine if the last item in the list is complete.
|
|
352
|
+
responses_to_yield = gemini_responses[:-1]
|
|
353
|
+
for r in responses_to_yield[current_gemini_response_index:]:
|
|
354
|
+
current_gemini_response_index += 1
|
|
355
|
+
self._usage += _metadata_as_usage(r)
|
|
356
|
+
yield r
|
|
357
|
+
|
|
358
|
+
# Now yield the final response, which should be complete
|
|
359
|
+
if gemini_responses:
|
|
360
|
+
r = gemini_responses[-1]
|
|
378
361
|
self._usage += _metadata_as_usage(r)
|
|
379
|
-
|
|
380
|
-
combined_parts.extend(candidate['content']['parts'])
|
|
381
|
-
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
382
|
-
|
|
383
|
-
def usage(self) -> result.Usage:
|
|
384
|
-
return self._usage
|
|
362
|
+
yield r
|
|
385
363
|
|
|
386
364
|
def timestamp(self) -> datetime:
|
|
387
365
|
return self._timestamp
|
|
@@ -458,9 +436,14 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
458
436
|
items: list[ModelResponsePart] = []
|
|
459
437
|
for part in parts:
|
|
460
438
|
if 'text' in part:
|
|
461
|
-
items.append(TextPart(part['text']))
|
|
439
|
+
items.append(TextPart(content=part['text']))
|
|
462
440
|
elif 'function_call' in part:
|
|
463
|
-
items.append(
|
|
441
|
+
items.append(
|
|
442
|
+
ToolCallPart.from_raw_args(
|
|
443
|
+
tool_name=part['function_call']['name'],
|
|
444
|
+
args=part['function_call']['args'],
|
|
445
|
+
)
|
|
446
|
+
)
|
|
464
447
|
elif 'function_response' in part:
|
|
465
448
|
raise exceptions.UnexpectedModelBehavior(
|
|
466
449
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
|
|
|
575
558
|
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
|
|
576
559
|
|
|
577
560
|
|
|
578
|
-
# TODO: Delete the next three functions once we've reworked streams to be more flexible
|
|
579
|
-
def _extract_response_parts(
|
|
580
|
-
response: _GeminiResponse,
|
|
581
|
-
) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
|
|
582
|
-
"""Extract the parts of the response from the Gemini API.
|
|
583
|
-
|
|
584
|
-
Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
|
|
585
|
-
"""
|
|
586
|
-
if len(response['candidates']) != 1:
|
|
587
|
-
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
|
|
588
|
-
parts = response['candidates'][0]['content']['parts']
|
|
589
|
-
if _all_function_call_parts(parts):
|
|
590
|
-
return _utils.Either(left=parts)
|
|
591
|
-
elif _all_text_parts(parts):
|
|
592
|
-
return _utils.Either(right=parts)
|
|
593
|
-
else:
|
|
594
|
-
raise exceptions.UnexpectedModelBehavior(
|
|
595
|
-
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
|
|
600
|
-
return all('function_call' in part for part in parts)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
|
|
604
|
-
return all('text' in part for part in parts)
|
|
605
|
-
|
|
606
|
-
|
|
607
561
|
class _GeminiCandidates(TypedDict):
|
|
608
562
|
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
|
|
609
563
|
|
|
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
630
584
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
631
585
|
|
|
632
586
|
|
|
633
|
-
def _metadata_as_usage(response: _GeminiResponse) ->
|
|
587
|
+
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
|
|
634
588
|
metadata = response.get('usage_metadata')
|
|
635
589
|
if metadata is None:
|
|
636
|
-
return
|
|
590
|
+
return usage.Usage()
|
|
637
591
|
details: dict[str, int] = {}
|
|
638
592
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
639
593
|
details['cached_content_token_count'] = cached_content_token_count
|
|
640
|
-
return
|
|
594
|
+
return usage.Usage(
|
|
641
595
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
642
596
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
643
597
|
total_tokens=metadata.get('total_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -10,13 +10,14 @@ from typing import Literal, overload
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils,
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -24,15 +25,12 @@ from ..messages import (
|
|
|
24
25
|
ToolReturnPart,
|
|
25
26
|
UserPromptPart,
|
|
26
27
|
)
|
|
27
|
-
from ..result import Usage
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
31
|
AgentModel,
|
|
32
|
-
EitherStreamedResponse,
|
|
33
32
|
Model,
|
|
34
|
-
|
|
35
|
-
StreamTextResponse,
|
|
33
|
+
StreamedResponse,
|
|
36
34
|
cached_async_http_client,
|
|
37
35
|
check_allow_model_requests,
|
|
38
36
|
)
|
|
@@ -41,7 +39,6 @@ try:
|
|
|
41
39
|
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
|
|
42
40
|
from groq.types import chat
|
|
43
41
|
from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
|
44
|
-
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
45
42
|
except ImportError as _import_error:
|
|
46
43
|
raise ImportError(
|
|
47
44
|
'Please install `groq` to use the Groq model, '
|
|
@@ -157,14 +154,14 @@ class GroqAgentModel(AgentModel):
|
|
|
157
154
|
|
|
158
155
|
async def request(
|
|
159
156
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
|
-
) -> tuple[ModelResponse,
|
|
157
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
161
158
|
response = await self._completions_create(messages, False, model_settings)
|
|
162
159
|
return self._process_response(response), _map_usage(response)
|
|
163
160
|
|
|
164
161
|
@asynccontextmanager
|
|
165
162
|
async def request_stream(
|
|
166
163
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
164
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
165
|
response = await self._completions_create(messages, True, model_settings)
|
|
169
166
|
async with response:
|
|
170
167
|
yield await self._process_streamed_response(response)
|
|
@@ -217,38 +214,23 @@ class GroqAgentModel(AgentModel):
|
|
|
217
214
|
choice = response.choices[0]
|
|
218
215
|
items: list[ModelResponsePart] = []
|
|
219
216
|
if choice.message.content is not None:
|
|
220
|
-
items.append(TextPart(choice.message.content))
|
|
217
|
+
items.append(TextPart(content=choice.message.content))
|
|
221
218
|
if choice.message.tool_calls is not None:
|
|
222
219
|
for c in choice.message.tool_calls:
|
|
223
|
-
items.append(
|
|
220
|
+
items.append(
|
|
221
|
+
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
222
|
+
)
|
|
224
223
|
return ModelResponse(items, timestamp=timestamp)
|
|
225
224
|
|
|
226
225
|
@staticmethod
|
|
227
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) ->
|
|
226
|
+
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
228
227
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
except StopAsyncIteration as e:
|
|
236
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
237
|
-
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
238
|
-
start_usage += _map_usage(chunk)
|
|
239
|
-
|
|
240
|
-
if chunk.choices:
|
|
241
|
-
delta = chunk.choices[0].delta
|
|
242
|
-
|
|
243
|
-
if delta.content is not None:
|
|
244
|
-
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
245
|
-
elif delta.tool_calls is not None:
|
|
246
|
-
return GroqStreamStructuredResponse(
|
|
247
|
-
response,
|
|
248
|
-
{c.index: c for c in delta.tool_calls},
|
|
249
|
-
timestamp,
|
|
250
|
-
start_usage,
|
|
251
|
-
)
|
|
228
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
229
|
+
first_chunk = await peekable_response.peek()
|
|
230
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
231
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
232
|
+
|
|
233
|
+
return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
|
|
252
234
|
|
|
253
235
|
@classmethod
|
|
254
236
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
@@ -301,90 +283,36 @@ class GroqAgentModel(AgentModel):
|
|
|
301
283
|
|
|
302
284
|
|
|
303
285
|
@dataclass
|
|
304
|
-
class
|
|
305
|
-
"""Implementation of `
|
|
286
|
+
class GroqStreamedResponse(StreamedResponse):
|
|
287
|
+
"""Implementation of `StreamedResponse` for Groq models."""
|
|
306
288
|
|
|
307
|
-
|
|
308
|
-
_response: AsyncStream[ChatCompletionChunk]
|
|
289
|
+
_response: AsyncIterable[ChatCompletionChunk]
|
|
309
290
|
_timestamp: datetime
|
|
310
|
-
_usage: result.Usage
|
|
311
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
312
|
-
|
|
313
|
-
async def __anext__(self) -> None:
|
|
314
|
-
if self._first is not None:
|
|
315
|
-
self._buffer.append(self._first)
|
|
316
|
-
self._first = None
|
|
317
|
-
return None
|
|
318
|
-
|
|
319
|
-
chunk = await self._response.__anext__()
|
|
320
|
-
self._usage = _map_usage(chunk)
|
|
321
|
-
|
|
322
|
-
try:
|
|
323
|
-
choice = chunk.choices[0]
|
|
324
|
-
except IndexError:
|
|
325
|
-
raise StopAsyncIteration()
|
|
326
291
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
if choice.delta.content is not None:
|
|
331
|
-
self._buffer.append(choice.delta.content)
|
|
292
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
293
|
+
async for chunk in self._response:
|
|
294
|
+
self._usage += _map_usage(chunk)
|
|
332
295
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
async def __anext__(self) -> None:
|
|
354
|
-
chunk = await self._response.__anext__()
|
|
355
|
-
self._usage = _map_usage(chunk)
|
|
356
|
-
|
|
357
|
-
try:
|
|
358
|
-
choice = chunk.choices[0]
|
|
359
|
-
except IndexError:
|
|
360
|
-
raise StopAsyncIteration()
|
|
361
|
-
|
|
362
|
-
if choice.finish_reason is not None:
|
|
363
|
-
raise StopAsyncIteration()
|
|
364
|
-
|
|
365
|
-
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
|
|
366
|
-
|
|
367
|
-
for new in choice.delta.tool_calls or []:
|
|
368
|
-
if current := self._delta_tool_calls.get(new.index):
|
|
369
|
-
if current.function is None:
|
|
370
|
-
current.function = new.function
|
|
371
|
-
elif new.function is not None:
|
|
372
|
-
current.function.name = _utils.add_optional(current.function.name, new.function.name)
|
|
373
|
-
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
|
|
374
|
-
else:
|
|
375
|
-
self._delta_tool_calls[new.index] = new
|
|
376
|
-
|
|
377
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
378
|
-
items: list[ModelResponsePart] = []
|
|
379
|
-
for c in self._delta_tool_calls.values():
|
|
380
|
-
if f := c.function:
|
|
381
|
-
if f.name is not None and f.arguments is not None:
|
|
382
|
-
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
383
|
-
|
|
384
|
-
return ModelResponse(items, timestamp=self._timestamp)
|
|
385
|
-
|
|
386
|
-
def usage(self) -> Usage:
|
|
387
|
-
return self._usage
|
|
296
|
+
try:
|
|
297
|
+
choice = chunk.choices[0]
|
|
298
|
+
except IndexError:
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
# Handle the text part of the response
|
|
302
|
+
content = choice.delta.content
|
|
303
|
+
if content is not None:
|
|
304
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
305
|
+
|
|
306
|
+
# Handle the tool calls
|
|
307
|
+
for dtc in choice.delta.tool_calls or []:
|
|
308
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
309
|
+
vendor_part_id=dtc.index,
|
|
310
|
+
tool_name=dtc.function and dtc.function.name,
|
|
311
|
+
args=dtc.function and dtc.function.arguments,
|
|
312
|
+
tool_call_id=dtc.id,
|
|
313
|
+
)
|
|
314
|
+
if maybe_event is not None:
|
|
315
|
+
yield maybe_event
|
|
388
316
|
|
|
389
317
|
def timestamp(self) -> datetime:
|
|
390
318
|
return self._timestamp
|
|
@@ -398,18 +326,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
|
398
326
|
)
|
|
399
327
|
|
|
400
328
|
|
|
401
|
-
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) ->
|
|
402
|
-
|
|
329
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
|
|
330
|
+
response_usage = None
|
|
403
331
|
if isinstance(completion, ChatCompletion):
|
|
404
|
-
|
|
332
|
+
response_usage = completion.usage
|
|
405
333
|
elif completion.x_groq is not None:
|
|
406
|
-
|
|
334
|
+
response_usage = completion.x_groq.usage
|
|
407
335
|
|
|
408
|
-
if
|
|
409
|
-
return
|
|
336
|
+
if response_usage is None:
|
|
337
|
+
return usage.Usage()
|
|
410
338
|
|
|
411
|
-
return
|
|
412
|
-
request_tokens=
|
|
413
|
-
response_tokens=
|
|
414
|
-
total_tokens=
|
|
339
|
+
return usage.Usage(
|
|
340
|
+
request_tokens=response_usage.prompt_tokens,
|
|
341
|
+
response_tokens=response_usage.completion_tokens,
|
|
342
|
+
total_tokens=response_usage.total_tokens,
|
|
415
343
|
)
|