pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +82 -74
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +218 -9
- pydantic_ai/models/__init__.py +31 -72
- pydantic_ai/models/anthropic.py +21 -21
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +76 -122
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from collections.abc import AsyncIterator, Iterable
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
@@ -12,7 +12,7 @@ import pydantic_core
|
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
13
13
|
from typing_extensions import assert_never
|
|
14
14
|
|
|
15
|
-
from .. import UnexpectedModelBehavior
|
|
15
|
+
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
18
|
ArgsJson,
|
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
ModelRequest,
|
|
21
21
|
ModelResponse,
|
|
22
22
|
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
23
24
|
RetryPromptPart,
|
|
24
25
|
SystemPromptPart,
|
|
25
26
|
TextPart,
|
|
@@ -32,10 +33,8 @@ from ..settings import ModelSettings
|
|
|
32
33
|
from ..tools import ToolDefinition
|
|
33
34
|
from . import (
|
|
34
35
|
AgentModel,
|
|
35
|
-
EitherStreamedResponse,
|
|
36
36
|
Model,
|
|
37
|
-
|
|
38
|
-
StreamTextResponse,
|
|
37
|
+
StreamedResponse,
|
|
39
38
|
cached_async_http_client,
|
|
40
39
|
)
|
|
41
40
|
|
|
@@ -164,7 +163,7 @@ class MistralAgentModel(AgentModel):
|
|
|
164
163
|
@asynccontextmanager
|
|
165
164
|
async def request_stream(
|
|
166
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
167
|
-
) -> AsyncIterator[
|
|
166
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
168
167
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
169
168
|
response = await self._stream_completions_create(messages, model_settings)
|
|
170
169
|
async with response:
|
|
@@ -282,11 +281,11 @@ class MistralAgentModel(AgentModel):
|
|
|
282
281
|
|
|
283
282
|
parts: list[ModelResponsePart] = []
|
|
284
283
|
if text := _map_content(content):
|
|
285
|
-
parts.append(TextPart(text))
|
|
284
|
+
parts.append(TextPart(content=text))
|
|
286
285
|
|
|
287
286
|
if isinstance(tool_calls, list):
|
|
288
287
|
for tool_call in tool_calls:
|
|
289
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
288
|
+
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
290
289
|
parts.append(tool)
|
|
291
290
|
|
|
292
291
|
return ModelResponse(parts, timestamp=timestamp)
|
|
@@ -295,45 +294,19 @@ class MistralAgentModel(AgentModel):
|
|
|
295
294
|
async def _process_streamed_response(
|
|
296
295
|
result_tools: list[ToolDefinition],
|
|
297
296
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
|
-
) ->
|
|
297
|
+
) -> StreamedResponse:
|
|
299
298
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
-
|
|
299
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
300
|
+
first_chunk = await peekable_response.peek()
|
|
301
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
302
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
301
303
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
chunk = event.data
|
|
307
|
-
except StopAsyncIteration as e:
|
|
308
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
309
|
-
|
|
310
|
-
start_usage += _map_usage(chunk)
|
|
311
|
-
|
|
312
|
-
if chunk.created:
|
|
313
|
-
timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
314
|
-
else:
|
|
315
|
-
timestamp = _now_utc()
|
|
316
|
-
|
|
317
|
-
if chunk.choices:
|
|
318
|
-
delta = chunk.choices[0].delta
|
|
319
|
-
content = _map_content(delta.content)
|
|
320
|
-
|
|
321
|
-
tool_calls: list[MistralToolCall] | None = None
|
|
322
|
-
if delta.tool_calls:
|
|
323
|
-
tool_calls = delta.tool_calls
|
|
324
|
-
|
|
325
|
-
if tool_calls or content and result_tools:
|
|
326
|
-
return MistralStreamStructuredResponse(
|
|
327
|
-
{c.id if c.id else 'null': c for c in tool_calls or []},
|
|
328
|
-
{c.name: c for c in result_tools},
|
|
329
|
-
response,
|
|
330
|
-
content,
|
|
331
|
-
timestamp,
|
|
332
|
-
start_usage,
|
|
333
|
-
)
|
|
304
|
+
if first_chunk.data.created:
|
|
305
|
+
timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
|
|
306
|
+
else:
|
|
307
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
334
308
|
|
|
335
|
-
|
|
336
|
-
return MistralStreamTextResponse(content, response, timestamp, start_usage)
|
|
309
|
+
return MistralStreamedResponse(peekable_response, timestamp, {c.name: c for c in result_tools})
|
|
337
310
|
|
|
338
311
|
@staticmethod
|
|
339
312
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
@@ -467,108 +440,73 @@ class MistralAgentModel(AgentModel):
|
|
|
467
440
|
assert_never(message)
|
|
468
441
|
|
|
469
442
|
|
|
470
|
-
|
|
471
|
-
class MistralStreamTextResponse(StreamTextResponse):
|
|
472
|
-
"""Implementation of `StreamTextResponse` for Mistral models."""
|
|
473
|
-
|
|
474
|
-
_first: str | None
|
|
475
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
|
-
_timestamp: datetime
|
|
477
|
-
_usage: Usage
|
|
478
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
|
-
|
|
480
|
-
async def __anext__(self) -> None:
|
|
481
|
-
if self._first is not None and len(self._first) > 0:
|
|
482
|
-
self._buffer.append(self._first)
|
|
483
|
-
self._first = None
|
|
484
|
-
return None
|
|
485
|
-
|
|
486
|
-
chunk = await self._response.__anext__()
|
|
487
|
-
self._usage += _map_usage(chunk.data)
|
|
488
|
-
|
|
489
|
-
try:
|
|
490
|
-
choice = chunk.data.choices[0]
|
|
491
|
-
except IndexError:
|
|
492
|
-
raise StopAsyncIteration()
|
|
493
|
-
|
|
494
|
-
content = choice.delta.content
|
|
495
|
-
if choice.finish_reason is None:
|
|
496
|
-
assert content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
497
|
-
|
|
498
|
-
if text := _map_content(content):
|
|
499
|
-
self._buffer.append(text)
|
|
500
|
-
|
|
501
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
502
|
-
yield from self._buffer
|
|
503
|
-
self._buffer.clear()
|
|
504
|
-
|
|
505
|
-
def usage(self) -> Usage:
|
|
506
|
-
return self._usage
|
|
507
|
-
|
|
508
|
-
def timestamp(self) -> datetime:
|
|
509
|
-
return self._timestamp
|
|
443
|
+
MistralToolCallId = Union[str, None]
|
|
510
444
|
|
|
511
445
|
|
|
512
446
|
@dataclass
|
|
513
|
-
class
|
|
514
|
-
"""Implementation of `
|
|
447
|
+
class MistralStreamedResponse(StreamedResponse):
|
|
448
|
+
"""Implementation of `StreamedResponse` for Mistral models."""
|
|
515
449
|
|
|
516
|
-
|
|
517
|
-
_result_tools: dict[str, ToolDefinition]
|
|
518
|
-
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
|
-
_delta_content: str | None
|
|
450
|
+
_response: AsyncIterable[MistralCompletionEvent]
|
|
520
451
|
_timestamp: datetime
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
async def __anext__(self) -> None:
|
|
524
|
-
chunk = await self._response.__anext__()
|
|
525
|
-
self._usage += _map_usage(chunk.data)
|
|
526
|
-
|
|
527
|
-
try:
|
|
528
|
-
choice = chunk.data.choices[0]
|
|
529
|
-
|
|
530
|
-
except IndexError:
|
|
531
|
-
raise StopAsyncIteration()
|
|
532
|
-
|
|
533
|
-
if choice.finish_reason is not None:
|
|
534
|
-
raise StopAsyncIteration()
|
|
535
|
-
|
|
536
|
-
content = choice.delta.content
|
|
537
|
-
if self._result_tools:
|
|
538
|
-
if text := _map_content(content):
|
|
539
|
-
self._delta_content = (self._delta_content or '') + text
|
|
540
|
-
|
|
541
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
542
|
-
calls: list[ModelResponsePart] = []
|
|
543
|
-
if self._function_tools and self._result_tools or self._function_tools:
|
|
544
|
-
for tool_call in self._function_tools.values():
|
|
545
|
-
tool = _map_mistral_to_pydantic_tool_call(tool_call)
|
|
546
|
-
calls.append(tool)
|
|
547
|
-
|
|
548
|
-
elif self._delta_content and self._result_tools:
|
|
549
|
-
output_json: dict[str, Any] | None = pydantic_core.from_json(
|
|
550
|
-
self._delta_content, allow_partial='trailing-strings'
|
|
551
|
-
)
|
|
452
|
+
_result_tools: dict[str, ToolDefinition]
|
|
552
453
|
|
|
553
|
-
|
|
554
|
-
for result_tool in self._result_tools.values():
|
|
555
|
-
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
556
|
-
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
557
|
-
# Example with BaseModel and required fields.
|
|
558
|
-
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
559
|
-
continue
|
|
454
|
+
_delta_content: str = field(default='', init=False)
|
|
560
455
|
|
|
561
|
-
|
|
562
|
-
|
|
456
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
457
|
+
chunk: MistralCompletionEvent
|
|
458
|
+
async for chunk in self._response:
|
|
459
|
+
self._usage += _map_usage(chunk.data)
|
|
563
460
|
|
|
564
|
-
|
|
461
|
+
try:
|
|
462
|
+
choice = chunk.data.choices[0]
|
|
463
|
+
except IndexError:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
# Handle the text part of the response
|
|
467
|
+
content = choice.delta.content
|
|
468
|
+
text = _map_content(content)
|
|
469
|
+
if text:
|
|
470
|
+
# Attempt to produce a result tool call from the received text
|
|
471
|
+
if self._result_tools:
|
|
472
|
+
self._delta_content += text
|
|
473
|
+
maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
|
|
474
|
+
if maybe_tool_call_part:
|
|
475
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
476
|
+
vendor_part_id='result',
|
|
477
|
+
tool_name=maybe_tool_call_part.tool_name,
|
|
478
|
+
args=maybe_tool_call_part.args_as_dict(),
|
|
479
|
+
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
480
|
+
)
|
|
481
|
+
else:
|
|
482
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
|
|
565
483
|
|
|
566
|
-
|
|
567
|
-
|
|
484
|
+
# Handle the explicit tool calls
|
|
485
|
+
for index, dtc in enumerate(choice.delta.tool_calls or []):
|
|
486
|
+
# It seems that mistral just sends full tool calls, so we just use them directly, rather than building
|
|
487
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
488
|
+
vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
|
|
489
|
+
)
|
|
568
490
|
|
|
569
491
|
def timestamp(self) -> datetime:
|
|
570
492
|
return self._timestamp
|
|
571
493
|
|
|
494
|
+
@staticmethod
|
|
495
|
+
def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
|
|
496
|
+
output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
|
|
497
|
+
if output_json:
|
|
498
|
+
for result_tool in result_tools.values():
|
|
499
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
500
|
+
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
501
|
+
# Example with BaseModel and required fields.
|
|
502
|
+
if not MistralStreamedResponse._validate_required_json_schema(
|
|
503
|
+
output_json, result_tool.parameters_json_schema
|
|
504
|
+
):
|
|
505
|
+
continue
|
|
506
|
+
|
|
507
|
+
# The following part_id will be thrown away
|
|
508
|
+
return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json)
|
|
509
|
+
|
|
572
510
|
@staticmethod
|
|
573
511
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
574
512
|
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
@@ -587,20 +525,20 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
587
525
|
if not isinstance(json_dict[param], list):
|
|
588
526
|
return False
|
|
589
527
|
for item in json_dict[param]:
|
|
590
|
-
if not isinstance(item,
|
|
528
|
+
if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
|
|
591
529
|
return False
|
|
592
|
-
elif param_type and not isinstance(json_dict[param],
|
|
530
|
+
elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
|
|
593
531
|
return False
|
|
594
532
|
|
|
595
533
|
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
596
534
|
nested_schema = param_schema
|
|
597
|
-
if not
|
|
535
|
+
if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
598
536
|
return False
|
|
599
537
|
|
|
600
538
|
return True
|
|
601
539
|
|
|
602
540
|
|
|
603
|
-
|
|
541
|
+
VALID_JSON_TYPE_MAPPING: dict[str, Any] = {
|
|
604
542
|
'string': str,
|
|
605
543
|
'integer': int,
|
|
606
544
|
'number': float,
|
pydantic_ai/models/ollama.py
CHANGED
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -10,13 +10,14 @@ from typing import Literal, Union, overload
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils,
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -24,15 +25,12 @@ from ..messages import (
|
|
|
24
25
|
ToolReturnPart,
|
|
25
26
|
UserPromptPart,
|
|
26
27
|
)
|
|
27
|
-
from ..result import Usage
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
31
|
AgentModel,
|
|
32
|
-
EitherStreamedResponse,
|
|
33
32
|
Model,
|
|
34
|
-
|
|
35
|
-
StreamTextResponse,
|
|
33
|
+
StreamedResponse,
|
|
36
34
|
cached_async_http_client,
|
|
37
35
|
check_allow_model_requests,
|
|
38
36
|
)
|
|
@@ -41,7 +39,6 @@ try:
|
|
|
41
39
|
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|
42
40
|
from openai.types import ChatModel, chat
|
|
43
41
|
from openai.types.chat import ChatCompletionChunk
|
|
44
|
-
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
45
42
|
except ImportError as _import_error:
|
|
46
43
|
raise ImportError(
|
|
47
44
|
'Please install `openai` to use the OpenAI model, '
|
|
@@ -146,14 +143,14 @@ class OpenAIAgentModel(AgentModel):
|
|
|
146
143
|
|
|
147
144
|
async def request(
|
|
148
145
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
-
) -> tuple[ModelResponse,
|
|
146
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
150
147
|
response = await self._completions_create(messages, False, model_settings)
|
|
151
148
|
return self._process_response(response), _map_usage(response)
|
|
152
149
|
|
|
153
150
|
@asynccontextmanager
|
|
154
151
|
async def request_stream(
|
|
155
152
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
156
|
-
) -> AsyncIterator[
|
|
153
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
157
154
|
response = await self._completions_create(messages, True, model_settings)
|
|
158
155
|
async with response:
|
|
159
156
|
yield await self._process_streamed_response(response)
|
|
@@ -214,33 +211,14 @@ class OpenAIAgentModel(AgentModel):
|
|
|
214
211
|
return ModelResponse(items, timestamp=timestamp)
|
|
215
212
|
|
|
216
213
|
@staticmethod
|
|
217
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) ->
|
|
214
|
+
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
218
215
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
except StopAsyncIteration as e:
|
|
226
|
-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
227
|
-
|
|
228
|
-
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
229
|
-
start_usage += _map_usage(chunk)
|
|
230
|
-
|
|
231
|
-
if chunk.choices:
|
|
232
|
-
delta = chunk.choices[0].delta
|
|
233
|
-
|
|
234
|
-
if delta.content is not None:
|
|
235
|
-
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
236
|
-
elif delta.tool_calls is not None:
|
|
237
|
-
return OpenAIStreamStructuredResponse(
|
|
238
|
-
response,
|
|
239
|
-
{c.index: c for c in delta.tool_calls},
|
|
240
|
-
timestamp,
|
|
241
|
-
start_usage,
|
|
242
|
-
)
|
|
243
|
-
# else continue until we get either delta.content or delta.tool_calls
|
|
216
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
217
|
+
first_chunk = await peekable_response.peek()
|
|
218
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
219
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
220
|
+
|
|
221
|
+
return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc))
|
|
244
222
|
|
|
245
223
|
@classmethod
|
|
246
224
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
@@ -295,88 +273,35 @@ class OpenAIAgentModel(AgentModel):
|
|
|
295
273
|
|
|
296
274
|
|
|
297
275
|
@dataclass
|
|
298
|
-
class
|
|
299
|
-
"""Implementation of `
|
|
300
|
-
|
|
301
|
-
_first: str | None
|
|
302
|
-
_response: AsyncStream[ChatCompletionChunk]
|
|
303
|
-
_timestamp: datetime
|
|
304
|
-
_usage: result.Usage
|
|
305
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
306
|
-
|
|
307
|
-
async def __anext__(self) -> None:
|
|
308
|
-
if self._first is not None:
|
|
309
|
-
self._buffer.append(self._first)
|
|
310
|
-
self._first = None
|
|
311
|
-
return None
|
|
312
|
-
|
|
313
|
-
chunk = await self._response.__anext__()
|
|
314
|
-
self._usage += _map_usage(chunk)
|
|
315
|
-
try:
|
|
316
|
-
choice = chunk.choices[0]
|
|
317
|
-
except IndexError:
|
|
318
|
-
raise StopAsyncIteration()
|
|
319
|
-
|
|
320
|
-
# we don't raise StopAsyncIteration on the last chunk because usage comes after this
|
|
321
|
-
if choice.finish_reason is None:
|
|
322
|
-
assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
323
|
-
if choice.delta.content is not None:
|
|
324
|
-
self._buffer.append(choice.delta.content)
|
|
325
|
-
|
|
326
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
327
|
-
yield from self._buffer
|
|
328
|
-
self._buffer.clear()
|
|
329
|
-
|
|
330
|
-
def usage(self) -> Usage:
|
|
331
|
-
return self._usage
|
|
332
|
-
|
|
333
|
-
def timestamp(self) -> datetime:
|
|
334
|
-
return self._timestamp
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
@dataclass
|
|
338
|
-
class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
339
|
-
"""Implementation of `StreamStructuredResponse` for OpenAI models."""
|
|
276
|
+
class OpenAIStreamedResponse(StreamedResponse):
|
|
277
|
+
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
340
278
|
|
|
341
|
-
_response:
|
|
342
|
-
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
279
|
+
_response: AsyncIterable[ChatCompletionChunk]
|
|
343
280
|
_timestamp: datetime
|
|
344
|
-
_usage: result.Usage
|
|
345
|
-
|
|
346
|
-
async def __anext__(self) -> None:
|
|
347
|
-
chunk = await self._response.__anext__()
|
|
348
|
-
self._usage += _map_usage(chunk)
|
|
349
|
-
try:
|
|
350
|
-
choice = chunk.choices[0]
|
|
351
|
-
except IndexError:
|
|
352
|
-
raise StopAsyncIteration()
|
|
353
|
-
|
|
354
|
-
if choice.finish_reason is not None:
|
|
355
|
-
raise StopAsyncIteration()
|
|
356
|
-
|
|
357
|
-
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
|
|
358
|
-
|
|
359
|
-
for new in choice.delta.tool_calls or []:
|
|
360
|
-
if current := self._delta_tool_calls.get(new.index):
|
|
361
|
-
if current.function is None:
|
|
362
|
-
current.function = new.function
|
|
363
|
-
elif new.function is not None:
|
|
364
|
-
current.function.name = _utils.add_optional(current.function.name, new.function.name)
|
|
365
|
-
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
|
|
366
|
-
else:
|
|
367
|
-
self._delta_tool_calls[new.index] = new
|
|
368
281
|
|
|
369
|
-
def
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
if f := c.function:
|
|
373
|
-
if f.name is not None and f.arguments is not None:
|
|
374
|
-
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
375
|
-
|
|
376
|
-
return ModelResponse(items, timestamp=self._timestamp)
|
|
282
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
283
|
+
async for chunk in self._response:
|
|
284
|
+
self._usage += _map_usage(chunk)
|
|
377
285
|
|
|
378
|
-
|
|
379
|
-
|
|
286
|
+
try:
|
|
287
|
+
choice = chunk.choices[0]
|
|
288
|
+
except IndexError:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
# Handle the text part of the response
|
|
292
|
+
content = choice.delta.content
|
|
293
|
+
if content is not None:
|
|
294
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
295
|
+
|
|
296
|
+
for dtc in choice.delta.tool_calls or []:
|
|
297
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
298
|
+
vendor_part_id=dtc.index,
|
|
299
|
+
tool_name=dtc.function and dtc.function.name,
|
|
300
|
+
args=dtc.function and dtc.function.arguments,
|
|
301
|
+
tool_call_id=dtc.id,
|
|
302
|
+
)
|
|
303
|
+
if maybe_event is not None:
|
|
304
|
+
yield maybe_event
|
|
380
305
|
|
|
381
306
|
def timestamp(self) -> datetime:
|
|
382
307
|
return self._timestamp
|
|
@@ -390,19 +315,19 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
|
390
315
|
)
|
|
391
316
|
|
|
392
317
|
|
|
393
|
-
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) ->
|
|
394
|
-
|
|
395
|
-
if
|
|
396
|
-
return
|
|
318
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
|
|
319
|
+
response_usage = response.usage
|
|
320
|
+
if response_usage is None:
|
|
321
|
+
return usage.Usage()
|
|
397
322
|
else:
|
|
398
323
|
details: dict[str, int] = {}
|
|
399
|
-
if
|
|
400
|
-
details.update(
|
|
401
|
-
if
|
|
402
|
-
details.update(
|
|
403
|
-
return
|
|
404
|
-
request_tokens=
|
|
405
|
-
response_tokens=
|
|
406
|
-
total_tokens=
|
|
324
|
+
if response_usage.completion_tokens_details is not None:
|
|
325
|
+
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
326
|
+
if response_usage.prompt_tokens_details is not None:
|
|
327
|
+
details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
328
|
+
return usage.Usage(
|
|
329
|
+
request_tokens=response_usage.prompt_tokens,
|
|
330
|
+
response_tokens=response_usage.completion_tokens,
|
|
331
|
+
total_tokens=response_usage.total_tokens,
|
|
407
332
|
details=details,
|
|
408
333
|
)
|