pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +14 -3
- pydantic_ai/_result.py +6 -9
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +154 -90
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +34 -51
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +82 -35
- pydantic_ai/settings.py +69 -0
- pydantic_ai/tools.py +22 -28
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/METADATA +1 -2
- pydantic_ai_slim-0.0.15.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/WHEEL +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
|
|
|
8
8
|
from itertools import chain
|
|
9
9
|
from typing import Any, Callable, Literal, Union
|
|
10
10
|
|
|
11
|
+
import pydantic_core
|
|
11
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
12
13
|
from typing_extensions import assert_never
|
|
13
14
|
|
|
@@ -26,7 +27,7 @@ from ..messages import (
|
|
|
26
27
|
ToolReturnPart,
|
|
27
28
|
UserPromptPart,
|
|
28
29
|
)
|
|
29
|
-
from ..result import
|
|
30
|
+
from ..result import Usage
|
|
30
31
|
from ..settings import ModelSettings
|
|
31
32
|
from ..tools import ToolDefinition
|
|
32
33
|
from . import (
|
|
@@ -39,7 +40,6 @@ from . import (
|
|
|
39
40
|
)
|
|
40
41
|
|
|
41
42
|
try:
|
|
42
|
-
from json_repair import repair_json
|
|
43
43
|
from mistralai import (
|
|
44
44
|
UNSET,
|
|
45
45
|
CompletionChunk as MistralCompletionChunk,
|
|
@@ -156,10 +156,10 @@ class MistralAgentModel(AgentModel):
|
|
|
156
156
|
|
|
157
157
|
async def request(
|
|
158
158
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
159
|
-
) -> tuple[ModelResponse,
|
|
159
|
+
) -> tuple[ModelResponse, Usage]:
|
|
160
160
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
161
161
|
response = await self._completions_create(messages, model_settings)
|
|
162
|
-
return self._process_response(response),
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
163
163
|
|
|
164
164
|
@asynccontextmanager
|
|
165
165
|
async def request_stream(
|
|
@@ -198,11 +198,10 @@ class MistralAgentModel(AgentModel):
|
|
|
198
198
|
"""Create a streaming completion request to the Mistral model."""
|
|
199
199
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
200
200
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
201
|
-
|
|
202
201
|
model_settings = model_settings or {}
|
|
203
202
|
|
|
204
203
|
if self.result_tools and self.function_tools or self.function_tools:
|
|
205
|
-
# Function Calling
|
|
204
|
+
# Function Calling
|
|
206
205
|
response = await self.client.chat.stream_async(
|
|
207
206
|
model=str(self.model_name),
|
|
208
207
|
messages=mistral_messages,
|
|
@@ -218,9 +217,9 @@ class MistralAgentModel(AgentModel):
|
|
|
218
217
|
elif self.result_tools:
|
|
219
218
|
# Json Mode
|
|
220
219
|
parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
|
|
221
|
-
|
|
222
220
|
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
|
|
223
221
|
mistral_messages.append(user_output_format_message)
|
|
222
|
+
|
|
224
223
|
response = await self.client.chat.stream_async(
|
|
225
224
|
model=str(self.model_name),
|
|
226
225
|
messages=mistral_messages,
|
|
@@ -270,12 +269,13 @@ class MistralAgentModel(AgentModel):
|
|
|
270
269
|
@staticmethod
|
|
271
270
|
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
272
271
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
272
|
+
assert response.choices, 'Unexpected empty response choice.'
|
|
273
|
+
|
|
273
274
|
if response.created:
|
|
274
275
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
275
276
|
else:
|
|
276
277
|
timestamp = _now_utc()
|
|
277
278
|
|
|
278
|
-
assert response.choices, 'Unexpected empty response choice.'
|
|
279
279
|
choice = response.choices[0]
|
|
280
280
|
content = choice.message.content
|
|
281
281
|
tool_calls = choice.message.tool_calls
|
|
@@ -297,7 +297,7 @@ class MistralAgentModel(AgentModel):
|
|
|
297
297
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
298
|
) -> EitherStreamedResponse:
|
|
299
299
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
-
|
|
300
|
+
start_usage = Usage()
|
|
301
301
|
|
|
302
302
|
# Iterate until we get either `tool_calls` or `content` from the first chunk.
|
|
303
303
|
while True:
|
|
@@ -307,7 +307,7 @@ class MistralAgentModel(AgentModel):
|
|
|
307
307
|
except StopAsyncIteration as e:
|
|
308
308
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
309
309
|
|
|
310
|
-
|
|
310
|
+
start_usage += _map_usage(chunk)
|
|
311
311
|
|
|
312
312
|
if chunk.created:
|
|
313
313
|
timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
@@ -329,11 +329,11 @@ class MistralAgentModel(AgentModel):
|
|
|
329
329
|
response,
|
|
330
330
|
content,
|
|
331
331
|
timestamp,
|
|
332
|
-
|
|
332
|
+
start_usage,
|
|
333
333
|
)
|
|
334
334
|
|
|
335
335
|
elif content:
|
|
336
|
-
return MistralStreamTextResponse(content, response, timestamp,
|
|
336
|
+
return MistralStreamTextResponse(content, response, timestamp, start_usage)
|
|
337
337
|
|
|
338
338
|
@staticmethod
|
|
339
339
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
@@ -474,7 +474,7 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
474
474
|
_first: str | None
|
|
475
475
|
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
476
|
_timestamp: datetime
|
|
477
|
-
|
|
477
|
+
_usage: Usage
|
|
478
478
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
479
|
|
|
480
480
|
async def __anext__(self) -> None:
|
|
@@ -484,7 +484,7 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
484
484
|
return None
|
|
485
485
|
|
|
486
486
|
chunk = await self._response.__anext__()
|
|
487
|
-
self.
|
|
487
|
+
self._usage += _map_usage(chunk.data)
|
|
488
488
|
|
|
489
489
|
try:
|
|
490
490
|
choice = chunk.data.choices[0]
|
|
@@ -502,8 +502,8 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
502
502
|
yield from self._buffer
|
|
503
503
|
self._buffer.clear()
|
|
504
504
|
|
|
505
|
-
def
|
|
506
|
-
return self.
|
|
505
|
+
def usage(self) -> Usage:
|
|
506
|
+
return self._usage
|
|
507
507
|
|
|
508
508
|
def timestamp(self) -> datetime:
|
|
509
509
|
return self._timestamp
|
|
@@ -518,11 +518,11 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
518
518
|
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
519
|
_delta_content: str | None
|
|
520
520
|
_timestamp: datetime
|
|
521
|
-
|
|
521
|
+
_usage: Usage
|
|
522
522
|
|
|
523
523
|
async def __anext__(self) -> None:
|
|
524
524
|
chunk = await self._response.__anext__()
|
|
525
|
-
self.
|
|
525
|
+
self._usage += _map_usage(chunk.data)
|
|
526
526
|
|
|
527
527
|
try:
|
|
528
528
|
choice = chunk.data.choices[0]
|
|
@@ -546,39 +546,31 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
546
546
|
calls.append(tool)
|
|
547
547
|
|
|
548
548
|
elif self._delta_content and self._result_tools:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
output_json, dict
|
|
553
|
-
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
|
|
549
|
+
output_json: dict[str, Any] | None = pydantic_core.from_json(
|
|
550
|
+
self._delta_content, allow_partial='trailing-strings'
|
|
551
|
+
)
|
|
554
552
|
|
|
555
553
|
if output_json:
|
|
556
554
|
for result_tool in self._result_tools.values():
|
|
557
|
-
# NOTE: Additional verification to prevent JSON validation to crash in `
|
|
555
|
+
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
|
|
558
556
|
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
|
|
559
|
-
#
|
|
560
|
-
|
|
561
|
-
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
|
|
562
|
-
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
|
|
563
|
-
if not self._validate_required_json_shema(output_json, result_tool.parameters_json_schema):
|
|
557
|
+
# Example with BaseModel and required fields.
|
|
558
|
+
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
564
559
|
continue
|
|
565
560
|
|
|
566
|
-
tool = ToolCallPart.
|
|
567
|
-
tool_name=result_tool.name,
|
|
568
|
-
args_dict=output_json,
|
|
569
|
-
)
|
|
561
|
+
tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
|
|
570
562
|
calls.append(tool)
|
|
571
563
|
|
|
572
564
|
return ModelResponse(calls, timestamp=self._timestamp)
|
|
573
565
|
|
|
574
|
-
def
|
|
575
|
-
return self.
|
|
566
|
+
def usage(self) -> Usage:
|
|
567
|
+
return self._usage
|
|
576
568
|
|
|
577
569
|
def timestamp(self) -> datetime:
|
|
578
570
|
return self._timestamp
|
|
579
571
|
|
|
580
572
|
@staticmethod
|
|
581
|
-
def
|
|
573
|
+
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
582
574
|
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
583
575
|
required_params = json_schema.get('required', [])
|
|
584
576
|
properties = json_schema.get('properties', {})
|
|
@@ -602,7 +594,7 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
602
594
|
|
|
603
595
|
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
604
596
|
nested_schema = param_schema
|
|
605
|
-
if not MistralStreamStructuredResponse.
|
|
597
|
+
if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
606
598
|
return False
|
|
607
599
|
|
|
608
600
|
return True
|
|
@@ -633,29 +625,20 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
|
|
|
633
625
|
tool_call_id = tool_call.id or None
|
|
634
626
|
func_call = tool_call.function
|
|
635
627
|
|
|
636
|
-
|
|
637
|
-
return ToolCallPart.from_json(
|
|
638
|
-
tool_name=func_call.name,
|
|
639
|
-
args_json=func_call.arguments,
|
|
640
|
-
tool_call_id=tool_call_id,
|
|
641
|
-
)
|
|
642
|
-
else:
|
|
643
|
-
return ToolCallPart.from_dict(
|
|
644
|
-
tool_name=func_call.name, args_dict=func_call.arguments, tool_call_id=tool_call_id
|
|
645
|
-
)
|
|
628
|
+
return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
|
|
646
629
|
|
|
647
630
|
|
|
648
|
-
def
|
|
649
|
-
"""Maps a Mistral Completion Chunk or Chat Completion Response to a
|
|
631
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
632
|
+
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
650
633
|
if response.usage:
|
|
651
|
-
return
|
|
634
|
+
return Usage(
|
|
652
635
|
request_tokens=response.usage.prompt_tokens,
|
|
653
636
|
response_tokens=response.usage.completion_tokens,
|
|
654
637
|
total_tokens=response.usage.total_tokens,
|
|
655
638
|
details=None,
|
|
656
639
|
)
|
|
657
640
|
else:
|
|
658
|
-
return
|
|
641
|
+
return Usage()
|
|
659
642
|
|
|
660
643
|
|
|
661
644
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsJson,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -25,7 +24,7 @@ from ..messages import (
|
|
|
25
24
|
ToolReturnPart,
|
|
26
25
|
UserPromptPart,
|
|
27
26
|
)
|
|
28
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
29
28
|
from ..settings import ModelSettings
|
|
30
29
|
from ..tools import ToolDefinition
|
|
31
30
|
from . import (
|
|
@@ -147,9 +146,9 @@ class OpenAIAgentModel(AgentModel):
|
|
|
147
146
|
|
|
148
147
|
async def request(
|
|
149
148
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
150
|
-
) -> tuple[ModelResponse, result.
|
|
149
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
151
150
|
response = await self._completions_create(messages, False, model_settings)
|
|
152
|
-
return self._process_response(response),
|
|
151
|
+
return self._process_response(response), _map_usage(response)
|
|
153
152
|
|
|
154
153
|
@asynccontextmanager
|
|
155
154
|
async def request_stream(
|
|
@@ -211,14 +210,14 @@ class OpenAIAgentModel(AgentModel):
|
|
|
211
210
|
items.append(TextPart(choice.message.content))
|
|
212
211
|
if choice.message.tool_calls is not None:
|
|
213
212
|
for c in choice.message.tool_calls:
|
|
214
|
-
items.append(ToolCallPart.
|
|
213
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
215
214
|
return ModelResponse(items, timestamp=timestamp)
|
|
216
215
|
|
|
217
216
|
@staticmethod
|
|
218
217
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
219
218
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
220
219
|
timestamp: datetime | None = None
|
|
221
|
-
|
|
220
|
+
start_usage = Usage()
|
|
222
221
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
223
222
|
while True:
|
|
224
223
|
try:
|
|
@@ -227,19 +226,19 @@ class OpenAIAgentModel(AgentModel):
|
|
|
227
226
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
228
227
|
|
|
229
228
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
230
|
-
|
|
229
|
+
start_usage += _map_usage(chunk)
|
|
231
230
|
|
|
232
231
|
if chunk.choices:
|
|
233
232
|
delta = chunk.choices[0].delta
|
|
234
233
|
|
|
235
234
|
if delta.content is not None:
|
|
236
|
-
return OpenAIStreamTextResponse(delta.content, response, timestamp,
|
|
235
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
237
236
|
elif delta.tool_calls is not None:
|
|
238
237
|
return OpenAIStreamStructuredResponse(
|
|
239
238
|
response,
|
|
240
239
|
{c.index: c for c in delta.tool_calls},
|
|
241
240
|
timestamp,
|
|
242
|
-
|
|
241
|
+
start_usage,
|
|
243
242
|
)
|
|
244
243
|
# else continue until we get either delta.content or delta.tool_calls
|
|
245
244
|
|
|
@@ -302,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
302
301
|
_first: str | None
|
|
303
302
|
_response: AsyncStream[ChatCompletionChunk]
|
|
304
303
|
_timestamp: datetime
|
|
305
|
-
|
|
304
|
+
_usage: result.Usage
|
|
306
305
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
307
306
|
|
|
308
307
|
async def __anext__(self) -> None:
|
|
@@ -312,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
312
311
|
return None
|
|
313
312
|
|
|
314
313
|
chunk = await self._response.__anext__()
|
|
315
|
-
self.
|
|
314
|
+
self._usage += _map_usage(chunk)
|
|
316
315
|
try:
|
|
317
316
|
choice = chunk.choices[0]
|
|
318
317
|
except IndexError:
|
|
@@ -328,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
328
327
|
yield from self._buffer
|
|
329
328
|
self._buffer.clear()
|
|
330
329
|
|
|
331
|
-
def
|
|
332
|
-
return self.
|
|
330
|
+
def usage(self) -> Usage:
|
|
331
|
+
return self._usage
|
|
333
332
|
|
|
334
333
|
def timestamp(self) -> datetime:
|
|
335
334
|
return self._timestamp
|
|
@@ -342,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
342
341
|
_response: AsyncStream[ChatCompletionChunk]
|
|
343
342
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
344
343
|
_timestamp: datetime
|
|
345
|
-
|
|
344
|
+
_usage: result.Usage
|
|
346
345
|
|
|
347
346
|
async def __anext__(self) -> None:
|
|
348
347
|
chunk = await self._response.__anext__()
|
|
349
|
-
self.
|
|
348
|
+
self._usage += _map_usage(chunk)
|
|
350
349
|
try:
|
|
351
350
|
choice = chunk.choices[0]
|
|
352
351
|
except IndexError:
|
|
@@ -372,37 +371,36 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
372
371
|
for c in self._delta_tool_calls.values():
|
|
373
372
|
if f := c.function:
|
|
374
373
|
if f.name is not None and f.arguments is not None:
|
|
375
|
-
items.append(ToolCallPart.
|
|
374
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
376
375
|
|
|
377
376
|
return ModelResponse(items, timestamp=self._timestamp)
|
|
378
377
|
|
|
379
|
-
def
|
|
380
|
-
return self.
|
|
378
|
+
def usage(self) -> Usage:
|
|
379
|
+
return self._usage
|
|
381
380
|
|
|
382
381
|
def timestamp(self) -> datetime:
|
|
383
382
|
return self._timestamp
|
|
384
383
|
|
|
385
384
|
|
|
386
385
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
387
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
388
386
|
return chat.ChatCompletionMessageToolCallParam(
|
|
389
387
|
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
390
388
|
type='function',
|
|
391
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
389
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
392
390
|
)
|
|
393
391
|
|
|
394
392
|
|
|
395
|
-
def
|
|
393
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
|
|
396
394
|
usage = response.usage
|
|
397
395
|
if usage is None:
|
|
398
|
-
return result.
|
|
396
|
+
return result.Usage()
|
|
399
397
|
else:
|
|
400
398
|
details: dict[str, int] = {}
|
|
401
399
|
if usage.completion_tokens_details is not None:
|
|
402
400
|
details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
403
401
|
if usage.prompt_tokens_details is not None:
|
|
404
402
|
details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
405
|
-
return result.
|
|
403
|
+
return result.Usage(
|
|
406
404
|
request_tokens=usage.prompt_tokens,
|
|
407
405
|
response_tokens=usage.completion_tokens,
|
|
408
406
|
total_tokens=usage.total_tokens,
|
pydantic_ai/models/test.py
CHANGED
|
@@ -21,7 +21,7 @@ from ..messages import (
|
|
|
21
21
|
ToolCallPart,
|
|
22
22
|
ToolReturnPart,
|
|
23
23
|
)
|
|
24
|
-
from ..result import
|
|
24
|
+
from ..result import Usage
|
|
25
25
|
from ..settings import ModelSettings
|
|
26
26
|
from ..tools import ToolDefinition
|
|
27
27
|
from . import (
|
|
@@ -31,6 +31,7 @@ from . import (
|
|
|
31
31
|
StreamStructuredResponse,
|
|
32
32
|
StreamTextResponse,
|
|
33
33
|
)
|
|
34
|
+
from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
@dataclass
|
|
@@ -131,15 +132,17 @@ class TestAgentModel(AgentModel):
|
|
|
131
132
|
|
|
132
133
|
async def request(
|
|
133
134
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
134
|
-
) -> tuple[ModelResponse,
|
|
135
|
-
|
|
135
|
+
) -> tuple[ModelResponse, Usage]:
|
|
136
|
+
model_response = self._request(messages, model_settings)
|
|
137
|
+
usage = _estimate_usage([*messages, model_response])
|
|
138
|
+
return model_response, usage
|
|
136
139
|
|
|
137
140
|
@asynccontextmanager
|
|
138
141
|
async def request_stream(
|
|
139
142
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
140
143
|
) -> AsyncIterator[EitherStreamedResponse]:
|
|
141
144
|
msg = self._request(messages, model_settings)
|
|
142
|
-
|
|
145
|
+
usage = _estimate_usage(messages)
|
|
143
146
|
|
|
144
147
|
# TODO: Rework this once we make StreamTextResponse more general
|
|
145
148
|
texts: list[str] = []
|
|
@@ -153,9 +156,9 @@ class TestAgentModel(AgentModel):
|
|
|
153
156
|
assert_never(item)
|
|
154
157
|
|
|
155
158
|
if texts:
|
|
156
|
-
yield TestStreamTextResponse('\n\n'.join(texts),
|
|
159
|
+
yield TestStreamTextResponse('\n\n'.join(texts), usage)
|
|
157
160
|
else:
|
|
158
|
-
yield TestStreamStructuredResponse(msg,
|
|
161
|
+
yield TestStreamStructuredResponse(msg, usage)
|
|
159
162
|
|
|
160
163
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
161
164
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -164,7 +167,7 @@ class TestAgentModel(AgentModel):
|
|
|
164
167
|
# if there are tools, the first thing we want to do is call all of them
|
|
165
168
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
166
169
|
return ModelResponse(
|
|
167
|
-
parts=[ToolCallPart.
|
|
170
|
+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
168
171
|
)
|
|
169
172
|
|
|
170
173
|
if messages:
|
|
@@ -176,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
176
179
|
if new_retry_names:
|
|
177
180
|
return ModelResponse(
|
|
178
181
|
parts=[
|
|
179
|
-
ToolCallPart.
|
|
182
|
+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
|
|
180
183
|
for name, args in self.tool_calls
|
|
181
184
|
if name in new_retry_names
|
|
182
185
|
]
|
|
@@ -202,10 +205,10 @@ class TestAgentModel(AgentModel):
|
|
|
202
205
|
custom_result_args = self.result.right
|
|
203
206
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
204
207
|
if custom_result_args is not None:
|
|
205
|
-
return ModelResponse(parts=[ToolCallPart.
|
|
208
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
|
|
206
209
|
else:
|
|
207
210
|
response_args = self.gen_tool_args(result_tool)
|
|
208
|
-
return ModelResponse(parts=[ToolCallPart.
|
|
211
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
|
|
209
212
|
|
|
210
213
|
|
|
211
214
|
@dataclass
|
|
@@ -213,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
213
216
|
"""A text response that streams test data."""
|
|
214
217
|
|
|
215
218
|
_text: str
|
|
216
|
-
|
|
219
|
+
_usage: Usage
|
|
217
220
|
_iter: Iterator[str] = field(init=False)
|
|
218
221
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
219
222
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
@@ -228,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
228
231
|
self._iter = iter(words)
|
|
229
232
|
|
|
230
233
|
async def __anext__(self) -> None:
|
|
231
|
-
|
|
234
|
+
next_str = _utils.sync_anext(self._iter)
|
|
235
|
+
response_tokens = _estimate_string_usage(next_str)
|
|
236
|
+
self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
237
|
+
self._buffer.append(next_str)
|
|
232
238
|
|
|
233
239
|
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
234
240
|
yield from self._buffer
|
|
235
241
|
self._buffer.clear()
|
|
236
242
|
|
|
237
|
-
def
|
|
238
|
-
return self.
|
|
243
|
+
def usage(self) -> Usage:
|
|
244
|
+
return self._usage
|
|
239
245
|
|
|
240
246
|
def timestamp(self) -> datetime:
|
|
241
247
|
return self._timestamp
|
|
@@ -246,7 +252,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
246
252
|
"""A structured response that streams test data."""
|
|
247
253
|
|
|
248
254
|
_structured_response: ModelResponse
|
|
249
|
-
|
|
255
|
+
_usage: Usage
|
|
250
256
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
251
257
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
252
258
|
|
|
@@ -256,8 +262,8 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
256
262
|
def get(self, *, final: bool = False) -> ModelResponse:
|
|
257
263
|
return self._structured_response
|
|
258
264
|
|
|
259
|
-
def
|
|
260
|
-
return self.
|
|
265
|
+
def usage(self) -> Usage:
|
|
266
|
+
return self._usage
|
|
261
267
|
|
|
262
268
|
def timestamp(self) -> datetime:
|
|
263
269
|
return self._timestamp
|