pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +81 -28
- pydantic_ai/models/__init__.py +19 -7
- pydantic_ai/models/anthropic.py +6 -6
- pydantic_ai/models/bedrock.py +63 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +10 -13
- pydantic_ai/models/google.py +4 -4
- pydantic_ai/models/groq.py +5 -5
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +44 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +20 -29
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +6 -7
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/result.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
3
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
|
4
4
|
from copy import copy
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Generic, cast
|
|
7
|
+
from typing import Generic, cast, overload
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
|
-
from typing_extensions import TypeVar
|
|
11
|
-
|
|
12
|
-
from pydantic_ai._tool_manager import ToolManager
|
|
10
|
+
from typing_extensions import TypeVar, deprecated
|
|
13
11
|
|
|
14
12
|
from . import _utils, exceptions, messages as _messages, models
|
|
15
13
|
from ._output import (
|
|
@@ -22,11 +20,14 @@ from ._output import (
|
|
|
22
20
|
ToolOutputSchema,
|
|
23
21
|
)
|
|
24
22
|
from ._run_context import AgentDepsT, RunContext
|
|
23
|
+
from ._tool_manager import ToolManager
|
|
25
24
|
from .messages import ModelResponseStreamEvent
|
|
26
25
|
from .output import (
|
|
26
|
+
DeferredToolRequests,
|
|
27
27
|
OutputDataT,
|
|
28
28
|
ToolOutput,
|
|
29
29
|
)
|
|
30
|
+
from .run import AgentRunResult
|
|
30
31
|
from .usage import RunUsage, UsageLimits
|
|
31
32
|
|
|
32
33
|
__all__ = (
|
|
@@ -41,7 +42,7 @@ T = TypeVar('T')
|
|
|
41
42
|
"""An invariant TypeVar."""
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
@dataclass
|
|
45
|
+
@dataclass(kw_only=True)
|
|
45
46
|
class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
46
47
|
_raw_stream_response: models.StreamedResponse
|
|
47
48
|
_output_schema: OutputSchema[OutputDataT]
|
|
@@ -62,11 +63,11 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
62
63
|
async for response in self.stream_responses(debounce_by=debounce_by):
|
|
63
64
|
if self._raw_stream_response.final_result_event is not None:
|
|
64
65
|
try:
|
|
65
|
-
yield await self.
|
|
66
|
+
yield await self.validate_response_output(response, allow_partial=True)
|
|
66
67
|
except ValidationError:
|
|
67
68
|
pass
|
|
68
69
|
if self._raw_stream_response.final_result_event is not None: # pragma: no branch
|
|
69
|
-
yield await self.
|
|
70
|
+
yield await self.validate_response_output(self._raw_stream_response.get())
|
|
70
71
|
|
|
71
72
|
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
|
|
72
73
|
"""Asynchronously stream the (unvalidated) model responses for the agent."""
|
|
@@ -127,9 +128,11 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
127
128
|
async for _ in self:
|
|
128
129
|
pass
|
|
129
130
|
|
|
130
|
-
return await self.
|
|
131
|
+
return await self.validate_response_output(self._raw_stream_response.get())
|
|
131
132
|
|
|
132
|
-
async def
|
|
133
|
+
async def validate_response_output(
|
|
134
|
+
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
135
|
+
) -> OutputDataT:
|
|
133
136
|
"""Validate a structured result message."""
|
|
134
137
|
final_result_event = self._raw_stream_response.final_result_event
|
|
135
138
|
if final_result_event is None:
|
|
@@ -153,12 +156,12 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
153
156
|
return await self._tool_manager.handle_call(
|
|
154
157
|
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
|
|
155
158
|
)
|
|
156
|
-
elif
|
|
157
|
-
if not self._output_schema.
|
|
159
|
+
elif deferred_tool_requests := _get_deferred_tool_requests(message.parts, self._tool_manager):
|
|
160
|
+
if not self._output_schema.allows_deferred_tools:
|
|
158
161
|
raise exceptions.UserError(
|
|
159
|
-
'A deferred tool call was present, but `
|
|
162
|
+
'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.'
|
|
160
163
|
)
|
|
161
|
-
return cast(OutputDataT,
|
|
164
|
+
return cast(OutputDataT, deferred_tool_requests)
|
|
162
165
|
elif isinstance(self._output_schema, TextOutputSchema):
|
|
163
166
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
164
167
|
|
|
@@ -231,26 +234,61 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
231
234
|
return self._agent_stream_iterator
|
|
232
235
|
|
|
233
236
|
|
|
234
|
-
@dataclass
|
|
237
|
+
@dataclass(init=False)
|
|
235
238
|
class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
236
239
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
237
240
|
|
|
238
241
|
_all_messages: list[_messages.ModelMessage]
|
|
239
242
|
_new_message_index: int
|
|
240
243
|
|
|
241
|
-
_stream_response: AgentStream[AgentDepsT, OutputDataT]
|
|
242
|
-
_on_complete: Callable[[], Awaitable[None]]
|
|
244
|
+
_stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None
|
|
245
|
+
_on_complete: Callable[[], Awaitable[None]] | None = None
|
|
246
|
+
|
|
247
|
+
_run_result: AgentRunResult[OutputDataT] | None = None
|
|
243
248
|
|
|
244
249
|
is_complete: bool = field(default=False, init=False)
|
|
245
250
|
"""Whether the stream has all been received.
|
|
246
251
|
|
|
247
252
|
This is set to `True` when one of
|
|
248
|
-
[`
|
|
253
|
+
[`stream_output`][pydantic_ai.result.StreamedRunResult.stream_output],
|
|
249
254
|
[`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text],
|
|
250
|
-
[`
|
|
255
|
+
[`stream_responses`][pydantic_ai.result.StreamedRunResult.stream_responses] or
|
|
251
256
|
[`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
|
|
252
257
|
"""
|
|
253
258
|
|
|
259
|
+
@overload
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
all_messages: list[_messages.ModelMessage],
|
|
263
|
+
new_message_index: int,
|
|
264
|
+
stream_response: AgentStream[AgentDepsT, OutputDataT] | None,
|
|
265
|
+
on_complete: Callable[[], Awaitable[None]] | None,
|
|
266
|
+
) -> None: ...
|
|
267
|
+
|
|
268
|
+
@overload
|
|
269
|
+
def __init__(
|
|
270
|
+
self,
|
|
271
|
+
all_messages: list[_messages.ModelMessage],
|
|
272
|
+
new_message_index: int,
|
|
273
|
+
*,
|
|
274
|
+
run_result: AgentRunResult[OutputDataT],
|
|
275
|
+
) -> None: ...
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
all_messages: list[_messages.ModelMessage],
|
|
280
|
+
new_message_index: int,
|
|
281
|
+
stream_response: AgentStream[AgentDepsT, OutputDataT] | None = None,
|
|
282
|
+
on_complete: Callable[[], Awaitable[None]] | None = None,
|
|
283
|
+
run_result: AgentRunResult[OutputDataT] | None = None,
|
|
284
|
+
) -> None:
|
|
285
|
+
self._all_messages = all_messages
|
|
286
|
+
self._new_message_index = new_message_index
|
|
287
|
+
|
|
288
|
+
self._stream_response = stream_response
|
|
289
|
+
self._on_complete = on_complete
|
|
290
|
+
self._run_result = run_result
|
|
291
|
+
|
|
254
292
|
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
|
|
255
293
|
"""Return the history of _messages.
|
|
256
294
|
|
|
@@ -318,24 +356,35 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
318
356
|
self.new_messages(output_tool_return_content=output_tool_return_content)
|
|
319
357
|
)
|
|
320
358
|
|
|
359
|
+
@deprecated('`StreamedRunResult.stream` is deprecated, use `stream_output` instead.')
|
|
321
360
|
async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
|
|
322
|
-
|
|
361
|
+
async for output in self.stream_output(debounce_by=debounce_by):
|
|
362
|
+
yield output
|
|
363
|
+
|
|
364
|
+
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]:
|
|
365
|
+
"""Stream the output as an async iterable.
|
|
323
366
|
|
|
324
367
|
The pydantic validator for structured data will be called in
|
|
325
368
|
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
|
|
326
369
|
on each iteration.
|
|
327
370
|
|
|
328
371
|
Args:
|
|
329
|
-
debounce_by: by how much (if at all) to debounce/group the
|
|
330
|
-
Debouncing is particularly important for long structured
|
|
372
|
+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
|
|
373
|
+
Debouncing is particularly important for long structured outputs to reduce the overhead of
|
|
331
374
|
performing validation as each token is received.
|
|
332
375
|
|
|
333
376
|
Returns:
|
|
334
377
|
An async iterable of the response data.
|
|
335
378
|
"""
|
|
336
|
-
|
|
337
|
-
yield output
|
|
338
|
-
|
|
379
|
+
if self._run_result is not None:
|
|
380
|
+
yield self._run_result.output
|
|
381
|
+
await self._marked_completed()
|
|
382
|
+
elif self._stream_response is not None:
|
|
383
|
+
async for output in self._stream_response.stream_output(debounce_by=debounce_by):
|
|
384
|
+
yield output
|
|
385
|
+
await self._marked_completed(self._stream_response.get())
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
339
388
|
|
|
340
389
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
341
390
|
"""Stream the text result as an async iterable.
|
|
@@ -350,12 +399,30 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
350
399
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
351
400
|
performing validation as each token is received.
|
|
352
401
|
"""
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
402
|
+
if self._run_result is not None: # pragma: no cover
|
|
403
|
+
# We can't really get here, as `_run_result` is only set in `run_stream` when `CallToolsNode` produces `DeferredToolRequests` output
|
|
404
|
+
# as a result of a tool function raising `CallDeferred` or `ApprovalRequired`.
|
|
405
|
+
# That'll change if we ever support something like `raise EndRun(output: OutputT)` where `OutputT` could be `str`.
|
|
406
|
+
if not isinstance(self._run_result.output, str):
|
|
407
|
+
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
408
|
+
yield self._run_result.output
|
|
409
|
+
await self._marked_completed()
|
|
410
|
+
elif self._stream_response is not None:
|
|
411
|
+
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
|
|
412
|
+
yield text
|
|
413
|
+
await self._marked_completed(self._stream_response.get())
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
356
416
|
|
|
417
|
+
@deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
|
|
357
418
|
async def stream_structured(
|
|
358
419
|
self, *, debounce_by: float | None = 0.1
|
|
420
|
+
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
421
|
+
async for msg, last in self.stream_responses(debounce_by=debounce_by):
|
|
422
|
+
yield msg, last
|
|
423
|
+
|
|
424
|
+
async def stream_responses(
|
|
425
|
+
self, *, debounce_by: float | None = 0.1
|
|
359
426
|
) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
|
|
360
427
|
"""Stream the response as an async iterable of Structured LLM Messages.
|
|
361
428
|
|
|
@@ -367,20 +434,34 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
367
434
|
Returns:
|
|
368
435
|
An async iterable of the structured response message and whether that is the last message.
|
|
369
436
|
"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
yield
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
437
|
+
if self._run_result is not None:
|
|
438
|
+
model_response = cast(_messages.ModelResponse, self.all_messages()[-1])
|
|
439
|
+
yield model_response, True
|
|
440
|
+
await self._marked_completed()
|
|
441
|
+
elif self._stream_response is not None:
|
|
442
|
+
# if the message currently has any parts with content, yield before streaming
|
|
443
|
+
async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
|
|
444
|
+
yield msg, False
|
|
445
|
+
|
|
446
|
+
msg = self._stream_response.get()
|
|
447
|
+
yield msg, True
|
|
448
|
+
|
|
449
|
+
await self._marked_completed(msg)
|
|
450
|
+
else:
|
|
451
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
378
452
|
|
|
379
453
|
async def get_output(self) -> OutputDataT:
|
|
380
454
|
"""Stream the whole response, validate and return it."""
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
455
|
+
if self._run_result is not None:
|
|
456
|
+
output = self._run_result.output
|
|
457
|
+
await self._marked_completed()
|
|
458
|
+
return output
|
|
459
|
+
elif self._stream_response is not None:
|
|
460
|
+
output = await self._stream_response.get_output()
|
|
461
|
+
await self._marked_completed(self._stream_response.get())
|
|
462
|
+
return output
|
|
463
|
+
else:
|
|
464
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
384
465
|
|
|
385
466
|
def usage(self) -> RunUsage:
|
|
386
467
|
"""Return the usage of the whole run.
|
|
@@ -388,24 +469,45 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
388
469
|
!!! note
|
|
389
470
|
This won't return the full usage until the stream is finished.
|
|
390
471
|
"""
|
|
391
|
-
|
|
472
|
+
if self._run_result is not None:
|
|
473
|
+
return self._run_result.usage()
|
|
474
|
+
elif self._stream_response is not None:
|
|
475
|
+
return self._stream_response.usage()
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
392
478
|
|
|
393
479
|
def timestamp(self) -> datetime:
|
|
394
480
|
"""Get the timestamp of the response."""
|
|
395
|
-
|
|
481
|
+
if self._run_result is not None:
|
|
482
|
+
return self._run_result.timestamp()
|
|
483
|
+
elif self._stream_response is not None:
|
|
484
|
+
return self._stream_response.timestamp()
|
|
485
|
+
else:
|
|
486
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
396
487
|
|
|
488
|
+
@deprecated('`validate_structured_output` is deprecated, use `validate_response_output` instead.')
|
|
397
489
|
async def validate_structured_output(
|
|
398
490
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
491
|
+
) -> OutputDataT:
|
|
492
|
+
return await self.validate_response_output(message, allow_partial=allow_partial)
|
|
493
|
+
|
|
494
|
+
async def validate_response_output(
|
|
495
|
+
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
399
496
|
) -> OutputDataT:
|
|
400
497
|
"""Validate a structured result message."""
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
498
|
+
if self._run_result is not None:
|
|
499
|
+
return self._run_result.output
|
|
500
|
+
elif self._stream_response is not None:
|
|
501
|
+
return await self._stream_response.validate_response_output(message, allow_partial=allow_partial)
|
|
502
|
+
else:
|
|
503
|
+
raise ValueError('No stream response or run result provided') # pragma: no cover
|
|
404
504
|
|
|
405
|
-
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
|
505
|
+
async def _marked_completed(self, message: _messages.ModelResponse | None = None) -> None:
|
|
406
506
|
self.is_complete = True
|
|
407
|
-
|
|
408
|
-
|
|
507
|
+
if message is not None:
|
|
508
|
+
self._all_messages.append(message)
|
|
509
|
+
if self._on_complete is not None:
|
|
510
|
+
await self._on_complete()
|
|
409
511
|
|
|
410
512
|
|
|
411
513
|
@dataclass(repr=False)
|
|
@@ -414,8 +516,10 @@ class FinalResult(Generic[OutputDataT]):
|
|
|
414
516
|
|
|
415
517
|
output: OutputDataT
|
|
416
518
|
"""The final result data."""
|
|
519
|
+
|
|
417
520
|
tool_name: str | None = None
|
|
418
521
|
"""Name of the final output tool; `None` if the output came from unstructured text content."""
|
|
522
|
+
|
|
419
523
|
tool_call_id: str | None = None
|
|
420
524
|
"""ID of the tool call that produced the final output; `None` if the output came from unstructured text content."""
|
|
421
525
|
|
|
@@ -436,9 +540,26 @@ def _get_usage_checking_stream_response(
|
|
|
436
540
|
|
|
437
541
|
return _usage_checking_iterator()
|
|
438
542
|
else:
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
543
|
+
return aiter(stream_response)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _get_deferred_tool_requests(
|
|
547
|
+
parts: Iterable[_messages.ModelResponsePart], tool_manager: ToolManager[AgentDepsT]
|
|
548
|
+
) -> DeferredToolRequests | None:
|
|
549
|
+
"""Get the deferred tool requests from the model response parts."""
|
|
550
|
+
approvals: list[_messages.ToolCallPart] = []
|
|
551
|
+
calls: list[_messages.ToolCallPart] = []
|
|
552
|
+
|
|
553
|
+
for part in parts:
|
|
554
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
555
|
+
tool_def = tool_manager.get_tool_def(part.tool_name)
|
|
556
|
+
if tool_def is not None: # pragma: no branch
|
|
557
|
+
if tool_def.kind == 'unapproved':
|
|
558
|
+
approvals.append(part)
|
|
559
|
+
elif tool_def.kind == 'external':
|
|
560
|
+
calls.append(part)
|
|
561
|
+
|
|
562
|
+
if not calls and not approvals:
|
|
563
|
+
return None
|
|
443
564
|
|
|
444
|
-
|
|
565
|
+
return DeferredToolRequests(calls=calls, approvals=approvals)
|
pydantic_ai/retries.py
CHANGED
|
@@ -24,17 +24,17 @@ from httpx import (
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
try:
|
|
27
|
-
from tenacity import
|
|
27
|
+
from tenacity import RetryCallState, RetryError, retry, wait_exponential
|
|
28
28
|
except ImportError as _import_error:
|
|
29
29
|
raise ImportError(
|
|
30
30
|
'Please install `tenacity` to use the retries utilities, '
|
|
31
31
|
'you can use the `retries` optional group — `pip install "pydantic-ai-slim[retries]"`'
|
|
32
32
|
) from _import_error
|
|
33
33
|
|
|
34
|
-
from collections.abc import Awaitable
|
|
34
|
+
from collections.abc import Awaitable, Callable
|
|
35
35
|
from datetime import datetime, timezone
|
|
36
36
|
from email.utils import parsedate_to_datetime
|
|
37
|
-
from typing import TYPE_CHECKING, Any,
|
|
37
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
38
38
|
|
|
39
39
|
from typing_extensions import TypedDict
|
|
40
40
|
|
|
@@ -134,8 +134,9 @@ class TenacityTransport(BaseTransport):
|
|
|
134
134
|
|
|
135
135
|
Example:
|
|
136
136
|
```python
|
|
137
|
-
from httpx import Client,
|
|
138
|
-
from tenacity import
|
|
137
|
+
from httpx import Client, HTTPStatusError, HTTPTransport
|
|
138
|
+
from tenacity import retry_if_exception_type, stop_after_attempt
|
|
139
|
+
|
|
139
140
|
from pydantic_ai.retries import RetryConfig, TenacityTransport, wait_retry_after
|
|
140
141
|
|
|
141
142
|
transport = TenacityTransport(
|
|
@@ -157,18 +158,7 @@ class TenacityTransport(BaseTransport):
|
|
|
157
158
|
config: RetryConfig,
|
|
158
159
|
wrapped: BaseTransport | None = None,
|
|
159
160
|
validate_response: Callable[[Response], Any] | None = None,
|
|
160
|
-
**kwargs: NoReturn,
|
|
161
161
|
):
|
|
162
|
-
# TODO: Remove the following checks (and **kwargs) during v1 release
|
|
163
|
-
if 'controller' in kwargs: # pragma: no cover
|
|
164
|
-
raise TypeError('The `controller` argument has been renamed to `config`, and now requires a `RetryConfig`.')
|
|
165
|
-
if kwargs: # pragma: no cover
|
|
166
|
-
raise TypeError(f'Unexpected keyword arguments: {", ".join(kwargs)}')
|
|
167
|
-
if isinstance(config, Retrying): # pragma: no cover
|
|
168
|
-
raise ValueError(
|
|
169
|
-
'Passing a Retrying instance is no longer supported; the `config` argument must be a `pydantic_ai.retries.RetryConfig`.'
|
|
170
|
-
)
|
|
171
|
-
|
|
172
162
|
self.config = config
|
|
173
163
|
self.wrapped = wrapped or HTTPTransport()
|
|
174
164
|
self.validate_response = validate_response
|
|
@@ -224,7 +214,8 @@ class AsyncTenacityTransport(AsyncBaseTransport):
|
|
|
224
214
|
Example:
|
|
225
215
|
```python
|
|
226
216
|
from httpx import AsyncClient, HTTPStatusError
|
|
227
|
-
from tenacity import
|
|
217
|
+
from tenacity import retry_if_exception_type, stop_after_attempt
|
|
218
|
+
|
|
228
219
|
from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
|
|
229
220
|
|
|
230
221
|
transport = AsyncTenacityTransport(
|
|
@@ -245,18 +236,7 @@ class AsyncTenacityTransport(AsyncBaseTransport):
|
|
|
245
236
|
config: RetryConfig,
|
|
246
237
|
wrapped: AsyncBaseTransport | None = None,
|
|
247
238
|
validate_response: Callable[[Response], Any] | None = None,
|
|
248
|
-
**kwargs: NoReturn,
|
|
249
239
|
):
|
|
250
|
-
# TODO: Remove the following checks (and **kwargs) during v1 release
|
|
251
|
-
if 'controller' in kwargs: # pragma: no cover
|
|
252
|
-
raise TypeError('The `controller` argument has been renamed to `config`, and now requires a `RetryConfig`.')
|
|
253
|
-
if kwargs: # pragma: no cover
|
|
254
|
-
raise TypeError(f'Unexpected keyword arguments: {", ".join(kwargs)}')
|
|
255
|
-
if isinstance(config, AsyncRetrying): # pragma: no cover
|
|
256
|
-
raise ValueError(
|
|
257
|
-
'Passing an AsyncRetrying instance is no longer supported; the `config` argument must be a `pydantic_ai.retries.RetryConfig`.'
|
|
258
|
-
)
|
|
259
|
-
|
|
260
240
|
self.config = config
|
|
261
241
|
self.wrapped = wrapped or AsyncHTTPTransport()
|
|
262
242
|
self.validate_response = validate_response
|
|
@@ -314,7 +294,8 @@ def wait_retry_after(
|
|
|
314
294
|
Example:
|
|
315
295
|
```python
|
|
316
296
|
from httpx import AsyncClient, HTTPStatusError
|
|
317
|
-
from tenacity import
|
|
297
|
+
from tenacity import retry_if_exception_type, stop_after_attempt
|
|
298
|
+
|
|
318
299
|
from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
|
|
319
300
|
|
|
320
301
|
transport = AsyncTenacityTransport(
|
pydantic_ai/run.py
CHANGED
|
@@ -3,9 +3,8 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
from collections.abc import AsyncIterator
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
from typing_extensions import Literal
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
|
|
9
8
|
|
|
10
9
|
from pydantic_graph import End, GraphRun, GraphRunContext
|
|
11
10
|
|
|
@@ -16,9 +15,11 @@ from . import (
|
|
|
16
15
|
usage as _usage,
|
|
17
16
|
)
|
|
18
17
|
from .output import OutputDataT
|
|
19
|
-
from .result import FinalResult
|
|
20
18
|
from .tools import AgentDepsT
|
|
21
19
|
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .result import FinalResult
|
|
22
|
+
|
|
22
23
|
|
|
23
24
|
@dataclasses.dataclass(repr=False)
|
|
24
25
|
class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
@@ -100,7 +101,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
100
101
|
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
|
|
101
102
|
"""The current context of the agent run."""
|
|
102
103
|
return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]](
|
|
103
|
-
self._graph_run.state, self._graph_run.deps
|
|
104
|
+
state=self._graph_run.state, deps=self._graph_run.deps
|
|
104
105
|
)
|
|
105
106
|
|
|
106
107
|
@property
|
|
@@ -348,3 +349,9 @@ class AgentRunResult(Generic[OutputDataT]):
|
|
|
348
349
|
def usage(self) -> _usage.RunUsage:
|
|
349
350
|
"""Return the usage of the whole run."""
|
|
350
351
|
return self._state.usage
|
|
352
|
+
|
|
353
|
+
def timestamp(self) -> datetime:
|
|
354
|
+
"""Return the timestamp of last response."""
|
|
355
|
+
model_response = self.all_messages()[-1]
|
|
356
|
+
assert isinstance(model_response, _messages.ModelResponse)
|
|
357
|
+
return model_response.timestamp
|