pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +82 -74
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +218 -9
- pydantic_ai/models/__init__.py +31 -72
- pydantic_ai/models/anthropic.py +21 -21
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +76 -122
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/models/test.py
CHANGED
|
@@ -2,21 +2,22 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
4
|
import string
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import InitVar, dataclass, field
|
|
8
8
|
from datetime import date, datetime, timedelta
|
|
9
9
|
from typing import Any, Literal
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
|
-
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import _utils
|
|
15
14
|
from ..messages import (
|
|
15
|
+
ArgsJson,
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
TextPart,
|
|
22
23
|
ToolCallPart,
|
|
@@ -27,12 +28,10 @@ from ..settings import ModelSettings
|
|
|
27
28
|
from ..tools import ToolDefinition
|
|
28
29
|
from . import (
|
|
29
30
|
AgentModel,
|
|
30
|
-
EitherStreamedResponse,
|
|
31
31
|
Model,
|
|
32
|
-
|
|
33
|
-
StreamTextResponse,
|
|
32
|
+
StreamedResponse,
|
|
34
33
|
)
|
|
35
|
-
from .function import
|
|
34
|
+
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
36
35
|
|
|
37
36
|
|
|
38
37
|
@dataclass
|
|
@@ -141,25 +140,9 @@ class TestAgentModel(AgentModel):
|
|
|
141
140
|
@asynccontextmanager
|
|
142
141
|
async def request_stream(
|
|
143
142
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
144
|
-
) -> AsyncIterator[
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
# TODO: Rework this once we make StreamTextResponse more general
|
|
149
|
-
texts: list[str] = []
|
|
150
|
-
tool_calls: list[ToolCallPart] = []
|
|
151
|
-
for item in msg.parts:
|
|
152
|
-
if isinstance(item, TextPart):
|
|
153
|
-
texts.append(item.content)
|
|
154
|
-
elif isinstance(item, ToolCallPart):
|
|
155
|
-
tool_calls.append(item)
|
|
156
|
-
else:
|
|
157
|
-
assert_never(item)
|
|
158
|
-
|
|
159
|
-
if texts:
|
|
160
|
-
yield TestStreamTextResponse('\n\n'.join(texts), usage)
|
|
161
|
-
else:
|
|
162
|
-
yield TestStreamStructuredResponse(msg, usage)
|
|
143
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
144
|
+
model_response = self._request(messages, model_settings)
|
|
145
|
+
yield TestStreamedResponse(model_response, messages)
|
|
163
146
|
|
|
164
147
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
165
148
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -223,58 +206,37 @@ class TestAgentModel(AgentModel):
|
|
|
223
206
|
|
|
224
207
|
|
|
225
208
|
@dataclass
|
|
226
|
-
class
|
|
227
|
-
"""A text response that streams test data."""
|
|
228
|
-
|
|
229
|
-
_text: str
|
|
230
|
-
_usage: Usage
|
|
231
|
-
_iter: Iterator[str] = field(init=False)
|
|
232
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
233
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
234
|
-
|
|
235
|
-
def __post_init__(self):
|
|
236
|
-
*words, last_word = self._text.split(' ')
|
|
237
|
-
words = [f'{word} ' for word in words]
|
|
238
|
-
words.append(last_word)
|
|
239
|
-
if len(words) == 1 and len(self._text) > 2:
|
|
240
|
-
mid = len(self._text) // 2
|
|
241
|
-
words = [self._text[:mid], self._text[mid:]]
|
|
242
|
-
self._iter = iter(words)
|
|
243
|
-
|
|
244
|
-
async def __anext__(self) -> None:
|
|
245
|
-
next_str = _utils.sync_anext(self._iter)
|
|
246
|
-
response_tokens = _estimate_string_usage(next_str)
|
|
247
|
-
self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
248
|
-
self._buffer.append(next_str)
|
|
249
|
-
|
|
250
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
251
|
-
yield from self._buffer
|
|
252
|
-
self._buffer.clear()
|
|
253
|
-
|
|
254
|
-
def usage(self) -> Usage:
|
|
255
|
-
return self._usage
|
|
256
|
-
|
|
257
|
-
def timestamp(self) -> datetime:
|
|
258
|
-
return self._timestamp
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
@dataclass
|
|
262
|
-
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
209
|
+
class TestStreamedResponse(StreamedResponse):
|
|
263
210
|
"""A structured response that streams test data."""
|
|
264
211
|
|
|
265
212
|
_structured_response: ModelResponse
|
|
266
|
-
|
|
267
|
-
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
268
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
269
|
-
|
|
270
|
-
async def __anext__(self) -> None:
|
|
271
|
-
return _utils.sync_anext(self._iter)
|
|
213
|
+
_messages: InitVar[Iterable[ModelMessage]]
|
|
272
214
|
|
|
273
|
-
|
|
274
|
-
return self._structured_response
|
|
215
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
275
216
|
|
|
276
|
-
def
|
|
277
|
-
|
|
217
|
+
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
218
|
+
self._usage = _estimate_usage(_messages)
|
|
219
|
+
|
|
220
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
221
|
+
for i, part in enumerate(self._structured_response.parts):
|
|
222
|
+
if isinstance(part, TextPart):
|
|
223
|
+
text = part.content
|
|
224
|
+
*words, last_word = text.split(' ')
|
|
225
|
+
words = [f'{word} ' for word in words]
|
|
226
|
+
words.append(last_word)
|
|
227
|
+
if len(words) == 1 and len(text) > 2:
|
|
228
|
+
mid = len(text) // 2
|
|
229
|
+
words = [text[:mid], text[mid:]]
|
|
230
|
+
self._usage += _get_string_usage('')
|
|
231
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
|
|
232
|
+
for word in words:
|
|
233
|
+
self._usage += _get_string_usage(word)
|
|
234
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
|
|
235
|
+
else:
|
|
236
|
+
args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
|
|
237
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
238
|
+
vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
|
|
239
|
+
)
|
|
278
240
|
|
|
279
241
|
def timestamp(self) -> datetime:
|
|
280
242
|
return self._timestamp
|
|
@@ -434,3 +396,8 @@ class _JsonSchemaTestData:
|
|
|
434
396
|
rem //= chars
|
|
435
397
|
s += _chars[self.seed % chars]
|
|
436
398
|
return s
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _get_string_usage(text: str) -> Usage:
|
|
402
|
+
response_tokens = _estimate_string_tokens(text)
|
|
403
|
+
return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
pydantic_ai/result.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
5
5
|
from copy import deepcopy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
@@ -169,7 +169,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
169
169
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
170
170
|
|
|
171
171
|
_usage_limits: UsageLimits | None
|
|
172
|
-
_stream_response: models.
|
|
172
|
+
_stream_response: models.StreamedResponse
|
|
173
173
|
_result_schema: _result.ResultSchema[ResultData] | None
|
|
174
174
|
_run_ctx: RunContext[AgentDeps]
|
|
175
175
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
|
|
@@ -200,20 +200,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
200
200
|
Returns:
|
|
201
201
|
An async iterable of the response data.
|
|
202
202
|
"""
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
else:
|
|
207
|
-
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
|
|
208
|
-
yield await self.validate_structured_result(structured_message, allow_partial=not is_last)
|
|
203
|
+
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
|
|
204
|
+
result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
|
|
205
|
+
yield result
|
|
209
206
|
|
|
210
207
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
211
208
|
"""Stream the text result as an async iterable.
|
|
212
209
|
|
|
213
|
-
!!! note
|
|
214
|
-
This method will fail if the response is structured,
|
|
215
|
-
e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `True`.
|
|
216
|
-
|
|
217
210
|
!!! note
|
|
218
211
|
Result validators will NOT be called on the text result if `delta=True`.
|
|
219
212
|
|
|
@@ -224,54 +217,65 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
224
217
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
225
218
|
performing validation as each token is received.
|
|
226
219
|
"""
|
|
220
|
+
if self._result_schema and not self._result_schema.allow_text_result:
|
|
221
|
+
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
222
|
+
|
|
227
223
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
228
224
|
self._stream_response, self._usage_limits, self.usage
|
|
229
225
|
)
|
|
230
226
|
|
|
227
|
+
# Define a "merged" version of the iterator that will yield items that have already been retrieved
|
|
228
|
+
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
|
|
229
|
+
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
|
|
230
|
+
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
|
|
231
|
+
# if the response currently has any parts with content, yield those before streaming
|
|
232
|
+
msg = self._stream_response.get()
|
|
233
|
+
for i, part in enumerate(msg.parts):
|
|
234
|
+
if isinstance(part, _messages.TextPart) and part.content:
|
|
235
|
+
yield part.content, i
|
|
236
|
+
|
|
237
|
+
async for event in usage_checking_stream:
|
|
238
|
+
if (
|
|
239
|
+
isinstance(event, _messages.PartStartEvent)
|
|
240
|
+
and isinstance(event.part, _messages.TextPart)
|
|
241
|
+
and event.part.content
|
|
242
|
+
):
|
|
243
|
+
yield event.part.content, event.index
|
|
244
|
+
elif (
|
|
245
|
+
isinstance(event, _messages.PartDeltaEvent)
|
|
246
|
+
and isinstance(event.delta, _messages.TextPartDelta)
|
|
247
|
+
and event.delta.content_delta
|
|
248
|
+
):
|
|
249
|
+
yield event.delta.content_delta, event.index
|
|
250
|
+
|
|
251
|
+
async def _stream_text_deltas() -> AsyncIterator[str]:
|
|
252
|
+
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
|
|
253
|
+
async for items in group_iter:
|
|
254
|
+
yield ''.join([content for content, _ in items])
|
|
255
|
+
|
|
231
256
|
with _logfire.span('response stream text') as lf_span:
|
|
232
|
-
if isinstance(self._stream_response, models.StreamStructuredResponse):
|
|
233
|
-
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
234
257
|
if delta:
|
|
235
|
-
async
|
|
236
|
-
|
|
237
|
-
yield ''.join(self._stream_response.get())
|
|
238
|
-
final_delta = ''.join(self._stream_response.get(final=True))
|
|
239
|
-
if final_delta:
|
|
240
|
-
yield final_delta
|
|
258
|
+
async for text in _stream_text_deltas():
|
|
259
|
+
yield text
|
|
241
260
|
else:
|
|
242
261
|
# a quick benchmark shows it's faster to build up a string with concat when we're
|
|
243
262
|
# yielding at each step
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
async
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
yield combined
|
|
255
|
-
|
|
256
|
-
new = False
|
|
257
|
-
for chunk in self._stream_response.get(final=True):
|
|
258
|
-
chunks.append(chunk)
|
|
259
|
-
new = True
|
|
260
|
-
if new:
|
|
261
|
-
combined = await self._validate_text_result(''.join(chunks))
|
|
262
|
-
yield combined
|
|
263
|
-
lf_span.set_attribute('combined_text', combined)
|
|
264
|
-
await self._marked_completed(_messages.ModelResponse.from_text(combined))
|
|
263
|
+
deltas: list[str] = []
|
|
264
|
+
combined_validated_text = ''
|
|
265
|
+
async for text in _stream_text_deltas():
|
|
266
|
+
deltas.append(text)
|
|
267
|
+
combined_text = ''.join(deltas)
|
|
268
|
+
combined_validated_text = await self._validate_text_result(combined_text)
|
|
269
|
+
yield combined_validated_text
|
|
270
|
+
|
|
271
|
+
lf_span.set_attribute('combined_text', combined_validated_text)
|
|
272
|
+
await self._marked_completed(_messages.ModelResponse.from_text(combined_validated_text))
|
|
265
273
|
|
|
266
274
|
async def stream_structured(
|
|
267
275
|
self, *, debounce_by: float | None = 0.1
|
|
268
276
|
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
269
277
|
"""Stream the response as an async iterable of Structured LLM Messages.
|
|
270
278
|
|
|
271
|
-
!!! note
|
|
272
|
-
This method will fail if the response is text,
|
|
273
|
-
e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `False`.
|
|
274
|
-
|
|
275
279
|
Args:
|
|
276
280
|
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
277
281
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
@@ -285,24 +289,20 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
285
289
|
)
|
|
286
290
|
|
|
287
291
|
with _logfire.span('response stream structured') as lf_span:
|
|
288
|
-
if
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
+
# if the message currently has any parts with content, yield before streaming
|
|
293
|
+
msg = self._stream_response.get()
|
|
294
|
+
for part in msg.parts:
|
|
295
|
+
if part.has_content():
|
|
296
|
+
yield msg, False
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
300
|
+
async for _events in group_iter:
|
|
301
|
+
msg = self._stream_response.get()
|
|
302
|
+
yield msg, False
|
|
292
303
|
msg = self._stream_response.get()
|
|
293
|
-
for item in msg.parts:
|
|
294
|
-
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
295
|
-
yield msg, False
|
|
296
|
-
break
|
|
297
|
-
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
298
|
-
async for _ in group_iter:
|
|
299
|
-
msg = self._stream_response.get()
|
|
300
|
-
for item in msg.parts:
|
|
301
|
-
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
302
|
-
yield msg, False
|
|
303
|
-
break
|
|
304
|
-
msg = self._stream_response.get(final=True)
|
|
305
304
|
yield msg, True
|
|
305
|
+
# TODO: Should this now be `final_response` instead of `structured_response`?
|
|
306
306
|
lf_span.set_attribute('structured_response', msg)
|
|
307
307
|
await self._marked_completed(msg)
|
|
308
308
|
|
|
@@ -314,21 +314,9 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
314
314
|
|
|
315
315
|
async for _ in usage_checking_stream:
|
|
316
316
|
pass
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
text = await self._validate_text_result(text)
|
|
321
|
-
await self._marked_completed(_messages.ModelResponse.from_text(text))
|
|
322
|
-
return cast(ResultData, text)
|
|
323
|
-
else:
|
|
324
|
-
message = self._stream_response.get(final=True)
|
|
325
|
-
await self._marked_completed(message)
|
|
326
|
-
return await self.validate_structured_result(message)
|
|
327
|
-
|
|
328
|
-
@property
|
|
329
|
-
def is_structured(self) -> bool:
|
|
330
|
-
"""Return whether the stream response contains structured data (as opposed to text)."""
|
|
331
|
-
return isinstance(self._stream_response, models.StreamStructuredResponse)
|
|
317
|
+
message = self._stream_response.get()
|
|
318
|
+
await self._marked_completed(message)
|
|
319
|
+
return await self.validate_structured_result(message)
|
|
332
320
|
|
|
333
321
|
def usage(self) -> Usage:
|
|
334
322
|
"""Return the usage of the whole run.
|
|
@@ -346,20 +334,29 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
346
334
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
347
335
|
) -> ResultData:
|
|
348
336
|
"""Validate a structured result message."""
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
337
|
+
if self._result_schema is not None and self._result_tool_name is not None:
|
|
338
|
+
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
|
|
339
|
+
if match is None:
|
|
340
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
341
|
+
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
call, result_tool = match
|
|
345
|
+
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
346
|
+
|
|
347
|
+
for validator in self._result_validators:
|
|
348
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
349
|
+
return result_data
|
|
350
|
+
else:
|
|
351
|
+
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
352
|
+
for validator in self._result_validators:
|
|
353
|
+
text = await validator.validate(
|
|
354
|
+
text, # pyright: ignore[reportArgumentType]
|
|
355
|
+
None,
|
|
356
|
+
self._run_ctx,
|
|
357
|
+
)
|
|
358
|
+
# Since there is no result tool, we can assume that str is compatible with ResultData
|
|
359
|
+
return cast(ResultData, text)
|
|
363
360
|
|
|
364
361
|
async def _validate_text_result(self, text: str) -> str:
|
|
365
362
|
for validator in self._result_validators:
|
|
@@ -377,8 +374,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
377
374
|
|
|
378
375
|
|
|
379
376
|
def _get_usage_checking_stream_response(
|
|
380
|
-
stream_response:
|
|
381
|
-
|
|
377
|
+
stream_response: AsyncIterable[_messages.ModelResponseStreamEvent],
|
|
378
|
+
limits: UsageLimits | None,
|
|
379
|
+
get_usage: Callable[[], Usage],
|
|
380
|
+
) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
|
|
382
381
|
if limits is not None and limits.has_token_limits():
|
|
383
382
|
|
|
384
383
|
async def _usage_checking_iterator():
|
pydantic_ai/tools.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import Awaitable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from pydantic_core import SchemaValidator
|
|
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
|
|
18
18
|
|
|
19
19
|
__all__ = (
|
|
20
20
|
'AgentDeps',
|
|
21
|
+
'DocstringFormat',
|
|
21
22
|
'RunContext',
|
|
22
23
|
'SystemPromptFunc',
|
|
23
24
|
'ToolFuncContext',
|
|
@@ -106,7 +107,7 @@ See [tool docs](../tools.md#tool-prepare) for more information.
|
|
|
106
107
|
|
|
107
108
|
Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
|
|
108
109
|
|
|
109
|
-
```python {
|
|
110
|
+
```python {noqa="I001"}
|
|
110
111
|
from typing import Union
|
|
111
112
|
|
|
112
113
|
from pydantic_ai import RunContext, Tool
|
|
@@ -127,6 +128,15 @@ hitchhiker = Tool(hitchhiker, prepare=only_if_42)
|
|
|
127
128
|
Usage `ToolPrepareFunc[AgentDeps]`.
|
|
128
129
|
"""
|
|
129
130
|
|
|
131
|
+
DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
|
|
132
|
+
"""Supported docstring formats.
|
|
133
|
+
|
|
134
|
+
* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
|
|
135
|
+
* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
|
|
136
|
+
* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
|
|
137
|
+
* `'auto'` — Automatically infer the format based on the structure of the docstring.
|
|
138
|
+
"""
|
|
139
|
+
|
|
130
140
|
A = TypeVar('A')
|
|
131
141
|
|
|
132
142
|
|
|
@@ -140,6 +150,8 @@ class Tool(Generic[AgentDeps]):
|
|
|
140
150
|
name: str
|
|
141
151
|
description: str
|
|
142
152
|
prepare: ToolPrepareFunc[AgentDeps] | None
|
|
153
|
+
docstring_format: DocstringFormat
|
|
154
|
+
require_parameter_descriptions: bool
|
|
143
155
|
_is_async: bool = field(init=False)
|
|
144
156
|
_single_arg_name: str | None = field(init=False)
|
|
145
157
|
_positional_fields: list[str] = field(init=False)
|
|
@@ -157,12 +169,14 @@ class Tool(Generic[AgentDeps]):
|
|
|
157
169
|
name: str | None = None,
|
|
158
170
|
description: str | None = None,
|
|
159
171
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
172
|
+
docstring_format: DocstringFormat = 'auto',
|
|
173
|
+
require_parameter_descriptions: bool = False,
|
|
160
174
|
):
|
|
161
175
|
"""Create a new tool instance.
|
|
162
176
|
|
|
163
177
|
Example usage:
|
|
164
178
|
|
|
165
|
-
```python {
|
|
179
|
+
```python {noqa="I001"}
|
|
166
180
|
from pydantic_ai import Agent, RunContext, Tool
|
|
167
181
|
|
|
168
182
|
async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
|
|
@@ -173,7 +187,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
173
187
|
|
|
174
188
|
or with a custom prepare method:
|
|
175
189
|
|
|
176
|
-
```python {
|
|
190
|
+
```python {noqa="I001"}
|
|
177
191
|
from typing import Union
|
|
178
192
|
|
|
179
193
|
from pydantic_ai import Agent, RunContext, Tool
|
|
@@ -203,17 +217,22 @@ class Tool(Generic[AgentDeps]):
|
|
|
203
217
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
204
218
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
205
219
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
220
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
221
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
222
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
206
223
|
"""
|
|
207
224
|
if takes_ctx is None:
|
|
208
225
|
takes_ctx = _pydantic.takes_ctx(function)
|
|
209
226
|
|
|
210
|
-
f = _pydantic.function_schema(function, takes_ctx)
|
|
227
|
+
f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
|
|
211
228
|
self.function = function
|
|
212
229
|
self.takes_ctx = takes_ctx
|
|
213
230
|
self.max_retries = max_retries
|
|
214
231
|
self.name = name or function.__name__
|
|
215
232
|
self.description = description or f['description']
|
|
216
233
|
self.prepare = prepare
|
|
234
|
+
self.docstring_format = docstring_format
|
|
235
|
+
self.require_parameter_descriptions = require_parameter_descriptions
|
|
217
236
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
218
237
|
self._single_arg_name = f['single_arg_name']
|
|
219
238
|
self._positional_fields = f['positional_fields']
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -31,6 +31,8 @@ Requires-Dist: logfire-api>=1.2.0
|
|
|
31
31
|
Requires-Dist: pydantic>=2.10
|
|
32
32
|
Provides-Extra: anthropic
|
|
33
33
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
34
|
+
Provides-Extra: graph
|
|
35
|
+
Requires-Dist: pydantic-graph==0.0.19; extra == 'graph'
|
|
34
36
|
Provides-Extra: groq
|
|
35
37
|
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
36
38
|
Provides-Extra: logfire
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
|
|
2
|
+
pydantic_ai/_griffe.py,sha256=RYRKiLbgG97QxnazbAwlnc74XxevGHLQet-FGfq9qls,3960
|
|
3
|
+
pydantic_ai/_parts_manager.py,sha256=pMDZs6BGC8EmaNa-73QvuptmxdG2MhBrBLIydCOl-gM,11886
|
|
4
|
+
pydantic_ai/_pydantic.py,sha256=Zvjd2te6EzPrFnz--oDSdqwZuPw3vCiflTHriRhpNsY,8698
|
|
5
|
+
pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
|
|
6
|
+
pydantic_ai/_system_prompt.py,sha256=Fsl1K6GdQP0WhWBzvJxCc5uTqCD06lHjJlTADah-PI0,1116
|
|
7
|
+
pydantic_ai/_utils.py,sha256=EHW866W6ZpGJLCWtoEAcwIPeWo9OQFhnD5el2DwVcwc,10949
|
|
8
|
+
pydantic_ai/agent.py,sha256=Z_79gw4BIJooBIqJwPbnDHvmBcCXp2dbNd_832tc_do,62500
|
|
9
|
+
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
10
|
+
pydantic_ai/format_as_xml.py,sha256=QE7eMlg5-YUMw1_2kcI3h0uKYPZZyGkgXFDtfZTMeeI,4480
|
|
11
|
+
pydantic_ai/messages.py,sha256=b4RpaXogREquE8WHlGPMm0UGTNx2QtePV5GYk-9EscY,18185
|
|
12
|
+
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
pydantic_ai/result.py,sha256=93ZLxr2jPx0cZeslHgphJ6XJnQMEybktzQ_LUT47h3Q,17429
|
|
14
|
+
pydantic_ai/settings.py,sha256=oTk8ZfYuUsNxpJMWLvSrO1OH_0ur7VKgDNTMQG0tPyM,1974
|
|
15
|
+
pydantic_ai/tools.py,sha256=iwa2PyhnKmvh_njy4aMfRIh7AP5igDIZ1ZPvgvvn6bM,13018
|
|
16
|
+
pydantic_ai/usage.py,sha256=60d9f6M7YEYuKMbqDGDogX4KsA73fhDtWyDXYXoIPaI,4948
|
|
17
|
+
pydantic_ai/models/__init__.py,sha256=Q4_fHy48szaA_TrIW3LZhRXiDUlhPAYf8LkhinSP3s8,10883
|
|
18
|
+
pydantic_ai/models/anthropic.py,sha256=MkFqy2F7SPb_qAbgzc04iZWmVuoEBgn30v1HY1Wjadc,13543
|
|
19
|
+
pydantic_ai/models/function.py,sha256=iT4XT8VEaJbNwYJAWtrI_jbRAb2tZO6UL93ErR4RYhM,9629
|
|
20
|
+
pydantic_ai/models/gemini.py,sha256=3RTVQBAI1jWL3Xx_hi7qdy_6H-kTeuAOTPELnlVtPp4,27498
|
|
21
|
+
pydantic_ai/models/groq.py,sha256=kzQSFT-04WmQmdRaB6Wj0mxHeAXIgyrryZkptNiA4Ng,13211
|
|
22
|
+
pydantic_ai/models/mistral.py,sha256=qyYOLBpOdI5iPBmQxf5jp1d17sxqa1r8GJ7tb4yE45U,24549
|
|
23
|
+
pydantic_ai/models/ollama.py,sha256=aHI8pNw7fqOOgvlEWcTnTYTmhf0cGg41x-p5sUQr2_k,4200
|
|
24
|
+
pydantic_ai/models/openai.py,sha256=FzV6OCuK4Sr_J2GTuM-6Vu9NbDyZPxllwQPmssdOtbQ,13774
|
|
25
|
+
pydantic_ai/models/test.py,sha256=0m2Pdn0xJMjvAVekVIoADQL0aSkOnGZJct9k4WvImrQ,15880
|
|
26
|
+
pydantic_ai/models/vertexai.py,sha256=dHGrmLMgekWAEOZkLsO5rwDtQ6mjPixvn0umlvWAZok,9323
|
|
27
|
+
pydantic_ai_slim-0.0.19.dist-info/METADATA,sha256=lnGlda0-tCapsWI72DyzGV9Sppm5I7koWbb7-xEpWcU,2808
|
|
28
|
+
pydantic_ai_slim-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
29
|
+
pydantic_ai_slim-0.0.19.dist-info/RECORD,,
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
pydantic_ai/__init__.py,sha256=FbYetEgT6OO25u2KF5ZnFxKpz5DtnSpfckRXP4mjl8E,489
|
|
2
|
-
pydantic_ai/_griffe.py,sha256=407vemPed1Eeao-8sqAeC5cHGa-5SK55OzQwWk72Sl8,3795
|
|
3
|
-
pydantic_ai/_pydantic.py,sha256=qXi5IsyiYOHeg_-qozCdxkfeqw2z0gBTjqgywBCiJWo,8125
|
|
4
|
-
pydantic_ai/_result.py,sha256=cUSugZQV0n5Z4fFHiMqua-2xs_0S6m-rr-yd6QS3nFE,10317
|
|
5
|
-
pydantic_ai/_system_prompt.py,sha256=Fsl1K6GdQP0WhWBzvJxCc5uTqCD06lHjJlTADah-PI0,1116
|
|
6
|
-
pydantic_ai/_utils.py,sha256=skWNgm89US_x1EpxdRy5wCkghBrm1XgxFCiEh6wAkAo,8753
|
|
7
|
-
pydantic_ai/agent.py,sha256=n6yu8gHhMWknalru1tTEfNWkOt_qqqD9zMLjVeHWJ7U,61593
|
|
8
|
-
pydantic_ai/exceptions.py,sha256=eGDKX6bGhgVxXBzu81Sk3iiAkXr0GUtgT7bD5Rxlqpg,2028
|
|
9
|
-
pydantic_ai/format_as_xml.py,sha256=Gm65687GL8Z6A_lPiJWL1O_E3ovHEBn2O1DKhn1CDnA,4472
|
|
10
|
-
pydantic_ai/messages.py,sha256=RCtzsJFkhKBwIXNYOVAcNx0Kmnd0iAjSvwpnHmaAQt0,9211
|
|
11
|
-
pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
pydantic_ai/result.py,sha256=-dpaaD24E1Ns7fxz5Gn7SKou-A8Cag4LjEyCBJbrHzY,17597
|
|
13
|
-
pydantic_ai/settings.py,sha256=oTk8ZfYuUsNxpJMWLvSrO1OH_0ur7VKgDNTMQG0tPyM,1974
|
|
14
|
-
pydantic_ai/tools.py,sha256=G4lwAb7QIowtSHk7w5cH8WQFIFqwMPn0J6Nqhgz7ubA,11757
|
|
15
|
-
pydantic_ai/usage.py,sha256=60d9f6M7YEYuKMbqDGDogX4KsA73fhDtWyDXYXoIPaI,4948
|
|
16
|
-
pydantic_ai/models/__init__.py,sha256=nAsE9pcqAW68pluxX332Z7sVomhVWEaU20F4Oi57ojs,11754
|
|
17
|
-
pydantic_ai/models/anthropic.py,sha256=VyhLeNc585xann5we3obOWKUjIv6cKF6wYzhGHAAmvo,13466
|
|
18
|
-
pydantic_ai/models/function.py,sha256=i7qkS_31aHrTbYVh6OzQ7Cwucz44F5PjT2EJK3GMphw,10573
|
|
19
|
-
pydantic_ai/models/gemini.py,sha256=xvwPGYlZhQUYunu3LpWWbDfp_97Q4foLMaaLzYgyLFM,28745
|
|
20
|
-
pydantic_ai/models/groq.py,sha256=ZoPkuWJrf78JPnTRfZhi7v0ETgxJKNN5dH8BLWagGGk,15770
|
|
21
|
-
pydantic_ai/models/mistral.py,sha256=xGVI6-b8-9vnFickPPI2cRaHEWLc0jKKUM_vMjipf-U,25894
|
|
22
|
-
pydantic_ai/models/ollama.py,sha256=ELqxhcNcnvQBnadd3gukS01zprUp6v8N_h1P5K-uf6c,4188
|
|
23
|
-
pydantic_ai/models/openai.py,sha256=qFFInL3NbgfGcsAWigxMP5mscp76hC-jJimHc9woU6Y,16518
|
|
24
|
-
pydantic_ai/models/test.py,sha256=u2pdZd9OLXQ_jI6CaVt96udXuIcv0Hfnfqd3pFGmeJM,16514
|
|
25
|
-
pydantic_ai/models/vertexai.py,sha256=dHGrmLMgekWAEOZkLsO5rwDtQ6mjPixvn0umlvWAZok,9323
|
|
26
|
-
pydantic_ai_slim-0.0.18.dist-info/METADATA,sha256=jdEPVU8__Zt4lmd3KYV3MLW7LLUrMxNvqbGJ761F6C0,2730
|
|
27
|
-
pydantic_ai_slim-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
28
|
-
pydantic_ai_slim-0.0.18.dist-info/RECORD,,
|
|
File without changes
|