pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +14 -3
- pydantic_ai/_result.py +6 -9
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/agent.py +154 -90
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +29 -7
- pydantic_ai/models/__init__.py +10 -9
- pydantic_ai/models/anthropic.py +12 -12
- pydantic_ai/models/function.py +16 -22
- pydantic_ai/models/gemini.py +16 -18
- pydantic_ai/models/groq.py +21 -23
- pydantic_ai/models/mistral.py +34 -51
- pydantic_ai/models/openai.py +21 -23
- pydantic_ai/models/test.py +23 -17
- pydantic_ai/result.py +82 -35
- pydantic_ai/settings.py +69 -0
- pydantic_ai/tools.py +22 -28
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/METADATA +1 -2
- pydantic_ai_slim-0.0.15.dist-info/RECORD +26 -0
- pydantic_ai_slim-0.0.13.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.13.dist-info → pydantic_ai_slim-0.0.15.dist-info}/WHEEL +0 -0
pydantic_ai/exceptions.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
|
|
5
|
-
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
5
|
+
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelRetry(Exception):
|
|
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
|
|
|
30
30
|
super().__init__(message)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class
|
|
33
|
+
class AgentRunError(RuntimeError):
|
|
34
|
+
"""Base class for errors occurring during an agent run."""
|
|
35
|
+
|
|
36
|
+
message: str
|
|
37
|
+
"""The error message."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, message: str):
|
|
40
|
+
self.message = message
|
|
41
|
+
super().__init__(message)
|
|
42
|
+
|
|
43
|
+
def __str__(self) -> str:
|
|
44
|
+
return self.message
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class UsageLimitExceeded(AgentRunError):
|
|
48
|
+
"""Error raised when a Model's usage exceeds the specified limits."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class UnexpectedModelBehavior(AgentRunError):
|
|
34
52
|
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""
|
|
35
53
|
|
|
36
54
|
message: str
|
pydantic_ai/messages.py
CHANGED
|
@@ -2,11 +2,11 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from typing_extensions import Self
|
|
9
|
+
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
11
|
from ._utils import now_utc as _now_utc
|
|
12
12
|
|
|
@@ -190,12 +190,34 @@ class ToolCallPart:
|
|
|
190
190
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
191
191
|
|
|
192
192
|
@classmethod
|
|
193
|
-
def
|
|
194
|
-
|
|
193
|
+
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
194
|
+
"""Create a `ToolCallPart` from raw arguments."""
|
|
195
|
+
if isinstance(args, str):
|
|
196
|
+
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
197
|
+
elif isinstance(args, dict):
|
|
198
|
+
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
199
|
+
else:
|
|
200
|
+
assert_never(args)
|
|
195
201
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
202
|
+
def args_as_dict(self) -> dict[str, Any]:
|
|
203
|
+
"""Return the arguments as a Python dictionary.
|
|
204
|
+
|
|
205
|
+
This is just for convenience with models that require dicts as input.
|
|
206
|
+
"""
|
|
207
|
+
if isinstance(self.args, ArgsDict):
|
|
208
|
+
return self.args.args_dict
|
|
209
|
+
args = pydantic_core.from_json(self.args.args_json)
|
|
210
|
+
assert isinstance(args, dict), 'args should be a dict'
|
|
211
|
+
return cast(dict[str, Any], args)
|
|
212
|
+
|
|
213
|
+
def args_as_json_str(self) -> str:
|
|
214
|
+
"""Return the arguments as a JSON string.
|
|
215
|
+
|
|
216
|
+
This is just for convenience with models that require JSON strings as input.
|
|
217
|
+
"""
|
|
218
|
+
if isinstance(self.args, ArgsJson):
|
|
219
|
+
return self.args.args_json
|
|
220
|
+
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
199
221
|
|
|
200
222
|
def has_content(self) -> bool:
|
|
201
223
|
if isinstance(self.args, ArgsDict):
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -20,7 +20,7 @@ from ..messages import ModelMessage, ModelResponse
|
|
|
20
20
|
from ..settings import ModelSettings
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from ..result import
|
|
23
|
+
from ..result import Usage
|
|
24
24
|
from ..tools import ToolDefinition
|
|
25
25
|
|
|
26
26
|
|
|
@@ -31,6 +31,7 @@ KnownModelName = Literal[
|
|
|
31
31
|
'openai:gpt-4',
|
|
32
32
|
'openai:o1-preview',
|
|
33
33
|
'openai:o1-mini',
|
|
34
|
+
'openai:o1',
|
|
34
35
|
'openai:gpt-3.5-turbo',
|
|
35
36
|
'groq:llama-3.3-70b-versatile',
|
|
36
37
|
'groq:llama-3.1-70b-versatile',
|
|
@@ -122,7 +123,7 @@ class AgentModel(ABC):
|
|
|
122
123
|
@abstractmethod
|
|
123
124
|
async def request(
|
|
124
125
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
125
|
-
) -> tuple[ModelResponse,
|
|
126
|
+
) -> tuple[ModelResponse, Usage]:
|
|
126
127
|
"""Make a request to the model."""
|
|
127
128
|
raise NotImplementedError()
|
|
128
129
|
|
|
@@ -164,10 +165,10 @@ class StreamTextResponse(ABC):
|
|
|
164
165
|
raise NotImplementedError()
|
|
165
166
|
|
|
166
167
|
@abstractmethod
|
|
167
|
-
def
|
|
168
|
-
"""Return the
|
|
168
|
+
def usage(self) -> Usage:
|
|
169
|
+
"""Return the usage of the request.
|
|
169
170
|
|
|
170
|
-
NOTE: this won't return the
|
|
171
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
171
172
|
"""
|
|
172
173
|
raise NotImplementedError()
|
|
173
174
|
|
|
@@ -205,10 +206,10 @@ class StreamStructuredResponse(ABC):
|
|
|
205
206
|
raise NotImplementedError()
|
|
206
207
|
|
|
207
208
|
@abstractmethod
|
|
208
|
-
def
|
|
209
|
-
"""Get the
|
|
209
|
+
def usage(self) -> Usage:
|
|
210
|
+
"""Get the usage of the request.
|
|
210
211
|
|
|
211
|
-
NOTE: this won't return the full
|
|
212
|
+
NOTE: this won't return the full usage until the stream is finished.
|
|
212
213
|
"""
|
|
213
214
|
raise NotImplementedError()
|
|
214
215
|
|
|
@@ -235,7 +236,7 @@ The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
|
|
|
235
236
|
def check_allow_model_requests() -> None:
|
|
236
237
|
"""Check if model requests are allowed.
|
|
237
238
|
|
|
238
|
-
If you're defining your own models that have
|
|
239
|
+
If you're defining your own models that have costs or latency associated with their use, you should call this in
|
|
239
240
|
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
|
|
240
241
|
|
|
241
242
|
Raises:
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -158,9 +158,9 @@ class AnthropicAgentModel(AgentModel):
|
|
|
158
158
|
|
|
159
159
|
async def request(
|
|
160
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse, result.
|
|
161
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
162
162
|
response = await self._messages_create(messages, False, model_settings)
|
|
163
|
-
return self._process_response(response),
|
|
163
|
+
return self._process_response(response), _map_usage(response)
|
|
164
164
|
|
|
165
165
|
@asynccontextmanager
|
|
166
166
|
async def request_stream(
|
|
@@ -220,7 +220,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
220
220
|
else:
|
|
221
221
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
222
|
items.append(
|
|
223
|
-
ToolCallPart.
|
|
223
|
+
ToolCallPart.from_raw_args(
|
|
224
224
|
item.name,
|
|
225
225
|
cast(dict[str, Any], item.input),
|
|
226
226
|
item.id,
|
|
@@ -282,11 +282,11 @@ class AnthropicAgentModel(AgentModel):
|
|
|
282
282
|
MessageParam(
|
|
283
283
|
role='user',
|
|
284
284
|
content=[
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
285
|
+
ToolResultBlockParam(
|
|
286
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
287
|
+
type='tool_result',
|
|
288
|
+
content=part.model_response(),
|
|
289
|
+
is_error=True,
|
|
290
290
|
),
|
|
291
291
|
],
|
|
292
292
|
)
|
|
@@ -311,11 +311,11 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
|
311
311
|
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
312
|
type='tool_use',
|
|
313
313
|
name=t.tool_name,
|
|
314
|
-
input=t.
|
|
314
|
+
input=t.args_as_dict(),
|
|
315
315
|
)
|
|
316
316
|
|
|
317
317
|
|
|
318
|
-
def
|
|
318
|
+
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
|
|
319
319
|
if isinstance(message, AnthropicMessage):
|
|
320
320
|
usage = message.usage
|
|
321
321
|
else:
|
|
@@ -332,11 +332,11 @@ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
|
|
|
332
332
|
usage = None
|
|
333
333
|
|
|
334
334
|
if usage is None:
|
|
335
|
-
return result.
|
|
335
|
+
return result.Usage()
|
|
336
336
|
|
|
337
337
|
request_tokens = getattr(usage, 'input_tokens', None)
|
|
338
338
|
|
|
339
|
-
return result.
|
|
339
|
+
return result.Usage(
|
|
340
340
|
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
341
|
request_tokens=request_tokens,
|
|
342
342
|
response_tokens=usage.output_tokens,
|
pydantic_ai/models/function.py
CHANGED
|
@@ -9,12 +9,10 @@ from datetime import datetime
|
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union, cast
|
|
11
11
|
|
|
12
|
-
import pydantic_core
|
|
13
12
|
from typing_extensions import TypeAlias, assert_never, overload
|
|
14
13
|
|
|
15
14
|
from .. import _utils, result
|
|
16
15
|
from ..messages import (
|
|
17
|
-
ArgsJson,
|
|
18
16
|
ModelMessage,
|
|
19
17
|
ModelRequest,
|
|
20
18
|
ModelResponse,
|
|
@@ -144,7 +142,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
144
142
|
|
|
145
143
|
async def request(
|
|
146
144
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
147
|
-
) -> tuple[ModelResponse, result.
|
|
145
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
148
146
|
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
149
147
|
|
|
150
148
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -155,7 +153,7 @@ class FunctionAgentModel(AgentModel):
|
|
|
155
153
|
assert isinstance(response_, ModelResponse), response_
|
|
156
154
|
response = response_
|
|
157
155
|
# TODO is `messages` right here? Should it just be new messages?
|
|
158
|
-
return response,
|
|
156
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
159
157
|
|
|
160
158
|
@asynccontextmanager
|
|
161
159
|
async def request_stream(
|
|
@@ -198,8 +196,8 @@ class FunctionStreamTextResponse(StreamTextResponse):
|
|
|
198
196
|
yield from self._buffer
|
|
199
197
|
self._buffer.clear()
|
|
200
198
|
|
|
201
|
-
def
|
|
202
|
-
return result.
|
|
199
|
+
def usage(self) -> result.Usage:
|
|
200
|
+
return result.Usage()
|
|
203
201
|
|
|
204
202
|
def timestamp(self) -> datetime:
|
|
205
203
|
return self._timestamp
|
|
@@ -232,19 +230,19 @@ class FunctionStreamStructuredResponse(StreamStructuredResponse):
|
|
|
232
230
|
calls: list[ModelResponsePart] = []
|
|
233
231
|
for c in self._delta_tool_calls.values():
|
|
234
232
|
if c.name is not None and c.json_args is not None:
|
|
235
|
-
calls.append(ToolCallPart.
|
|
233
|
+
calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
|
|
236
234
|
|
|
237
235
|
return ModelResponse(calls, timestamp=self._timestamp)
|
|
238
236
|
|
|
239
|
-
def
|
|
240
|
-
return
|
|
237
|
+
def usage(self) -> result.Usage:
|
|
238
|
+
return _estimate_usage([self.get()])
|
|
241
239
|
|
|
242
240
|
def timestamp(self) -> datetime:
|
|
243
241
|
return self._timestamp
|
|
244
242
|
|
|
245
243
|
|
|
246
|
-
def
|
|
247
|
-
"""Very rough guesstimate of the
|
|
244
|
+
def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
|
|
245
|
+
"""Very rough guesstimate of the token usage associated with a series of messages.
|
|
248
246
|
|
|
249
247
|
This is designed to be used solely to give plausible numbers for testing!
|
|
250
248
|
"""
|
|
@@ -255,32 +253,28 @@ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
|
|
|
255
253
|
if isinstance(message, ModelRequest):
|
|
256
254
|
for part in message.parts:
|
|
257
255
|
if isinstance(part, (SystemPromptPart, UserPromptPart)):
|
|
258
|
-
request_tokens +=
|
|
256
|
+
request_tokens += _estimate_string_usage(part.content)
|
|
259
257
|
elif isinstance(part, ToolReturnPart):
|
|
260
|
-
request_tokens +=
|
|
258
|
+
request_tokens += _estimate_string_usage(part.model_response_str())
|
|
261
259
|
elif isinstance(part, RetryPromptPart):
|
|
262
|
-
request_tokens +=
|
|
260
|
+
request_tokens += _estimate_string_usage(part.model_response())
|
|
263
261
|
else:
|
|
264
262
|
assert_never(part)
|
|
265
263
|
elif isinstance(message, ModelResponse):
|
|
266
264
|
for part in message.parts:
|
|
267
265
|
if isinstance(part, TextPart):
|
|
268
|
-
response_tokens +=
|
|
266
|
+
response_tokens += _estimate_string_usage(part.content)
|
|
269
267
|
elif isinstance(part, ToolCallPart):
|
|
270
268
|
call = part
|
|
271
|
-
|
|
272
|
-
args_str = call.args.args_json
|
|
273
|
-
else:
|
|
274
|
-
args_str = pydantic_core.to_json(call.args.args_dict).decode()
|
|
275
|
-
response_tokens += 1 + _string_cost(args_str)
|
|
269
|
+
response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
|
|
276
270
|
else:
|
|
277
271
|
assert_never(part)
|
|
278
272
|
else:
|
|
279
273
|
assert_never(message)
|
|
280
|
-
return result.
|
|
274
|
+
return result.Usage(
|
|
281
275
|
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
|
|
282
276
|
)
|
|
283
277
|
|
|
284
278
|
|
|
285
|
-
def
|
|
279
|
+
def _estimate_string_usage(content: str) -> int:
|
|
286
280
|
return len(re.split(r'[\s",.:]+', content))
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -16,7 +16,6 @@ from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
|
|
|
16
16
|
|
|
17
17
|
from .. import UnexpectedModelBehavior, _utils, exceptions, result
|
|
18
18
|
from ..messages import (
|
|
19
|
-
ArgsDict,
|
|
20
19
|
ModelMessage,
|
|
21
20
|
ModelRequest,
|
|
22
21
|
ModelResponse,
|
|
@@ -172,10 +171,10 @@ class GeminiAgentModel(AgentModel):
|
|
|
172
171
|
|
|
173
172
|
async def request(
|
|
174
173
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
175
|
-
) -> tuple[ModelResponse, result.
|
|
174
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
176
175
|
async with self._make_request(messages, False, model_settings) as http_response:
|
|
177
176
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
178
|
-
return self._process_response(response),
|
|
177
|
+
return self._process_response(response), _metadata_as_usage(response)
|
|
179
178
|
|
|
180
179
|
@asynccontextmanager
|
|
181
180
|
async def request_stream(
|
|
@@ -301,7 +300,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
301
300
|
_stream: AsyncIterator[bytes]
|
|
302
301
|
_position: int = 0
|
|
303
302
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
304
|
-
|
|
303
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
305
304
|
|
|
306
305
|
async def __anext__(self) -> None:
|
|
307
306
|
chunk = await self._stream.__anext__()
|
|
@@ -321,7 +320,7 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
321
320
|
new_items, experimental_allow_partial='trailing-strings'
|
|
322
321
|
)
|
|
323
322
|
for r in new_responses:
|
|
324
|
-
self.
|
|
323
|
+
self._usage += _metadata_as_usage(r)
|
|
325
324
|
parts = r['candidates'][0]['content']['parts']
|
|
326
325
|
if _all_text_parts(parts):
|
|
327
326
|
for part in parts:
|
|
@@ -331,8 +330,8 @@ class GeminiStreamTextResponse(StreamTextResponse):
|
|
|
331
330
|
'Streamed response with unexpected content, expected all parts to be text'
|
|
332
331
|
)
|
|
333
332
|
|
|
334
|
-
def
|
|
335
|
-
return self.
|
|
333
|
+
def usage(self) -> result.Usage:
|
|
334
|
+
return self._usage
|
|
336
335
|
|
|
337
336
|
def timestamp(self) -> datetime:
|
|
338
337
|
return self._timestamp
|
|
@@ -345,7 +344,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
345
344
|
_content: bytearray
|
|
346
345
|
_stream: AsyncIterator[bytes]
|
|
347
346
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
348
|
-
|
|
347
|
+
_usage: result.Usage = field(default_factory=result.Usage, init=False)
|
|
349
348
|
|
|
350
349
|
async def __anext__(self) -> None:
|
|
351
350
|
chunk = await self._stream.__anext__()
|
|
@@ -365,15 +364,15 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse):
|
|
|
365
364
|
experimental_allow_partial='off' if final else 'trailing-strings',
|
|
366
365
|
)
|
|
367
366
|
combined_parts: list[_GeminiPartUnion] = []
|
|
368
|
-
self.
|
|
367
|
+
self._usage = result.Usage()
|
|
369
368
|
for r in responses:
|
|
370
|
-
self.
|
|
369
|
+
self._usage += _metadata_as_usage(r)
|
|
371
370
|
candidate = r['candidates'][0]
|
|
372
371
|
combined_parts.extend(candidate['content']['parts'])
|
|
373
372
|
return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
|
|
374
373
|
|
|
375
|
-
def
|
|
376
|
-
return self.
|
|
374
|
+
def usage(self) -> result.Usage:
|
|
375
|
+
return self._usage
|
|
377
376
|
|
|
378
377
|
def timestamp(self) -> datetime:
|
|
379
378
|
return self._timestamp
|
|
@@ -460,8 +459,7 @@ class _GeminiFunctionCallPart(TypedDict):
|
|
|
460
459
|
|
|
461
460
|
|
|
462
461
|
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
|
|
463
|
-
|
|
464
|
-
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
|
|
462
|
+
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
|
|
465
463
|
|
|
466
464
|
|
|
467
465
|
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
|
|
@@ -470,7 +468,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
|
|
|
470
468
|
if 'text' in part:
|
|
471
469
|
items.append(TextPart(part['text']))
|
|
472
470
|
elif 'function_call' in part:
|
|
473
|
-
items.append(ToolCallPart.
|
|
471
|
+
items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
|
|
474
472
|
elif 'function_response' in part:
|
|
475
473
|
raise exceptions.UnexpectedModelBehavior(
|
|
476
474
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
@@ -640,14 +638,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
|
|
|
640
638
|
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
|
|
641
639
|
|
|
642
640
|
|
|
643
|
-
def
|
|
641
|
+
def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
|
|
644
642
|
metadata = response.get('usage_metadata')
|
|
645
643
|
if metadata is None:
|
|
646
|
-
return result.
|
|
644
|
+
return result.Usage()
|
|
647
645
|
details: dict[str, int] = {}
|
|
648
646
|
if cached_content_token_count := metadata.get('cached_content_token_count'):
|
|
649
647
|
details['cached_content_token_count'] = cached_content_token_count
|
|
650
|
-
return result.
|
|
648
|
+
return result.Usage(
|
|
651
649
|
request_tokens=metadata.get('prompt_token_count', 0),
|
|
652
650
|
response_tokens=metadata.get('candidates_token_count', 0),
|
|
653
651
|
total_tokens=metadata.get('total_token_count', 0),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsJson,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -25,7 +24,7 @@ from ..messages import (
|
|
|
25
24
|
ToolReturnPart,
|
|
26
25
|
UserPromptPart,
|
|
27
26
|
)
|
|
28
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
29
28
|
from ..settings import ModelSettings
|
|
30
29
|
from ..tools import ToolDefinition
|
|
31
30
|
from . import (
|
|
@@ -158,9 +157,9 @@ class GroqAgentModel(AgentModel):
|
|
|
158
157
|
|
|
159
158
|
async def request(
|
|
160
159
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
-
) -> tuple[ModelResponse, result.
|
|
160
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
162
161
|
response = await self._completions_create(messages, False, model_settings)
|
|
163
|
-
return self._process_response(response),
|
|
162
|
+
return self._process_response(response), _map_usage(response)
|
|
164
163
|
|
|
165
164
|
@asynccontextmanager
|
|
166
165
|
async def request_stream(
|
|
@@ -221,14 +220,14 @@ class GroqAgentModel(AgentModel):
|
|
|
221
220
|
items.append(TextPart(choice.message.content))
|
|
222
221
|
if choice.message.tool_calls is not None:
|
|
223
222
|
for c in choice.message.tool_calls:
|
|
224
|
-
items.append(ToolCallPart.
|
|
223
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
225
224
|
return ModelResponse(items, timestamp=timestamp)
|
|
226
225
|
|
|
227
226
|
@staticmethod
|
|
228
227
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
229
228
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
230
229
|
timestamp: datetime | None = None
|
|
231
|
-
|
|
230
|
+
start_usage = Usage()
|
|
232
231
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
233
232
|
while True:
|
|
234
233
|
try:
|
|
@@ -236,19 +235,19 @@ class GroqAgentModel(AgentModel):
|
|
|
236
235
|
except StopAsyncIteration as e:
|
|
237
236
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
238
237
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
239
|
-
|
|
238
|
+
start_usage += _map_usage(chunk)
|
|
240
239
|
|
|
241
240
|
if chunk.choices:
|
|
242
241
|
delta = chunk.choices[0].delta
|
|
243
242
|
|
|
244
243
|
if delta.content is not None:
|
|
245
|
-
return GroqStreamTextResponse(delta.content, response, timestamp,
|
|
244
|
+
return GroqStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
246
245
|
elif delta.tool_calls is not None:
|
|
247
246
|
return GroqStreamStructuredResponse(
|
|
248
247
|
response,
|
|
249
248
|
{c.index: c for c in delta.tool_calls},
|
|
250
249
|
timestamp,
|
|
251
|
-
|
|
250
|
+
start_usage,
|
|
252
251
|
)
|
|
253
252
|
|
|
254
253
|
@classmethod
|
|
@@ -308,7 +307,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
308
307
|
_first: str | None
|
|
309
308
|
_response: AsyncStream[ChatCompletionChunk]
|
|
310
309
|
_timestamp: datetime
|
|
311
|
-
|
|
310
|
+
_usage: result.Usage
|
|
312
311
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
313
312
|
|
|
314
313
|
async def __anext__(self) -> None:
|
|
@@ -318,7 +317,7 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
318
317
|
return None
|
|
319
318
|
|
|
320
319
|
chunk = await self._response.__anext__()
|
|
321
|
-
self.
|
|
320
|
+
self._usage = _map_usage(chunk)
|
|
322
321
|
|
|
323
322
|
try:
|
|
324
323
|
choice = chunk.choices[0]
|
|
@@ -335,8 +334,8 @@ class GroqStreamTextResponse(StreamTextResponse):
|
|
|
335
334
|
yield from self._buffer
|
|
336
335
|
self._buffer.clear()
|
|
337
336
|
|
|
338
|
-
def
|
|
339
|
-
return self.
|
|
337
|
+
def usage(self) -> Usage:
|
|
338
|
+
return self._usage
|
|
340
339
|
|
|
341
340
|
def timestamp(self) -> datetime:
|
|
342
341
|
return self._timestamp
|
|
@@ -349,11 +348,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
349
348
|
_response: AsyncStream[ChatCompletionChunk]
|
|
350
349
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
351
350
|
_timestamp: datetime
|
|
352
|
-
|
|
351
|
+
_usage: result.Usage
|
|
353
352
|
|
|
354
353
|
async def __anext__(self) -> None:
|
|
355
354
|
chunk = await self._response.__anext__()
|
|
356
|
-
self.
|
|
355
|
+
self._usage = _map_usage(chunk)
|
|
357
356
|
|
|
358
357
|
try:
|
|
359
358
|
choice = chunk.choices[0]
|
|
@@ -380,27 +379,26 @@ class GroqStreamStructuredResponse(StreamStructuredResponse):
|
|
|
380
379
|
for c in self._delta_tool_calls.values():
|
|
381
380
|
if f := c.function:
|
|
382
381
|
if f.name is not None and f.arguments is not None:
|
|
383
|
-
items.append(ToolCallPart.
|
|
382
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
384
383
|
|
|
385
384
|
return ModelResponse(items, timestamp=self._timestamp)
|
|
386
385
|
|
|
387
|
-
def
|
|
388
|
-
return self.
|
|
386
|
+
def usage(self) -> Usage:
|
|
387
|
+
return self._usage
|
|
389
388
|
|
|
390
389
|
def timestamp(self) -> datetime:
|
|
391
390
|
return self._timestamp
|
|
392
391
|
|
|
393
392
|
|
|
394
393
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
395
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
396
394
|
return chat.ChatCompletionMessageToolCallParam(
|
|
397
395
|
id=_guard_tool_call_id(t=t, model_source='Groq'),
|
|
398
396
|
type='function',
|
|
399
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
397
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
400
398
|
)
|
|
401
399
|
|
|
402
400
|
|
|
403
|
-
def
|
|
401
|
+
def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage:
|
|
404
402
|
usage = None
|
|
405
403
|
if isinstance(completion, ChatCompletion):
|
|
406
404
|
usage = completion.usage
|
|
@@ -408,9 +406,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost:
|
|
|
408
406
|
usage = completion.x_groq.usage
|
|
409
407
|
|
|
410
408
|
if usage is None:
|
|
411
|
-
return result.
|
|
409
|
+
return result.Usage()
|
|
412
410
|
|
|
413
|
-
return result.
|
|
411
|
+
return result.Usage(
|
|
414
412
|
request_tokens=usage.prompt_tokens,
|
|
415
413
|
response_tokens=usage.completion_tokens,
|
|
416
414
|
total_tokens=usage.total_tokens,
|