pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/result.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
|
5
5
|
from copy import deepcopy
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
@@ -11,35 +11,49 @@ import logfire_api
|
|
|
11
11
|
from typing_extensions import TypeVar
|
|
12
12
|
|
|
13
13
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
14
|
-
from .tools import
|
|
14
|
+
from .tools import AgentDepsT, RunContext
|
|
15
15
|
from .usage import Usage, UsageLimits
|
|
16
16
|
|
|
17
|
-
__all__ = '
|
|
17
|
+
__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
"""
|
|
20
|
+
T = TypeVar('T')
|
|
21
|
+
"""An invariant TypeVar."""
|
|
22
|
+
ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
|
|
23
|
+
"""
|
|
24
|
+
An invariant type variable for the result data of a model.
|
|
25
|
+
|
|
26
|
+
We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
|
|
27
|
+
in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
|
|
28
|
+
possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
|
|
29
|
+
changing it would have negative consequences for the ergonomics of the library.
|
|
30
|
+
|
|
31
|
+
At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
|
|
32
|
+
resolve these potential variance issues.
|
|
33
|
+
"""
|
|
34
|
+
ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
|
|
35
|
+
"""Covariant type variable for the result data type of a run."""
|
|
22
36
|
|
|
23
37
|
ResultValidatorFunc = Union[
|
|
24
|
-
Callable[[RunContext[
|
|
25
|
-
Callable[[RunContext[
|
|
26
|
-
Callable[[
|
|
27
|
-
Callable[[
|
|
38
|
+
Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
|
|
39
|
+
Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
|
|
40
|
+
Callable[[ResultDataT_inv], ResultDataT_inv],
|
|
41
|
+
Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
|
|
28
42
|
]
|
|
29
43
|
"""
|
|
30
|
-
A function that always takes
|
|
44
|
+
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
31
45
|
|
|
32
46
|
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
33
47
|
* may or may not be async
|
|
34
48
|
|
|
35
|
-
Usage `ResultValidatorFunc[AgentDeps,
|
|
49
|
+
Usage `ResultValidatorFunc[AgentDeps, T]`.
|
|
36
50
|
"""
|
|
37
51
|
|
|
38
52
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
39
53
|
|
|
40
54
|
|
|
41
55
|
@dataclass
|
|
42
|
-
class _BaseRunResult(ABC, Generic[
|
|
56
|
+
class _BaseRunResult(ABC, Generic[ResultDataT]):
|
|
43
57
|
"""Base type for results.
|
|
44
58
|
|
|
45
59
|
You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
|
|
@@ -119,10 +133,10 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
119
133
|
|
|
120
134
|
|
|
121
135
|
@dataclass
|
|
122
|
-
class RunResult(_BaseRunResult[
|
|
136
|
+
class RunResult(_BaseRunResult[ResultDataT]):
|
|
123
137
|
"""Result of a non-streamed run."""
|
|
124
138
|
|
|
125
|
-
data:
|
|
139
|
+
data: ResultDataT
|
|
126
140
|
"""Data from the final response in the run."""
|
|
127
141
|
_result_tool_name: str | None
|
|
128
142
|
_usage: Usage
|
|
@@ -165,14 +179,14 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
165
179
|
|
|
166
180
|
|
|
167
181
|
@dataclass
|
|
168
|
-
class StreamedRunResult(_BaseRunResult[
|
|
182
|
+
class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]):
|
|
169
183
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
170
184
|
|
|
171
185
|
_usage_limits: UsageLimits | None
|
|
172
|
-
_stream_response: models.
|
|
173
|
-
_result_schema: _result.ResultSchema[
|
|
174
|
-
_run_ctx: RunContext[
|
|
175
|
-
_result_validators: list[_result.ResultValidator[
|
|
186
|
+
_stream_response: models.StreamedResponse
|
|
187
|
+
_result_schema: _result.ResultSchema[ResultDataT] | None
|
|
188
|
+
_run_ctx: RunContext[AgentDepsT]
|
|
189
|
+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
|
|
176
190
|
_result_tool_name: str | None
|
|
177
191
|
_on_complete: Callable[[], Awaitable[None]]
|
|
178
192
|
is_complete: bool = field(default=False, init=False)
|
|
@@ -185,7 +199,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
185
199
|
[`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
|
|
186
200
|
"""
|
|
187
201
|
|
|
188
|
-
async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[
|
|
202
|
+
async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
|
|
189
203
|
"""Stream the response as an async iterable.
|
|
190
204
|
|
|
191
205
|
The pydantic validator for structured data will be called in
|
|
@@ -200,20 +214,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
200
214
|
Returns:
|
|
201
215
|
An async iterable of the response data.
|
|
202
216
|
"""
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
else:
|
|
207
|
-
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
|
|
208
|
-
yield await self.validate_structured_result(structured_message, allow_partial=not is_last)
|
|
217
|
+
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
|
|
218
|
+
result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
|
|
219
|
+
yield result
|
|
209
220
|
|
|
210
221
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
211
222
|
"""Stream the text result as an async iterable.
|
|
212
223
|
|
|
213
|
-
!!! note
|
|
214
|
-
This method will fail if the response is structured,
|
|
215
|
-
e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `True`.
|
|
216
|
-
|
|
217
224
|
!!! note
|
|
218
225
|
Result validators will NOT be called on the text result if `delta=True`.
|
|
219
226
|
|
|
@@ -224,54 +231,70 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
224
231
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
225
232
|
performing validation as each token is received.
|
|
226
233
|
"""
|
|
234
|
+
if self._result_schema and not self._result_schema.allow_text_result:
|
|
235
|
+
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
236
|
+
|
|
227
237
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
228
238
|
self._stream_response, self._usage_limits, self.usage
|
|
229
239
|
)
|
|
230
240
|
|
|
241
|
+
# Define a "merged" version of the iterator that will yield items that have already been retrieved
|
|
242
|
+
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
|
|
243
|
+
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
|
|
244
|
+
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
|
|
245
|
+
# if the response currently has any parts with content, yield those before streaming
|
|
246
|
+
msg = self._stream_response.get()
|
|
247
|
+
for i, part in enumerate(msg.parts):
|
|
248
|
+
if isinstance(part, _messages.TextPart) and part.content:
|
|
249
|
+
yield part.content, i
|
|
250
|
+
|
|
251
|
+
async for event in usage_checking_stream:
|
|
252
|
+
if (
|
|
253
|
+
isinstance(event, _messages.PartStartEvent)
|
|
254
|
+
and isinstance(event.part, _messages.TextPart)
|
|
255
|
+
and event.part.content
|
|
256
|
+
):
|
|
257
|
+
yield event.part.content, event.index
|
|
258
|
+
elif (
|
|
259
|
+
isinstance(event, _messages.PartDeltaEvent)
|
|
260
|
+
and isinstance(event.delta, _messages.TextPartDelta)
|
|
261
|
+
and event.delta.content_delta
|
|
262
|
+
):
|
|
263
|
+
yield event.delta.content_delta, event.index
|
|
264
|
+
|
|
265
|
+
async def _stream_text_deltas() -> AsyncIterator[str]:
|
|
266
|
+
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
|
|
267
|
+
async for items in group_iter:
|
|
268
|
+
yield ''.join([content for content, _ in items])
|
|
269
|
+
|
|
231
270
|
with _logfire.span('response stream text') as lf_span:
|
|
232
|
-
if isinstance(self._stream_response, models.StreamStructuredResponse):
|
|
233
|
-
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
234
271
|
if delta:
|
|
235
|
-
async
|
|
236
|
-
|
|
237
|
-
yield ''.join(self._stream_response.get())
|
|
238
|
-
final_delta = ''.join(self._stream_response.get(final=True))
|
|
239
|
-
if final_delta:
|
|
240
|
-
yield final_delta
|
|
272
|
+
async for text in _stream_text_deltas():
|
|
273
|
+
yield text
|
|
241
274
|
else:
|
|
242
275
|
# a quick benchmark shows it's faster to build up a string with concat when we're
|
|
243
276
|
# yielding at each step
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
async
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
new = True
|
|
260
|
-
if new:
|
|
261
|
-
combined = await self._validate_text_result(''.join(chunks))
|
|
262
|
-
yield combined
|
|
263
|
-
lf_span.set_attribute('combined_text', combined)
|
|
264
|
-
await self._marked_completed(_messages.ModelResponse.from_text(combined))
|
|
277
|
+
deltas: list[str] = []
|
|
278
|
+
combined_validated_text = ''
|
|
279
|
+
async for text in _stream_text_deltas():
|
|
280
|
+
deltas.append(text)
|
|
281
|
+
combined_text = ''.join(deltas)
|
|
282
|
+
combined_validated_text = await self._validate_text_result(combined_text)
|
|
283
|
+
yield combined_validated_text
|
|
284
|
+
|
|
285
|
+
lf_span.set_attribute('combined_text', combined_validated_text)
|
|
286
|
+
await self._marked_completed(
|
|
287
|
+
_messages.ModelResponse(
|
|
288
|
+
parts=[_messages.TextPart(combined_validated_text)],
|
|
289
|
+
model_name=self._stream_response.model_name(),
|
|
290
|
+
)
|
|
291
|
+
)
|
|
265
292
|
|
|
266
293
|
async def stream_structured(
|
|
267
294
|
self, *, debounce_by: float | None = 0.1
|
|
268
295
|
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
269
296
|
"""Stream the response as an async iterable of Structured LLM Messages.
|
|
270
297
|
|
|
271
|
-
!!! note
|
|
272
|
-
This method will fail if the response is text,
|
|
273
|
-
e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `False`.
|
|
274
|
-
|
|
275
298
|
Args:
|
|
276
299
|
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
277
300
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
@@ -285,28 +308,24 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
285
308
|
)
|
|
286
309
|
|
|
287
310
|
with _logfire.span('response stream structured') as lf_span:
|
|
288
|
-
if
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
311
|
+
# if the message currently has any parts with content, yield before streaming
|
|
312
|
+
msg = self._stream_response.get()
|
|
313
|
+
for part in msg.parts:
|
|
314
|
+
if part.has_content():
|
|
315
|
+
yield msg, False
|
|
316
|
+
break
|
|
317
|
+
|
|
318
|
+
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
319
|
+
async for _events in group_iter:
|
|
320
|
+
msg = self._stream_response.get()
|
|
321
|
+
yield msg, False
|
|
292
322
|
msg = self._stream_response.get()
|
|
293
|
-
for item in msg.parts:
|
|
294
|
-
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
295
|
-
yield msg, False
|
|
296
|
-
break
|
|
297
|
-
async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
|
|
298
|
-
async for _ in group_iter:
|
|
299
|
-
msg = self._stream_response.get()
|
|
300
|
-
for item in msg.parts:
|
|
301
|
-
if isinstance(item, _messages.ToolCallPart) and item.has_content():
|
|
302
|
-
yield msg, False
|
|
303
|
-
break
|
|
304
|
-
msg = self._stream_response.get(final=True)
|
|
305
323
|
yield msg, True
|
|
324
|
+
# TODO: Should this now be `final_response` instead of `structured_response`?
|
|
306
325
|
lf_span.set_attribute('structured_response', msg)
|
|
307
326
|
await self._marked_completed(msg)
|
|
308
327
|
|
|
309
|
-
async def get_data(self) ->
|
|
328
|
+
async def get_data(self) -> ResultDataT:
|
|
310
329
|
"""Stream the whole response, validate and return it."""
|
|
311
330
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
312
331
|
self._stream_response, self._usage_limits, self.usage
|
|
@@ -314,21 +333,9 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
314
333
|
|
|
315
334
|
async for _ in usage_checking_stream:
|
|
316
335
|
pass
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
text = await self._validate_text_result(text)
|
|
321
|
-
await self._marked_completed(_messages.ModelResponse.from_text(text))
|
|
322
|
-
return cast(ResultData, text)
|
|
323
|
-
else:
|
|
324
|
-
message = self._stream_response.get(final=True)
|
|
325
|
-
await self._marked_completed(message)
|
|
326
|
-
return await self.validate_structured_result(message)
|
|
327
|
-
|
|
328
|
-
@property
|
|
329
|
-
def is_structured(self) -> bool:
|
|
330
|
-
"""Return whether the stream response contains structured data (as opposed to text)."""
|
|
331
|
-
return isinstance(self._stream_response, models.StreamStructuredResponse)
|
|
336
|
+
message = self._stream_response.get()
|
|
337
|
+
await self._marked_completed(message)
|
|
338
|
+
return await self.validate_structured_result(message)
|
|
332
339
|
|
|
333
340
|
def usage(self) -> Usage:
|
|
334
341
|
"""Return the usage of the whole run.
|
|
@@ -344,27 +351,36 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
344
351
|
|
|
345
352
|
async def validate_structured_result(
|
|
346
353
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
347
|
-
) ->
|
|
354
|
+
) -> ResultDataT:
|
|
348
355
|
"""Validate a structured result message."""
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
356
|
+
if self._result_schema is not None and self._result_tool_name is not None:
|
|
357
|
+
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
|
|
358
|
+
if match is None:
|
|
359
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
360
|
+
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
call, result_tool = match
|
|
364
|
+
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
365
|
+
|
|
366
|
+
for validator in self._result_validators:
|
|
367
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
368
|
+
return result_data
|
|
369
|
+
else:
|
|
370
|
+
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
371
|
+
for validator in self._result_validators:
|
|
372
|
+
text = await validator.validate(
|
|
373
|
+
text,
|
|
374
|
+
None,
|
|
375
|
+
self._run_ctx,
|
|
376
|
+
)
|
|
377
|
+
# Since there is no result tool, we can assume that str is compatible with ResultDataT
|
|
378
|
+
return cast(ResultDataT, text)
|
|
363
379
|
|
|
364
380
|
async def _validate_text_result(self, text: str) -> str:
|
|
365
381
|
for validator in self._result_validators:
|
|
366
|
-
text = await validator.validate(
|
|
367
|
-
text,
|
|
382
|
+
text = await validator.validate(
|
|
383
|
+
text,
|
|
368
384
|
None,
|
|
369
385
|
self._run_ctx,
|
|
370
386
|
)
|
|
@@ -377,8 +393,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
377
393
|
|
|
378
394
|
|
|
379
395
|
def _get_usage_checking_stream_response(
|
|
380
|
-
stream_response:
|
|
381
|
-
|
|
396
|
+
stream_response: AsyncIterable[_messages.ModelResponseStreamEvent],
|
|
397
|
+
limits: UsageLimits | None,
|
|
398
|
+
get_usage: Callable[[], Usage],
|
|
399
|
+
) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
|
|
382
400
|
if limits is not None and limits.has_token_limits():
|
|
383
401
|
|
|
384
402
|
async def _usage_checking_iterator():
|
pydantic_ai/settings.py
CHANGED
|
@@ -12,7 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
class ModelSettings(TypedDict, total=False):
|
|
13
13
|
"""Settings to configure an LLM.
|
|
14
14
|
|
|
15
|
-
Here we include only settings which apply to multiple models / model providers
|
|
15
|
+
Here we include only settings which apply to multiple models / model providers,
|
|
16
|
+
though not all of these settings are supported by all models.
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
max_tokens: int
|
|
@@ -24,6 +25,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
24
25
|
* Anthropic
|
|
25
26
|
* OpenAI
|
|
26
27
|
* Groq
|
|
28
|
+
* Cohere
|
|
29
|
+
* Mistral
|
|
27
30
|
"""
|
|
28
31
|
|
|
29
32
|
temperature: float
|
|
@@ -40,6 +43,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
40
43
|
* Anthropic
|
|
41
44
|
* OpenAI
|
|
42
45
|
* Groq
|
|
46
|
+
* Cohere
|
|
47
|
+
* Mistral
|
|
43
48
|
"""
|
|
44
49
|
|
|
45
50
|
top_p: float
|
|
@@ -55,6 +60,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
55
60
|
* Anthropic
|
|
56
61
|
* OpenAI
|
|
57
62
|
* Groq
|
|
63
|
+
* Cohere
|
|
64
|
+
* Mistral
|
|
58
65
|
"""
|
|
59
66
|
|
|
60
67
|
timeout: float | Timeout
|
|
@@ -66,6 +73,16 @@ class ModelSettings(TypedDict, total=False):
|
|
|
66
73
|
* Anthropic
|
|
67
74
|
* OpenAI
|
|
68
75
|
* Groq
|
|
76
|
+
* Mistral
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
parallel_tool_calls: bool
|
|
80
|
+
"""Whether to allow parallel tool calls.
|
|
81
|
+
|
|
82
|
+
Supported by:
|
|
83
|
+
* OpenAI
|
|
84
|
+
* Groq
|
|
85
|
+
* Anthropic
|
|
69
86
|
"""
|
|
70
87
|
|
|
71
88
|
|
pydantic_ai/tools.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import Awaitable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
10
|
from pydantic_core import SchemaValidator
|
|
@@ -17,7 +17,8 @@ if TYPE_CHECKING:
|
|
|
17
17
|
from .result import Usage
|
|
18
18
|
|
|
19
19
|
__all__ = (
|
|
20
|
-
'
|
|
20
|
+
'AgentDepsT',
|
|
21
|
+
'DocstringFormat',
|
|
21
22
|
'RunContext',
|
|
22
23
|
'SystemPromptFunc',
|
|
23
24
|
'ToolFuncContext',
|
|
@@ -30,15 +31,15 @@ __all__ = (
|
|
|
30
31
|
'ToolDefinition',
|
|
31
32
|
)
|
|
32
33
|
|
|
33
|
-
|
|
34
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
34
35
|
"""Type variable for agent dependencies."""
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
@dataclasses.dataclass
|
|
38
|
-
class RunContext(Generic[
|
|
39
|
+
class RunContext(Generic[AgentDepsT]):
|
|
39
40
|
"""Information about the current call."""
|
|
40
41
|
|
|
41
|
-
deps:
|
|
42
|
+
deps: AgentDepsT
|
|
42
43
|
"""Dependencies for the agent."""
|
|
43
44
|
model: models.Model
|
|
44
45
|
"""The model used in this run."""
|
|
@@ -57,7 +58,7 @@ class RunContext(Generic[AgentDeps]):
|
|
|
57
58
|
|
|
58
59
|
def replace_with(
|
|
59
60
|
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
|
|
60
|
-
) -> RunContext[
|
|
61
|
+
) -> RunContext[AgentDepsT]:
|
|
61
62
|
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
62
63
|
kwargs = {}
|
|
63
64
|
if retry is not None:
|
|
@@ -71,8 +72,8 @@ ToolParams = ParamSpec('ToolParams', default=...)
|
|
|
71
72
|
"""Retrieval function param spec."""
|
|
72
73
|
|
|
73
74
|
SystemPromptFunc = Union[
|
|
74
|
-
Callable[[RunContext[
|
|
75
|
-
Callable[[RunContext[
|
|
75
|
+
Callable[[RunContext[AgentDepsT]], str],
|
|
76
|
+
Callable[[RunContext[AgentDepsT]], Awaitable[str]],
|
|
76
77
|
Callable[[], str],
|
|
77
78
|
Callable[[], Awaitable[str]],
|
|
78
79
|
]
|
|
@@ -81,7 +82,7 @@ SystemPromptFunc = Union[
|
|
|
81
82
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
82
83
|
"""
|
|
83
84
|
|
|
84
|
-
ToolFuncContext = Callable[Concatenate[RunContext[
|
|
85
|
+
ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
|
|
85
86
|
"""A tool function that takes `RunContext` as the first argument.
|
|
86
87
|
|
|
87
88
|
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
@@ -91,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
|
|
|
91
92
|
|
|
92
93
|
Usage `ToolPlainFunc[ToolParams]`.
|
|
93
94
|
"""
|
|
94
|
-
ToolFuncEither = Union[ToolFuncContext[
|
|
95
|
+
ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]]
|
|
95
96
|
"""Either kind of tool function.
|
|
96
97
|
|
|
97
98
|
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
|
|
@@ -99,14 +100,14 @@ This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] a
|
|
|
99
100
|
|
|
100
101
|
Usage `ToolFuncEither[AgentDeps, ToolParams]`.
|
|
101
102
|
"""
|
|
102
|
-
ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[
|
|
103
|
+
ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
|
|
103
104
|
"""Definition of a function that can prepare a tool definition at call time.
|
|
104
105
|
|
|
105
106
|
See [tool docs](../tools.md#tool-prepare) for more information.
|
|
106
107
|
|
|
107
108
|
Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
|
|
108
109
|
|
|
109
|
-
```python {
|
|
110
|
+
```python {noqa="I001"}
|
|
110
111
|
from typing import Union
|
|
111
112
|
|
|
112
113
|
from pydantic_ai import RunContext, Tool
|
|
@@ -127,19 +128,30 @@ hitchhiker = Tool(hitchhiker, prepare=only_if_42)
|
|
|
127
128
|
Usage `ToolPrepareFunc[AgentDeps]`.
|
|
128
129
|
"""
|
|
129
130
|
|
|
131
|
+
DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
|
|
132
|
+
"""Supported docstring formats.
|
|
133
|
+
|
|
134
|
+
* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
|
|
135
|
+
* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
|
|
136
|
+
* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
|
|
137
|
+
* `'auto'` — Automatically infer the format based on the structure of the docstring.
|
|
138
|
+
"""
|
|
139
|
+
|
|
130
140
|
A = TypeVar('A')
|
|
131
141
|
|
|
132
142
|
|
|
133
143
|
@dataclass(init=False)
|
|
134
|
-
class Tool(Generic[
|
|
144
|
+
class Tool(Generic[AgentDepsT]):
|
|
135
145
|
"""A tool function for an agent."""
|
|
136
146
|
|
|
137
|
-
function: ToolFuncEither[
|
|
147
|
+
function: ToolFuncEither[AgentDepsT]
|
|
138
148
|
takes_ctx: bool
|
|
139
149
|
max_retries: int | None
|
|
140
150
|
name: str
|
|
141
151
|
description: str
|
|
142
|
-
prepare: ToolPrepareFunc[
|
|
152
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None
|
|
153
|
+
docstring_format: DocstringFormat
|
|
154
|
+
require_parameter_descriptions: bool
|
|
143
155
|
_is_async: bool = field(init=False)
|
|
144
156
|
_single_arg_name: str | None = field(init=False)
|
|
145
157
|
_positional_fields: list[str] = field(init=False)
|
|
@@ -150,19 +162,21 @@ class Tool(Generic[AgentDeps]):
|
|
|
150
162
|
|
|
151
163
|
def __init__(
|
|
152
164
|
self,
|
|
153
|
-
function: ToolFuncEither[
|
|
165
|
+
function: ToolFuncEither[AgentDepsT],
|
|
154
166
|
*,
|
|
155
167
|
takes_ctx: bool | None = None,
|
|
156
168
|
max_retries: int | None = None,
|
|
157
169
|
name: str | None = None,
|
|
158
170
|
description: str | None = None,
|
|
159
|
-
prepare: ToolPrepareFunc[
|
|
171
|
+
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
172
|
+
docstring_format: DocstringFormat = 'auto',
|
|
173
|
+
require_parameter_descriptions: bool = False,
|
|
160
174
|
):
|
|
161
175
|
"""Create a new tool instance.
|
|
162
176
|
|
|
163
177
|
Example usage:
|
|
164
178
|
|
|
165
|
-
```python {
|
|
179
|
+
```python {noqa="I001"}
|
|
166
180
|
from pydantic_ai import Agent, RunContext, Tool
|
|
167
181
|
|
|
168
182
|
async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
|
|
@@ -173,7 +187,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
173
187
|
|
|
174
188
|
or with a custom prepare method:
|
|
175
189
|
|
|
176
|
-
```python {
|
|
190
|
+
```python {noqa="I001"}
|
|
177
191
|
from typing import Union
|
|
178
192
|
|
|
179
193
|
from pydantic_ai import Agent, RunContext, Tool
|
|
@@ -203,17 +217,22 @@ class Tool(Generic[AgentDeps]):
|
|
|
203
217
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
204
218
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
205
219
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
220
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
221
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
222
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
206
223
|
"""
|
|
207
224
|
if takes_ctx is None:
|
|
208
225
|
takes_ctx = _pydantic.takes_ctx(function)
|
|
209
226
|
|
|
210
|
-
f = _pydantic.function_schema(function, takes_ctx)
|
|
227
|
+
f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
|
|
211
228
|
self.function = function
|
|
212
229
|
self.takes_ctx = takes_ctx
|
|
213
230
|
self.max_retries = max_retries
|
|
214
231
|
self.name = name or function.__name__
|
|
215
232
|
self.description = description or f['description']
|
|
216
233
|
self.prepare = prepare
|
|
234
|
+
self.docstring_format = docstring_format
|
|
235
|
+
self.require_parameter_descriptions = require_parameter_descriptions
|
|
217
236
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
218
237
|
self._single_arg_name = f['single_arg_name']
|
|
219
238
|
self._positional_fields = f['positional_fields']
|
|
@@ -221,7 +240,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
221
240
|
self._validator = f['validator']
|
|
222
241
|
self._parameters_json_schema = f['json_schema']
|
|
223
242
|
|
|
224
|
-
async def prepare_tool_def(self, ctx: RunContext[
|
|
243
|
+
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
|
|
225
244
|
"""Get the tool definition.
|
|
226
245
|
|
|
227
246
|
By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
|
|
@@ -241,7 +260,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
241
260
|
return tool_def
|
|
242
261
|
|
|
243
262
|
async def run(
|
|
244
|
-
self, message: _messages.ToolCallPart, run_context: RunContext[
|
|
263
|
+
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
|
|
245
264
|
) -> _messages.ModelRequestPart:
|
|
246
265
|
"""Run the tool function asynchronously."""
|
|
247
266
|
try:
|
|
@@ -274,7 +293,7 @@ class Tool(Generic[AgentDeps]):
|
|
|
274
293
|
self,
|
|
275
294
|
args_dict: dict[str, Any],
|
|
276
295
|
message: _messages.ToolCallPart,
|
|
277
|
-
run_context: RunContext[
|
|
296
|
+
run_context: RunContext[AgentDepsT],
|
|
278
297
|
) -> tuple[list[Any], dict[str, Any]]:
|
|
279
298
|
if self._single_arg_name:
|
|
280
299
|
args_dict = {self._single_arg_name: args_dict}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.20
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -26,11 +26,15 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
26
26
|
Requires-Python: >=3.9
|
|
27
27
|
Requires-Dist: eval-type-backport>=0.2.0
|
|
28
28
|
Requires-Dist: griffe>=1.3.2
|
|
29
|
-
Requires-Dist: httpx>=0.27
|
|
29
|
+
Requires-Dist: httpx>=0.27
|
|
30
30
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
31
|
Requires-Dist: pydantic>=2.10
|
|
32
32
|
Provides-Extra: anthropic
|
|
33
33
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
34
|
+
Provides-Extra: cohere
|
|
35
|
+
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
36
|
+
Provides-Extra: graph
|
|
37
|
+
Requires-Dist: pydantic-graph==0.0.20; extra == 'graph'
|
|
34
38
|
Provides-Extra: groq
|
|
35
39
|
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
36
40
|
Provides-Extra: logfire
|
|
@@ -38,7 +42,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
|
38
42
|
Provides-Extra: mistral
|
|
39
43
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
40
44
|
Provides-Extra: openai
|
|
41
|
-
Requires-Dist: openai>=1.
|
|
45
|
+
Requires-Dist: openai>=1.59.0; extra == 'openai'
|
|
42
46
|
Provides-Extra: vertexai
|
|
43
47
|
Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
|
|
44
48
|
Requires-Dist: requests>=2.32.3; extra == 'vertexai'
|