pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_agent_graph.py +310 -140
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +4 -4
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +3 -22
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +7 -8
- pydantic_ai/agent/__init__.py +70 -9
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +23 -2
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +2 -2
- pydantic_ai/messages.py +73 -25
- pydantic_ai/models/__init__.py +5 -4
- pydantic_ai/models/anthropic.py +5 -5
- pydantic_ai/models/bedrock.py +58 -56
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +9 -12
- pydantic_ai/models/google.py +3 -3
- pydantic_ai/models/groq.py +4 -4
- pydantic_ai/models/huggingface.py +4 -4
- pydantic_ai/models/instrumented.py +30 -16
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +6 -6
- pydantic_ai/models/openai.py +18 -27
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/result.py +144 -41
- pydantic_ai/retries.py +10 -29
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +126 -22
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +13 -4
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +7 -5
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
- pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
- pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_griffe.py
CHANGED
|
@@ -2,9 +2,10 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from contextlib import contextmanager
|
|
6
7
|
from inspect import Signature
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
8
9
|
|
|
9
10
|
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
10
11
|
|
pydantic_ai/_otel_messages.py
CHANGED
|
@@ -5,10 +5,10 @@ Based on https://github.com/lmolkova/semantic-conventions/blob/eccd1f806e426a32c
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import Literal
|
|
8
|
+
from typing import Literal, TypeAlias
|
|
9
9
|
|
|
10
10
|
from pydantic import JsonValue
|
|
11
|
-
from typing_extensions import NotRequired,
|
|
11
|
+
from typing_extensions import NotRequired, TypedDict
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class TextPart(TypedDict):
|
pydantic_ai/_output.py
CHANGED
|
@@ -3,9 +3,9 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import json
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from collections.abc import Awaitable, Sequence
|
|
6
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from pydantic_core import SchemaValidator, to_json
|
|
@@ -15,7 +15,7 @@ from . import _function_schema, _utils, messages as _messages
|
|
|
15
15
|
from ._run_context import AgentDepsT, RunContext
|
|
16
16
|
from .exceptions import ModelRetry, ToolRetryError, UserError
|
|
17
17
|
from .output import (
|
|
18
|
-
|
|
18
|
+
DeferredToolRequests,
|
|
19
19
|
NativeOutput,
|
|
20
20
|
OutputDataT,
|
|
21
21
|
OutputMode,
|
|
@@ -49,12 +49,12 @@ At some point, it may make sense to change the input to OutputValidatorFunc to b
|
|
|
49
49
|
resolve these potential variance issues.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
OutputValidatorFunc =
|
|
53
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
|
|
54
|
-
Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
55
|
-
Callable[[OutputDataT_inv], OutputDataT_inv]
|
|
56
|
-
Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
57
|
-
|
|
52
|
+
OutputValidatorFunc = (
|
|
53
|
+
Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv]
|
|
54
|
+
| Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
55
|
+
| Callable[[OutputDataT_inv], OutputDataT_inv]
|
|
56
|
+
| Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]]
|
|
57
|
+
)
|
|
58
58
|
"""
|
|
59
59
|
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
60
60
|
|
|
@@ -196,7 +196,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
|
196
196
|
|
|
197
197
|
@dataclass
|
|
198
198
|
class BaseOutputSchema(ABC, Generic[OutputDataT]):
|
|
199
|
-
|
|
199
|
+
allows_deferred_tools: bool
|
|
200
200
|
|
|
201
201
|
@abstractmethod
|
|
202
202
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
@@ -249,10 +249,10 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
249
249
|
"""Build an OutputSchema dataclass from an output type."""
|
|
250
250
|
raw_outputs = _flatten_output_spec(output_spec)
|
|
251
251
|
|
|
252
|
-
outputs = [output for output in raw_outputs if output is not
|
|
253
|
-
|
|
254
|
-
if len(outputs) == 0 and
|
|
255
|
-
raise UserError('At least one output type must be provided other than `
|
|
252
|
+
outputs = [output for output in raw_outputs if output is not DeferredToolRequests]
|
|
253
|
+
allows_deferred_tools = len(outputs) < len(raw_outputs)
|
|
254
|
+
if len(outputs) == 0 and allows_deferred_tools:
|
|
255
|
+
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')
|
|
256
256
|
|
|
257
257
|
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
|
|
258
258
|
if len(outputs) > 1:
|
|
@@ -265,7 +265,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
265
265
|
description=output.description,
|
|
266
266
|
strict=output.strict,
|
|
267
267
|
),
|
|
268
|
-
|
|
268
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
269
269
|
)
|
|
270
270
|
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
|
|
271
271
|
if len(outputs) > 1:
|
|
@@ -278,7 +278,7 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
278
278
|
description=output.description,
|
|
279
279
|
),
|
|
280
280
|
template=output.template,
|
|
281
|
-
|
|
281
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
282
282
|
)
|
|
283
283
|
|
|
284
284
|
text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = []
|
|
@@ -313,21 +313,21 @@ class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
|
|
|
313
313
|
|
|
314
314
|
if toolset:
|
|
315
315
|
return ToolOrTextOutputSchema(
|
|
316
|
-
processor=text_output_schema,
|
|
316
|
+
processor=text_output_schema,
|
|
317
|
+
toolset=toolset,
|
|
318
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
317
319
|
)
|
|
318
320
|
else:
|
|
319
|
-
return PlainTextOutputSchema(
|
|
320
|
-
processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls
|
|
321
|
-
)
|
|
321
|
+
return PlainTextOutputSchema(processor=text_output_schema, allows_deferred_tools=allows_deferred_tools)
|
|
322
322
|
|
|
323
323
|
if len(tool_outputs) > 0:
|
|
324
|
-
return ToolOutputSchema(toolset=toolset,
|
|
324
|
+
return ToolOutputSchema(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
|
|
325
325
|
|
|
326
326
|
if len(other_outputs) > 0:
|
|
327
327
|
schema = OutputSchemaWithoutMode(
|
|
328
328
|
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
|
|
329
329
|
toolset=toolset,
|
|
330
|
-
|
|
330
|
+
allows_deferred_tools=allows_deferred_tools,
|
|
331
331
|
)
|
|
332
332
|
if default_mode:
|
|
333
333
|
schema = schema.with_default_mode(default_mode)
|
|
@@ -371,23 +371,19 @@ class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
|
|
|
371
371
|
self,
|
|
372
372
|
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT],
|
|
373
373
|
toolset: OutputToolset[Any] | None,
|
|
374
|
-
|
|
374
|
+
allows_deferred_tools: bool,
|
|
375
375
|
):
|
|
376
|
-
super().__init__(
|
|
376
|
+
super().__init__(allows_deferred_tools)
|
|
377
377
|
self.processor = processor
|
|
378
378
|
self._toolset = toolset
|
|
379
379
|
|
|
380
380
|
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
|
|
381
381
|
if mode == 'native':
|
|
382
|
-
return NativeOutputSchema(
|
|
383
|
-
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
384
|
-
)
|
|
382
|
+
return NativeOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
|
|
385
383
|
elif mode == 'prompted':
|
|
386
|
-
return PromptedOutputSchema(
|
|
387
|
-
processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls
|
|
388
|
-
)
|
|
384
|
+
return PromptedOutputSchema(processor=self.processor, allows_deferred_tools=self.allows_deferred_tools)
|
|
389
385
|
elif mode == 'tool':
|
|
390
|
-
return ToolOutputSchema(toolset=self.toolset,
|
|
386
|
+
return ToolOutputSchema(toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools)
|
|
391
387
|
else:
|
|
392
388
|
assert_never(mode)
|
|
393
389
|
|
|
@@ -550,8 +546,8 @@ class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]):
|
|
|
550
546
|
class ToolOutputSchema(OutputSchema[OutputDataT]):
|
|
551
547
|
_toolset: OutputToolset[Any] | None
|
|
552
548
|
|
|
553
|
-
def __init__(self, toolset: OutputToolset[Any] | None,
|
|
554
|
-
super().__init__(
|
|
549
|
+
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tools: bool):
|
|
550
|
+
super().__init__(allows_deferred_tools)
|
|
555
551
|
self._toolset = toolset
|
|
556
552
|
|
|
557
553
|
@property
|
|
@@ -575,9 +571,9 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem
|
|
|
575
571
|
self,
|
|
576
572
|
processor: PlainTextOutputProcessor[OutputDataT] | None,
|
|
577
573
|
toolset: OutputToolset[Any] | None,
|
|
578
|
-
|
|
574
|
+
allows_deferred_tools: bool,
|
|
579
575
|
):
|
|
580
|
-
super().__init__(toolset=toolset,
|
|
576
|
+
super().__init__(toolset=toolset, allows_deferred_tools=allows_deferred_tools)
|
|
581
577
|
self.processor = processor
|
|
582
578
|
|
|
583
579
|
@property
|
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -15,7 +15,7 @@ from __future__ import annotations as _annotations
|
|
|
15
15
|
|
|
16
16
|
from collections.abc import Hashable
|
|
17
17
|
from dataclasses import dataclass, field, replace
|
|
18
|
-
from typing import Any
|
|
18
|
+
from typing import Any
|
|
19
19
|
|
|
20
20
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
21
21
|
from pydantic_ai.messages import (
|
|
@@ -38,7 +38,7 @@ VendorId = Hashable
|
|
|
38
38
|
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
39
39
|
"""
|
|
40
40
|
|
|
41
|
-
ManagedPart =
|
|
41
|
+
ManagedPart = ModelResponsePart | ToolCallPartDelta
|
|
42
42
|
"""
|
|
43
43
|
A union of types that are managed by the ModelResponsePartsManager.
|
|
44
44
|
Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
|
|
@@ -262,14 +262,14 @@ class ModelResponsePartsManager:
|
|
|
262
262
|
if tool_name is None and self._parts:
|
|
263
263
|
part_index = len(self._parts) - 1
|
|
264
264
|
latest_part = self._parts[part_index]
|
|
265
|
-
if isinstance(latest_part,
|
|
265
|
+
if isinstance(latest_part, ToolCallPart | ToolCallPartDelta): # pragma: no branch
|
|
266
266
|
existing_matching_part_and_index = latest_part, part_index
|
|
267
267
|
else:
|
|
268
268
|
# vendor_part_id is provided, so look up the corresponding part or delta
|
|
269
269
|
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
270
270
|
if part_index is not None:
|
|
271
271
|
existing_part = self._parts[part_index]
|
|
272
|
-
if not isinstance(existing_part,
|
|
272
|
+
if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart):
|
|
273
273
|
raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
|
|
274
274
|
existing_matching_part_and_index = existing_part, part_index
|
|
275
275
|
|
pydantic_ai/_run_context.py
CHANGED
|
@@ -18,7 +18,7 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
|
18
18
|
"""Type variable for agent dependencies."""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
@dataclasses.dataclass(repr=False)
|
|
21
|
+
@dataclasses.dataclass(repr=False, kw_only=True)
|
|
22
22
|
class RunContext(Generic[AgentDepsT]):
|
|
23
23
|
"""Information about the current call."""
|
|
24
24
|
|
|
@@ -46,5 +46,7 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
46
46
|
"""Number of retries so far."""
|
|
47
47
|
run_step: int = 0
|
|
48
48
|
"""The current step in the run."""
|
|
49
|
+
tool_call_approved: bool = False
|
|
50
|
+
"""Whether a tool call that required approval has now been approved."""
|
|
49
51
|
|
|
50
52
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from collections.abc import Awaitable
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Generic, cast
|
|
7
7
|
|
|
8
8
|
from . import _utils
|
|
9
9
|
from ._run_context import AgentDepsT, RunContext
|
pydantic_ai/_tool_manager.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Iterable
|
|
5
4
|
from dataclasses import dataclass, field, replace
|
|
6
5
|
from typing import Any, Generic
|
|
7
6
|
|
|
@@ -13,7 +12,6 @@ from . import messages as _messages
|
|
|
13
12
|
from ._run_context import AgentDepsT, RunContext
|
|
14
13
|
from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
|
|
15
14
|
from .messages import ToolCallPart
|
|
16
|
-
from .output import DeferredToolCalls
|
|
17
15
|
from .tools import ToolDefinition
|
|
18
16
|
from .toolsets.abstract import AbstractToolset, ToolsetTool
|
|
19
17
|
|
|
@@ -106,6 +104,9 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
106
104
|
msg = 'No tools available.'
|
|
107
105
|
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
|
|
108
106
|
|
|
107
|
+
if tool.tool_def.defer:
|
|
108
|
+
raise RuntimeError('Deferred tools cannot be called')
|
|
109
|
+
|
|
109
110
|
ctx = replace(
|
|
110
111
|
self.ctx,
|
|
111
112
|
tool_name=name,
|
|
@@ -204,23 +205,3 @@ class ToolManager(Generic[AgentDepsT]):
|
|
|
204
205
|
)
|
|
205
206
|
|
|
206
207
|
return tool_result
|
|
207
|
-
|
|
208
|
-
def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
|
|
209
|
-
"""Get the deferred tool calls from the model response parts."""
|
|
210
|
-
deferred_calls_and_defs = [
|
|
211
|
-
(part, tool_def)
|
|
212
|
-
for part in parts
|
|
213
|
-
if isinstance(part, _messages.ToolCallPart)
|
|
214
|
-
and (tool_def := self.get_tool_def(part.tool_name))
|
|
215
|
-
and tool_def.kind == 'deferred'
|
|
216
|
-
]
|
|
217
|
-
if not deferred_calls_and_defs:
|
|
218
|
-
return None
|
|
219
|
-
|
|
220
|
-
deferred_calls: list[_messages.ToolCallPart] = []
|
|
221
|
-
deferred_tool_defs: dict[str, ToolDefinition] = {}
|
|
222
|
-
for part, tool_def in deferred_calls_and_defs:
|
|
223
|
-
deferred_calls.append(part)
|
|
224
|
-
deferred_tool_defs[part.tool_name] = tool_def
|
|
225
|
-
|
|
226
|
-
return DeferredToolCalls(deferred_calls, deferred_tool_defs)
|
pydantic_ai/_utils.py
CHANGED
|
@@ -4,34 +4,28 @@ import asyncio
|
|
|
4
4
|
import functools
|
|
5
5
|
import inspect
|
|
6
6
|
import re
|
|
7
|
-
import sys
|
|
8
7
|
import time
|
|
9
8
|
import uuid
|
|
10
|
-
import
|
|
11
|
-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
9
|
+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
|
|
12
10
|
from contextlib import asynccontextmanager, suppress
|
|
13
11
|
from dataclasses import dataclass, fields, is_dataclass
|
|
14
12
|
from datetime import datetime, timezone
|
|
15
13
|
from functools import partial
|
|
16
14
|
from types import GenericAlias
|
|
17
|
-
from typing import TYPE_CHECKING, Any,
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
|
|
18
16
|
|
|
19
17
|
from anyio.to_thread import run_sync
|
|
20
18
|
from pydantic import BaseModel, TypeAdapter
|
|
21
19
|
from pydantic.json_schema import JsonSchemaValue
|
|
22
20
|
from typing_extensions import (
|
|
23
21
|
ParamSpec,
|
|
24
|
-
TypeAlias,
|
|
25
|
-
TypeGuard,
|
|
26
22
|
TypeIs,
|
|
27
|
-
get_args,
|
|
28
|
-
get_origin,
|
|
29
23
|
is_typeddict,
|
|
30
24
|
)
|
|
31
25
|
from typing_inspection import typing_objects
|
|
32
26
|
from typing_inspection.introspection import is_union_origin
|
|
33
27
|
|
|
34
|
-
from pydantic_graph._utils import AbstractSpan
|
|
28
|
+
from pydantic_graph._utils import AbstractSpan
|
|
35
29
|
|
|
36
30
|
from . import exceptions
|
|
37
31
|
|
|
@@ -96,7 +90,7 @@ class Some(Generic[T]):
|
|
|
96
90
|
value: T
|
|
97
91
|
|
|
98
92
|
|
|
99
|
-
Option: TypeAlias =
|
|
93
|
+
Option: TypeAlias = Some[T] | None
|
|
100
94
|
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
|
|
101
95
|
|
|
102
96
|
|
|
@@ -459,28 +453,22 @@ def strip_markdown_fences(text: str) -> str:
|
|
|
459
453
|
return text
|
|
460
454
|
|
|
461
455
|
|
|
456
|
+
def _unwrap_annotated(tp: Any) -> Any:
|
|
457
|
+
origin = get_origin(tp)
|
|
458
|
+
while typing_objects.is_annotated(origin):
|
|
459
|
+
tp = tp.__origin__
|
|
460
|
+
origin = get_origin(tp)
|
|
461
|
+
return tp
|
|
462
|
+
|
|
463
|
+
|
|
462
464
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
463
465
|
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
|
|
464
466
|
if typing_objects.is_typealiastype(tp):
|
|
465
467
|
tp = tp.__value__
|
|
466
468
|
|
|
469
|
+
tp = _unwrap_annotated(tp)
|
|
467
470
|
origin = get_origin(tp)
|
|
468
471
|
if is_union_origin(origin):
|
|
469
|
-
return get_args(tp)
|
|
472
|
+
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
|
|
470
473
|
else:
|
|
471
474
|
return ()
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
# The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10,
|
|
475
|
-
# but 3.9 still needs it to have the intended behavior.
|
|
476
|
-
|
|
477
|
-
if sys.version_info < (3, 10):
|
|
478
|
-
|
|
479
|
-
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
480
|
-
with warnings.catch_warnings():
|
|
481
|
-
warnings.simplefilter('ignore', DeprecationWarning)
|
|
482
|
-
return asyncio.Lock(loop=get_event_loop())
|
|
483
|
-
else:
|
|
484
|
-
|
|
485
|
-
def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
|
|
486
|
-
return asyncio.Lock()
|
pydantic_ai/ag_ui.py
CHANGED
|
@@ -8,12 +8,11 @@ from __future__ import annotations
|
|
|
8
8
|
|
|
9
9
|
import json
|
|
10
10
|
import uuid
|
|
11
|
-
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
|
|
11
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
|
|
12
12
|
from dataclasses import Field, dataclass, replace
|
|
13
13
|
from http import HTTPStatus
|
|
14
14
|
from typing import (
|
|
15
15
|
Any,
|
|
16
|
-
Callable,
|
|
17
16
|
ClassVar,
|
|
18
17
|
Final,
|
|
19
18
|
Generic,
|
|
@@ -46,11 +45,11 @@ from .messages import (
|
|
|
46
45
|
UserPromptPart,
|
|
47
46
|
)
|
|
48
47
|
from .models import KnownModelName, Model
|
|
49
|
-
from .output import
|
|
48
|
+
from .output import OutputDataT, OutputSpec
|
|
50
49
|
from .settings import ModelSettings
|
|
51
|
-
from .tools import AgentDepsT, ToolDefinition
|
|
50
|
+
from .tools import AgentDepsT, DeferredToolRequests, ToolDefinition
|
|
52
51
|
from .toolsets import AbstractToolset
|
|
53
|
-
from .toolsets.
|
|
52
|
+
from .toolsets.external import ExternalToolset
|
|
54
53
|
from .usage import RunUsage, UsageLimits
|
|
55
54
|
|
|
56
55
|
try:
|
|
@@ -343,7 +342,7 @@ async def run_ag_ui(
|
|
|
343
342
|
|
|
344
343
|
async with agent.iter(
|
|
345
344
|
user_prompt=None,
|
|
346
|
-
output_type=[output_type or agent.output_type,
|
|
345
|
+
output_type=[output_type or agent.output_type, DeferredToolRequests],
|
|
347
346
|
message_history=messages,
|
|
348
347
|
model=model,
|
|
349
348
|
deps=deps,
|
|
@@ -515,7 +514,7 @@ async def _handle_tool_result_event(
|
|
|
515
514
|
content = result.content
|
|
516
515
|
if isinstance(content, BaseEvent):
|
|
517
516
|
yield content
|
|
518
|
-
elif isinstance(content,
|
|
517
|
+
elif isinstance(content, str | bytes): # pragma: no branch
|
|
519
518
|
# Avoid iterable check for strings and bytes.
|
|
520
519
|
pass
|
|
521
520
|
elif isinstance(content, Iterable): # pragma: no branch
|
|
@@ -681,7 +680,7 @@ class _ToolCallNotFoundError(_RunError, ValueError):
|
|
|
681
680
|
)
|
|
682
681
|
|
|
683
682
|
|
|
684
|
-
class _AGUIFrontendToolset(
|
|
683
|
+
class _AGUIFrontendToolset(ExternalToolset[AgentDepsT]):
|
|
685
684
|
def __init__(self, tools: list[AGUITool]):
|
|
686
685
|
super().__init__(
|
|
687
686
|
[
|