pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_result.py +4 -7
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +85 -75
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +24 -36
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +63 -33
- pydantic_ai/settings.py +65 -0
- pydantic_ai/tools.py +24 -14
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -20,7 +20,7 @@ from ..messages import ModelMessage, ModelResponse
|
|
|
20
20
|
from ..settings import ModelSettings
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from ..result import
|
|
23
|
+
from ..result import Usage
|
|
24
24
|
from ..tools import ToolDefinition
|
|
25
25
|
|
|
26
26
|
|
|
@@ -31,6 +31,7 @@ KnownModelName = Literal[
|
|
|
31
31
|
'openai:gpt-4',
|
|
32
32
|
'openai:o1-preview',
|
|
33
33
|
'openai:o1-mini',
|
|
34
|
+
'openai:o1',
|
|
34
35
|
'openai:gpt-3.5-turbo',
|
|
35
36
|
'groq:llama-3.3-70b-versatile',
|
|
36
37
|
'groq:llama-3.1-70b-versatile',
|
|
@@ -122,7 +123,7 @@ class AgentModel(ABC):
|
|
|
122
123
|
@abstractmethod
|
|
123
124
|
async def request(
|
|
124
125
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
125
|
-
) -> tuple[ModelResponse,
|
|
126
|
+
) -> tuple[ModelResponse, Usage]:
|
|
126
127
|
"""Make a request to the model."""
|
|
127
128
|
raise NotImplementedError()
|
|
128
129
|
|
|
@@ -164,10 +165,10 @@ class StreamTextResponse(ABC):
|
|
|
164
165
|
raise NotImplementedError()
|
|
165
166
|
|
|
166
167
|
@abstractmethod
|
|
167
|
-
def
|
|
168
|
-
"""Return the
|
|
168
|
+
def usage(self) -> Usage:
|
|
169
|
+
"""Return the usage of the request.
|
|
169
170
|
|
|
170
|
-
NOTE: this won't return the
|
|
171
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
171
172
|
"""
|
|
172
173
|
raise NotImplementedError()
|
|
173
174
|
|
|
@@ -205,10 +206,10 @@ class StreamStructuredResponse(ABC):
|
|
|
205
206
|
raise NotImplementedError()
|
|
206
207
|
|
|
207
208
|
@abstractmethod
|
|
208
|
-
def
|
|
209
|
-
"""Get the
|
|
209
|
+
def usage(self) -> Usage:
|
|
210
|
+
"""Get the usage of the request.
|
|
210
211
|
|
|
211
|
-
NOTE: this won't return the full
|
|
212
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
212
213
|
"""
|
|
213
214
|
raise NotImplementedError()
|
|
214
215
|
|
|
@@ -235,7 +236,7 @@ The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
|
|
|
235
236
|
def check_allow_model_requests() -> None:
|
|
236
237
|
"""Check if model requests are allowed.
|
|
237
238
|
|
|
238
|
-
If you're defining your own models that have
|
|
239
|
+
If you're defining your own models that have costs or latency associated with their use, you should call this in
|
|
239
240
|
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
|
|
240
241
|
|
|
241
242
|
Raises:
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -158,9 +158,9 @@ class AnthropicAgentModel(AgentModel):
|
|
|
158
158
|
|
|
159
159
|
async def request(
|
|
160
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse, result.
|
|
161
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
162
162
|
response = await self._messages_create(messages, False, model_settings)
|
|
163
|
-
return self._process_response(response),
|
|
163
|
+
return self._process_response(response), _map_usage(response)
|
|
164
164
|
|
|
165
165
|
@asynccontextmanager
|
|
166
166
|
async def request_stream(
|
|
@@ -220,7 +220,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
220
220
|
else:
|
|
221
221
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
222
|
items.append(
|
|
223
|
-
ToolCallPart.
|
|
223
|
+
ToolCallPart.from_raw_args(
|
|
224
224
|
item.name,
|
|
225
225
|
cast(dict[str, Any], item.input),
|
|
226
226
|
item.id,
|
|
@@ -282,11 +282,11 @@ class AnthropicAgentModel(AgentModel):
|
|
|
282
282
|
MessageParam(
|
|
283
283
|
role='user',
|
|
284
284
|
content=[
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
285
|
+
ToolResultBlockParam(
|
|
286
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
287
|
+
type='tool_result',
|
|
288
|
+
content=part.model_response(),
|
|
289
|
+
is_error=True,
|
|
290
290
|
),
|
|
291
291
|
],
|
|
292
292
|
)
|
|
@@ -311,11 +311,11 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
|
311
311
|
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
312
|
type='tool_use',
|
|
313
313
|
name=t.tool_name,
|
|
314
|
-
input=t.
|
|
314
|
+
input=t.args_as_dict(),
|
|
315
315
|
)
|
|
316
316
|
|
|
317
317
|
|
|
318
|
-
def
|
|
318
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
|
|
319
319
|
if isinstance(message, AnthropicMessage):
|
|
320
320
|
usage = message.usage
|
|
321
321
|
else:
|
|
@@ -332,11 +332,11 @@ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
|
|
|
332
332
|
usage = None
|
|
333
333
|
|
|
334
334
|
if usage is None:
|
|
335
|
-
return result.
|
|
335
|
+
return result.Usage()
|
|
336
336
|
|
|
337
337
|
request_tokens = getattr(usage, 'input_tokens', None)
|
|
338
338
|
|
|
339
|
-
return result.
|
|
339
|
+
return result.Usage(
|
|
340
340
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
341
|
request_tokens=request_tokens,
|
|
342
342
|
response_tokens=usage.output_tokens,
|
pydantic_ai/models/function.py
CHANGED
|
@@ -9,12 +9,10 @@ from datetime import datetime
|
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union, cast
|
|
11
11
|
|
|
12
|
-
import pydantic_core
|
|
13
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
14
13
|
|
|
15
14
|
from .. import _utils, result
|
|
16
15
|
from ..messages import (
|
|
17
|
-
ArgsJson,
|
|
18
16
|
ModelMessage,
|
|
19
17
|
ModelRequest,
|
|
20
18
|
ModelResponse,
|
|
@@ -144,7 +142,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
144
142
|
|
|
145
143
|
async def request(
|
|
146
144
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
147
|
-
) -> tuple[ModelResponse, result.
|
|
145
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
148
146
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
149
147
|
|
|
150
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -155,7 +153,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
155
153
|
assert isinstance(response_, ModelResponse), response_
|
|
156
154
|
response = response_
|
|
157
155
|
# TODO is `messages` right here? Should it just be new messages?
|
|
158
|
-
return response,
|
|
156
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
159
157
|
|
|
160
158
|
@asynccontextmanager
|
|
161
159
|
async def request_stream(
|
|
@@ -198,8 +196,8 @@ class FunctionStreamTextResponse(StreamTextResponse):
|
|
|
198
196
|
yield from self._buffer
|
|
199
197
|
self._buffer.clear()
|
|
200
198
|
|
|
201
|
-
def
|
|
202
|
-
return result.
|
|
199
|
+
def usage(self) -> result.Usage:
|
|
200
|
+
return result.Usage()
|
|
203
201
|
|
|
204
202
|
def timestamp(self) -> datetime:
|
|
205
203
|
return self._timestamp
|
|
@@ -232,19 +230,19 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
232
230
|
calls: list[ModelResponsePart] = []
|
|
233
231
|
for c in self._delta_tool_calls.values():
|
|
234
232
|
if c.name is not None and c.json_args is not None:
|
|
235
|
-
calls.append(ToolCallPart.
|
|
233
|
+
calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
|
|
236
234
|
|
|
237
235
|
return ModelResponse(calls, timestamp=self._timestamp)
|
|
238
236
|
|
|
239
|
-
def
|
|
240
|
-
return
|
|
237
|
+
def usage(self) -> result.Usage:
|
|
238
|
+
return _estimate_usage([self.get()])
|
|
241
239
|
|
|
242
240
|
def timestamp(self) -> datetime:
|
|
243
241
|
return self._timestamp
|
|
244
242
|
|
|
245
243
|
|
|
246
|
-
def
|
|
247
|
-
"""Very rough guesstimate of the
|
|
244
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
245
|
+
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
248
246
|
|
|
249
247
|
This is designed to be used solely to give plausible numbers for testing!
|
|
250
248
|
"""
|
|
@@ -255,32 +253,28 @@ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
|
|
|
255
253
|
if isinstance(message, ModelRequest):
|
|
256
254
|
for part in message.parts:
|
|
257
255
|
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
258
|
-
request_tokens +=
|
|
256
|
+
request_tokens += _estimate_string_usage(part.content)
|
|
259
257
|
elif isinstance(part, ToolReturnPart):
|
|
260
|
-
request_tokens +=
|
|
258
|
+
request_tokens += _estimate_string_usage(part.model_response_str())
|
|
261
259
|
elif isinstance(part, RetryPromptPart):
|
|
262
|
-
request_tokens +=
|
|
260
|
+
request_tokens += _estimate_string_usage(part.model_response())
|
|
263
261
|
else:
|
|
264
262
|
assert_never(part)
|
|
265
263
|
elif isinstance(message, ModelResponse):
|
|
266
264
|
for part in message.parts:
|
|
267
265
|
if isinstance(part, TextPart):
|
|
268
|
-
response_tokens +=
|
|
266
|
+
response_tokens += _estimate_string_usage(part.content)
|
|
269
267
|
elif isinstance(part, ToolCallPart):
|
|
270
268
|
call = part
|
|
271
|
-
|
|
272
|
-
args_str = call.args.args_json
|
|
273
|
-
else:
|
|
274
|
-
args_str = pydantic_core.to_json(call.args.args_dict).decode()
|
|
275
|
-
response_tokens += 1 + _string_cost(args_str)
|
|
269
|
+
response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
|
|
276
270
|
else:
|
|
277
271
|
assert_never(part)
|
|
278
272
|
else:
|
|
279
273
|
assert_never(message)
|
|
280
|
-
return result.
|
|
274
|
+
return result.Usage(
|
|
281
275
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
282
276
|
)
|
|
283
277
|
|
|
284
278
|
|
|
285
|
-
def
|
|
279
|
+
def _estimate_string_usage(content: str) -> int:
|
|
286
280
|
return len(re.split(r'[\s",.:]+', content))
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -16,7 +16,6 @@ from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
|
16
16
|
|
|
17
17
|
from .. import UnexpectedModelBehavior, _utils, exceptions, result
|
|
18
18
|
from ..messages import (
|
|
19
|
-
ArgsDict,
|
|
20
19
|
ModelMessage,
|
|
21
20
|
ModelRequest,
|
|
22
21
|
ModelResponse,
|
|
@@ -172,10 +171,10 @@ class GeminiAgentModel(AgentModel):
|
|
|
172
171
|
|
|
173
172
|
async def request(
|
|
174
173
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
175
|
-
) -> tuple[ModelResponse, result.
|
|
174
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
176
175
|
async with self._make_request(messages, False, model_settings) as http_response:
|
|
177
176
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
178
|
-
return self._process_response(response),
|
|
177
|
+
return self._process_response(response), _metadata_as_usage(response)
|
|
179
178
|
|
|
180
179
|
@asynccontextmanager
|
|
181
180
|
async def request_stream(
|
|
@@ -301,7 +300,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
301
300
|
_stream: AsyncIterator[bytes]
|
|
302
301
|
_position: int = 0
|
|
303
302
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
304
|
-
|
|
303
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
305
304
|
|
|
306
305
|
async def __anext__(self) -> None:
|
|
307
306
|
chunk = await self._stream.__anext__()
|
|
@@ -321,7 +320,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
321
320
|
new_items, experimental_allow_partial='trailing-strings'
|
|
322
321
|
)
|
|
323
322
|
for r in new_responses:
|
|
324
|
-
self.
|
|
323
|
+
self._usage += _metadata_as_usage(r)
|
|
325
324
|
parts = r['candidates'][0]['content']['parts']
|
|
326
325
|
if _all_text_parts(parts):
|
|
327
326
|
for part in parts:
|
|
@@ -331,8 +330,8 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
331
330
|
'Streamed response with unexpected content, expected all parts to be text'
|
|
332
331
|
)
|
|
333
332
|
|
|
334
|
-
def
|
|
335
|
-
return self.
|
|
333
|
+
def usage(self) -> result.Usage:
|
|
334
|
+
return self._usage
|
|
336
335
|
|
|
337
336
|
def timestamp(self) -> datetime:
|
|
338
337
|
return self._timestamp
|
|
@@ -345,7 +344,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
345
344
|
_content: bytearray
|
|
346
345
|
_stream: AsyncIterator[bytes]
|
|
347
346
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
348
|
-
|
|
347
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
349
348
|
|
|
350
349
|
async def __anext__(self) -> None:
|
|
351
350
|
chunk = await self._stream.__anext__()
|
|
@@ -365,15 +364,15 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
365
364
|
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
366
365
|
)
|
|
367
366
|
combined_parts: list[_GeminiPartUnion] = []
|
|
368
|
-
self.
|
|
367
|
+
self._usage = result.Usage()
|
|
369
368
|
for r in responses:
|
|
370
|
-
self.
|
|
369
|
+
self._usage += _metadata_as_usage(r)
|
|
371
370
|
candidate = r['candidates'][0]
|
|
372
371
|
combined_parts.extend(candidate['content']['parts'])
|
|
373
372
|
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
374
373
|
|
|
375
|
-
def
|
|
376
|
-
return self.
|
|
374
|
+
def usage(self) -> result.Usage:
|
|
375
|
+
return self._usage
|
|
377
376
|
|
|
378
377
|
def timestamp(self) -> datetime:
|
|
379
378
|
return self._timestamp
|
|
@@ -460,8 +459,7 @@ class _GeminiFunctionCallPart(TypedDict):
|
|
|
460
459
|
|
|
461
460
|
|
|
462
461
|
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
|
|
463
|
-
|
|
464
|
-
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
|
|
462
|
+
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
465
463
|
|
|
466
464
|
|
|
467
465
|
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
|
|
@@ -470,7 +468,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
470
468
|
if 'text' in part:
|
|
471
469
|
items.append(TextPart(part['text']))
|
|
472
470
|
elif 'function_call' in part:
|
|
473
|
-
items.append(ToolCallPart.
|
|
471
|
+
items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
|
|
474
472
|
elif 'function_response' in part:
|
|
475
473
|
raise exceptions.UnexpectedModelBehavior(
|
|
476
474
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
@@ -640,14 +638,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
640
638
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
641
639
|
|
|
642
640
|
|
|
643
|
-
def
|
|
641
|
+
def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
|
|
644
642
|
metadata = response.get('usage_metadata')
|
|
645
643
|
if metadata is None:
|
|
646
|
-
return result.
|
|
644
|
+
return result.Usage()
|
|
647
645
|
details: dict[str, int] = {}
|
|
648
646
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
649
647
|
details['cached_content_token_count'] = cached_content_token_count
|
|
650
|
-
return result.
|
|
648
|
+
return result.Usage(
|
|
651
649
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
652
650
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
653
651
|
total_tokens=metadata.get('total_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsJson,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -25,7 +24,7 @@ from ..messages import (
|
|
|
25
24
|
ToolReturnPart,
|
|
26
25
|
UserPromptPart,
|
|
27
26
|
)
|
|
28
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
29
28
|
from ..settings import ModelSettings
|
|
30
29
|
from ..tools import ToolDefinition
|
|
31
30
|
from . import (
|
|
@@ -158,9 +157,9 @@ class GroqAgentModel(AgentModel):
|
|
|
158
157
|
|
|
159
158
|
async def request(
|
|
160
159
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse, result.
|
|
160
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
162
161
|
response = await self._completions_create(messages, False, model_settings)
|
|
163
|
-
return self._process_response(response),
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
164
163
|
|
|
165
164
|
@asynccontextmanager
|
|
166
165
|
async def request_stream(
|
|
@@ -221,14 +220,14 @@ class GroqAgentModel(AgentModel):
|
|
|
221
220
|
items.append(TextPart(choice.message.content))
|
|
222
221
|
if choice.message.tool_calls is not None:
|
|
223
222
|
for c in choice.message.tool_calls:
|
|
224
|
-
items.append(ToolCallPart.
|
|
223
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
225
224
|
return ModelResponse(items, timestamp=timestamp)
|
|
226
225
|
|
|
227
226
|
@staticmethod
|
|
228
227
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
229
228
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
230
229
|
timestamp: datetime | None = None
|
|
231
|
-
|
|
230
|
+
start_usage = Usage()
|
|
232
231
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
233
232
|
while True:
|
|
234
233
|
try:
|
|
@@ -236,19 +235,19 @@ class GroqAgentModel(AgentModel):
|
|
|
236
235
|
except StopAsyncIteration as e:
|
|
237
236
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
238
237
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
239
|
-
|
|
238
|
+
start_usage += _map_usage(chunk)
|
|
240
239
|
|
|
241
240
|
if chunk.choices:
|
|
242
241
|
delta = chunk.choices[0].delta
|
|
243
242
|
|
|
244
243
|
if delta.content is not None:
|
|
245
|
-
return GroqStreamTextResponse(delta.content, response, timestamp,
|
|
244
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
246
245
|
elif delta.tool_calls is not None:
|
|
247
246
|
return GroqStreamStructuredResponse(
|
|
248
247
|
response,
|
|
249
248
|
{c.index: c for c in delta.tool_calls},
|
|
250
249
|
timestamp,
|
|
251
|
-
|
|
250
|
+
start_usage,
|
|
252
251
|
)
|
|
253
252
|
|
|
254
253
|
@classmethod
|
|
@@ -308,7 +307,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
308
307
|
_first: str | None
|
|
309
308
|
_response: AsyncStream[ChatCompletionChunk]
|
|
310
309
|
_timestamp: datetime
|
|
311
|
-
|
|
310
|
+
_usage: result.Usage
|
|
312
311
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
313
312
|
|
|
314
313
|
async def __anext__(self) -> None:
|
|
@@ -318,7 +317,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
318
317
|
return None
|
|
319
318
|
|
|
320
319
|
chunk = await self._response.__anext__()
|
|
321
|
-
self.
|
|
320
|
+
self._usage = _map_usage(chunk)
|
|
322
321
|
|
|
323
322
|
try:
|
|
324
323
|
choice = chunk.choices[0]
|
|
@@ -335,8 +334,8 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
335
334
|
yield from self._buffer
|
|
336
335
|
self._buffer.clear()
|
|
337
336
|
|
|
338
|
-
def
|
|
339
|
-
return self.
|
|
337
|
+
def usage(self) -> Usage:
|
|
338
|
+
return self._usage
|
|
340
339
|
|
|
341
340
|
def timestamp(self) -> datetime:
|
|
342
341
|
return self._timestamp
|
|
@@ -349,11 +348,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
349
348
|
_response: AsyncStream[ChatCompletionChunk]
|
|
350
349
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
351
350
|
_timestamp: datetime
|
|
352
|
-
|
|
351
|
+
_usage: result.Usage
|
|
353
352
|
|
|
354
353
|
async def __anext__(self) -> None:
|
|
355
354
|
chunk = await self._response.__anext__()
|
|
356
|
-
self.
|
|
355
|
+
self._usage = _map_usage(chunk)
|
|
357
356
|
|
|
358
357
|
try:
|
|
359
358
|
choice = chunk.choices[0]
|
|
@@ -380,27 +379,26 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
380
379
|
for c in self._delta_tool_calls.values():
|
|
381
380
|
if f := c.function:
|
|
382
381
|
if f.name is not None and f.arguments is not None:
|
|
383
|
-
items.append(ToolCallPart.
|
|
382
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
384
383
|
|
|
385
384
|
return ModelResponse(items, timestamp=self._timestamp)
|
|
386
385
|
|
|
387
|
-
def
|
|
388
|
-
return self.
|
|
386
|
+
def usage(self) -> Usage:
|
|
387
|
+
return self._usage
|
|
389
388
|
|
|
390
389
|
def timestamp(self) -> datetime:
|
|
391
390
|
return self._timestamp
|
|
392
391
|
|
|
393
392
|
|
|
394
393
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
395
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
396
394
|
return chat.ChatCompletionMessageToolCallParam(
|
|
397
395
|
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
398
396
|
type='function',
|
|
399
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
397
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
400
398
|
)
|
|
401
399
|
|
|
402
400
|
|
|
403
|
-
def
|
|
401
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
|
|
404
402
|
usage = None
|
|
405
403
|
if isinstance(completion, ChatCompletion):
|
|
406
404
|
usage = completion.usage
|
|
@@ -408,9 +406,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
|
|
|
408
406
|
usage = completion.x_groq.usage
|
|
409
407
|
|
|
410
408
|
if usage is None:
|
|
411
|
-
return result.
|
|
409
|
+
return result.Usage()
|
|
412
410
|
|
|
413
|
-
return result.
|
|
411
|
+
return result.Usage(
|
|
414
412
|
request_tokens=usage.prompt_tokens,
|
|
415
413
|
response_tokens=usage.completion_tokens,
|
|
416
414
|
total_tokens=usage.total_tokens,
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -26,7 +26,7 @@ from ..messages import (
|
|
|
26
26
|
ToolReturnPart,
|
|
27
27
|
UserPromptPart,
|
|
28
28
|
)
|
|
29
|
-
from ..result import
|
|
29
|
+
from ..result import Usage
|
|
30
30
|
from ..settings import ModelSettings
|
|
31
31
|
from ..tools import ToolDefinition
|
|
32
32
|
from . import (
|
|
@@ -156,10 +156,10 @@ class MistralAgentModel(AgentModel):
|
|
|
156
156
|
|
|
157
157
|
async def request(
|
|
158
158
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
159
|
-
) -> tuple[ModelResponse,
|
|
159
|
+
) -> tuple[ModelResponse, Usage]:
|
|
160
160
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
161
161
|
response = await self._completions_create(messages, model_settings)
|
|
162
|
-
return self._process_response(response),
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
163
163
|
|
|
164
164
|
@asynccontextmanager
|
|
165
165
|
async def request_stream(
|
|
@@ -297,7 +297,7 @@ class MistralAgentModel(AgentModel):
|
|
|
297
297
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
298
298
|
) -> EitherStreamedResponse:
|
|
299
299
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
300
|
-
|
|
300
|
+
start_usage = Usage()
|
|
301
301
|
|
|
302
302
|
# Iterate until we get either `tool_calls` or `content` from the first chunk.
|
|
303
303
|
while True:
|
|
@@ -307,7 +307,7 @@ class MistralAgentModel(AgentModel):
|
|
|
307
307
|
except StopAsyncIteration as e:
|
|
308
308
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
309
309
|
|
|
310
|
-
|
|
310
|
+
start_usage += _map_usage(chunk)
|
|
311
311
|
|
|
312
312
|
if chunk.created:
|
|
313
313
|
timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
@@ -329,11 +329,11 @@ class MistralAgentModel(AgentModel):
|
|
|
329
329
|
response,
|
|
330
330
|
content,
|
|
331
331
|
timestamp,
|
|
332
|
-
|
|
332
|
+
start_usage,
|
|
333
333
|
)
|
|
334
334
|
|
|
335
335
|
elif content:
|
|
336
|
-
return MistralStreamTextResponse(content, response, timestamp,
|
|
336
|
+
return MistralStreamTextResponse(content, response, timestamp, start_usage)
|
|
337
337
|
|
|
338
338
|
@staticmethod
|
|
339
339
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
@@ -474,7 +474,7 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
474
474
|
_first: str | None
|
|
475
475
|
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
476
476
|
_timestamp: datetime
|
|
477
|
-
|
|
477
|
+
_usage: Usage
|
|
478
478
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
479
479
|
|
|
480
480
|
async def __anext__(self) -> None:
|
|
@@ -484,7 +484,7 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
484
484
|
return None
|
|
485
485
|
|
|
486
486
|
chunk = await self._response.__anext__()
|
|
487
|
-
self.
|
|
487
|
+
self._usage += _map_usage(chunk.data)
|
|
488
488
|
|
|
489
489
|
try:
|
|
490
490
|
choice = chunk.data.choices[0]
|
|
@@ -502,8 +502,8 @@ class MistralStreamTextResponse(StreamTextResponse):
|
|
|
502
502
|
yield from self._buffer
|
|
503
503
|
self._buffer.clear()
|
|
504
504
|
|
|
505
|
-
def
|
|
506
|
-
return self.
|
|
505
|
+
def usage(self) -> Usage:
|
|
506
|
+
return self._usage
|
|
507
507
|
|
|
508
508
|
def timestamp(self) -> datetime:
|
|
509
509
|
return self._timestamp
|
|
@@ -518,11 +518,11 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
518
518
|
_response: MistralEventStreamAsync[MistralCompletionEvent]
|
|
519
519
|
_delta_content: str | None
|
|
520
520
|
_timestamp: datetime
|
|
521
|
-
|
|
521
|
+
_usage: Usage
|
|
522
522
|
|
|
523
523
|
async def __anext__(self) -> None:
|
|
524
524
|
chunk = await self._response.__anext__()
|
|
525
|
-
self.
|
|
525
|
+
self._usage += _map_usage(chunk.data)
|
|
526
526
|
|
|
527
527
|
try:
|
|
528
528
|
choice = chunk.data.choices[0]
|
|
@@ -560,25 +560,22 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
560
560
|
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
|
|
561
561
|
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
|
|
562
562
|
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
|
|
563
|
-
if not self.
|
|
563
|
+
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
|
|
564
564
|
continue
|
|
565
565
|
|
|
566
|
-
tool = ToolCallPart.
|
|
567
|
-
tool_name=result_tool.name,
|
|
568
|
-
args_dict=output_json,
|
|
569
|
-
)
|
|
566
|
+
tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
|
|
570
567
|
calls.append(tool)
|
|
571
568
|
|
|
572
569
|
return ModelResponse(calls, timestamp=self._timestamp)
|
|
573
570
|
|
|
574
|
-
def
|
|
575
|
-
return self.
|
|
571
|
+
def usage(self) -> Usage:
|
|
572
|
+
return self._usage
|
|
576
573
|
|
|
577
574
|
def timestamp(self) -> datetime:
|
|
578
575
|
return self._timestamp
|
|
579
576
|
|
|
580
577
|
@staticmethod
|
|
581
|
-
def
|
|
578
|
+
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
582
579
|
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
|
|
583
580
|
required_params = json_schema.get('required', [])
|
|
584
581
|
properties = json_schema.get('properties', {})
|
|
@@ -602,7 +599,7 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
|
|
|
602
599
|
|
|
603
600
|
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
|
|
604
601
|
nested_schema = param_schema
|
|
605
|
-
if not MistralStreamStructuredResponse.
|
|
602
|
+
if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
|
|
606
603
|
return False
|
|
607
604
|
|
|
608
605
|
return True
|
|
@@ -633,29 +630,20 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
|
|
|
633
630
|
tool_call_id = tool_call.id or None
|
|
634
631
|
func_call = tool_call.function
|
|
635
632
|
|
|
636
|
-
|
|
637
|
-
return ToolCallPart.from_json(
|
|
638
|
-
tool_name=func_call.name,
|
|
639
|
-
args_json=func_call.arguments,
|
|
640
|
-
tool_call_id=tool_call_id,
|
|
641
|
-
)
|
|
642
|
-
else:
|
|
643
|
-
return ToolCallPart.from_dict(
|
|
644
|
-
tool_name=func_call.name, args_dict=func_call.arguments, tool_call_id=tool_call_id
|
|
645
|
-
)
|
|
633
|
+
return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
|
|
646
634
|
|
|
647
635
|
|
|
648
|
-
def
|
|
649
|
-
"""Maps a Mistral Completion Chunk or Chat Completion Response to a
|
|
636
|
+
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
637
|
+
"""Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
|
|
650
638
|
if response.usage:
|
|
651
|
-
return
|
|
639
|
+
return Usage(
|
|
652
640
|
request_tokens=response.usage.prompt_tokens,
|
|
653
641
|
response_tokens=response.usage.completion_tokens,
|
|
654
642
|
total_tokens=response.usage.total_tokens,
|
|
655
643
|
details=None,
|
|
656
644
|
)
|
|
657
645
|
else:
|
|
658
|
-
return
|
|
646
|
+
return Usage()
|
|
659
647
|
|
|
660
648
|
|
|
661
649
|
def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
|