pydantic-ai-slim 0.0.55__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/PKG-INFO +5 -5
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/__init__.py +10 -3
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_agent_graph.py +70 -59
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_cli.py +1 -2
- pydantic_ai_slim-0.0.55/pydantic_ai/_result.py → pydantic_ai_slim-0.1.1/pydantic_ai/_output.py +69 -47
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_utils.py +20 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/agent.py +511 -161
- pydantic_ai_slim-0.1.1/pydantic_ai/format_as_xml.py +9 -0
- pydantic_ai_slim-0.0.55/pydantic_ai/format_as_xml.py → pydantic_ai_slim-0.1.1/pydantic_ai/format_prompt.py +1 -1
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/messages.py +104 -21
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/__init__.py +24 -4
- pydantic_ai_slim-0.1.1/pydantic_ai/models/_json_schema.py +160 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/anthropic.py +5 -3
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/bedrock.py +100 -22
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/cohere.py +48 -44
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/fallback.py +2 -1
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/function.py +8 -8
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/gemini.py +82 -75
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/groq.py +32 -28
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/instrumented.py +4 -4
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/mistral.py +62 -58
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/openai.py +110 -158
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/test.py +45 -46
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/result.py +203 -90
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/tools.py +4 -4
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pyproject.toml +2 -2
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/README.md +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.1}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,13 +29,13 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.1.1
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
36
36
|
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
37
37
|
Provides-Extra: bedrock
|
|
38
|
-
Requires-Dist: boto3>=1.
|
|
38
|
+
Requires-Dist: boto3>=1.35.74; extra == 'bedrock'
|
|
39
39
|
Provides-Extra: cli
|
|
40
40
|
Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
|
|
41
41
|
Requires-Dist: prompt-toolkit>=3; extra == 'cli'
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.1.1; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -55,7 +55,7 @@ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
|
|
|
55
55
|
Provides-Extra: mistral
|
|
56
56
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
57
57
|
Provides-Extra: openai
|
|
58
|
-
Requires-Dist: openai>=1.
|
|
58
|
+
Requires-Dist: openai>=1.74.0; extra == 'openai'
|
|
59
59
|
Provides-Extra: tavily
|
|
60
60
|
Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
|
|
61
61
|
Provides-Extra: vertexai
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from importlib.metadata import version
|
|
1
|
+
from importlib.metadata import version as _metadata_version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import (
|
|
@@ -10,7 +10,9 @@ from .exceptions import (
|
|
|
10
10
|
UsageLimitExceeded,
|
|
11
11
|
UserError,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from .format_prompt import format_as_xml
|
|
14
|
+
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
|
|
15
|
+
from .result import ToolOutput
|
|
14
16
|
from .tools import RunContext, Tool
|
|
15
17
|
|
|
16
18
|
__all__ = (
|
|
@@ -33,10 +35,15 @@ __all__ = (
|
|
|
33
35
|
# messages
|
|
34
36
|
'ImageUrl',
|
|
35
37
|
'AudioUrl',
|
|
38
|
+
'VideoUrl',
|
|
36
39
|
'DocumentUrl',
|
|
37
40
|
'BinaryContent',
|
|
38
41
|
# tools
|
|
39
42
|
'Tool',
|
|
40
43
|
'RunContext',
|
|
44
|
+
# result
|
|
45
|
+
'ToolOutput',
|
|
46
|
+
# format_prompt
|
|
47
|
+
'format_as_xml',
|
|
41
48
|
)
|
|
42
|
-
__version__ =
|
|
49
|
+
__version__ = _metadata_version('pydantic_ai_slim')
|
|
@@ -3,11 +3,11 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
import json
|
|
6
|
-
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
6
|
+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
10
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
11
11
|
|
|
12
12
|
from opentelemetry.trace import Span, Tracer
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
@@ -16,7 +16,7 @@ from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
|
16
16
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
17
|
|
|
18
18
|
from . import (
|
|
19
|
-
|
|
19
|
+
_output,
|
|
20
20
|
_system_prompt,
|
|
21
21
|
exceptions,
|
|
22
22
|
messages as _messages,
|
|
@@ -25,7 +25,7 @@ from . import (
|
|
|
25
25
|
usage as _usage,
|
|
26
26
|
)
|
|
27
27
|
from .models.instrumented import InstrumentedModel
|
|
28
|
-
from .result import
|
|
28
|
+
from .result import OutputDataT, ToolOutput
|
|
29
29
|
from .settings import ModelSettings, merge_model_settings
|
|
30
30
|
from .tools import RunContext, Tool, ToolDefinition
|
|
31
31
|
|
|
@@ -53,7 +53,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
53
53
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
54
54
|
"""
|
|
55
55
|
DepsT = TypeVar('DepsT')
|
|
56
|
-
|
|
56
|
+
OutputT = TypeVar('OutputT')
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@dataclasses.dataclass
|
|
@@ -74,7 +74,7 @@ class GraphAgentState:
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
@dataclasses.dataclass
|
|
77
|
-
class GraphAgentDeps(Generic[DepsT,
|
|
77
|
+
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
78
78
|
"""Dependencies/config passed to the agent graph."""
|
|
79
79
|
|
|
80
80
|
user_deps: DepsT
|
|
@@ -87,10 +87,10 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
87
87
|
usage_limits: _usage.UsageLimits
|
|
88
88
|
max_result_retries: int
|
|
89
89
|
end_strategy: EndStrategy
|
|
90
|
+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
|
|
92
|
+
output_schema: _output.OutputSchema[OutputDataT] | None
|
|
93
|
+
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
|
|
94
94
|
|
|
95
95
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
96
96
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
@@ -126,6 +126,9 @@ def is_agent_node(
|
|
|
126
126
|
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
127
127
|
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
128
128
|
|
|
129
|
+
instructions: str | None
|
|
130
|
+
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
131
|
+
|
|
129
132
|
system_prompts: tuple[str, ...]
|
|
130
133
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
131
134
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
@@ -139,7 +142,9 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
139
142
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
140
143
|
) -> _messages.ModelRequest:
|
|
141
144
|
run_context = build_run_context(ctx)
|
|
142
|
-
history, next_message = await self._prepare_messages(
|
|
145
|
+
history, next_message = await self._prepare_messages(
|
|
146
|
+
self.user_prompt, ctx.state.message_history, ctx.deps.get_instructions, run_context
|
|
147
|
+
)
|
|
143
148
|
ctx.state.message_history = history
|
|
144
149
|
run_context.messages = history
|
|
145
150
|
|
|
@@ -153,6 +158,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
153
158
|
self,
|
|
154
159
|
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
155
160
|
message_history: list[_messages.ModelMessage] | None,
|
|
161
|
+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]],
|
|
156
162
|
run_context: RunContext[DepsT],
|
|
157
163
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
158
164
|
try:
|
|
@@ -167,6 +173,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
167
173
|
ctx_messages.used = True
|
|
168
174
|
|
|
169
175
|
parts: list[_messages.ModelRequestPart] = []
|
|
176
|
+
instructions = await get_instructions(run_context)
|
|
170
177
|
if message_history:
|
|
171
178
|
# Shallow copy messages
|
|
172
179
|
messages.extend(message_history)
|
|
@@ -177,7 +184,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
177
184
|
|
|
178
185
|
if user_prompt is not None:
|
|
179
186
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
180
|
-
return messages, _messages.ModelRequest(parts)
|
|
187
|
+
return messages, _messages.ModelRequest(parts, instructions=instructions)
|
|
181
188
|
|
|
182
189
|
async def _reevaluate_dynamic_prompts(
|
|
183
190
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -233,11 +240,11 @@ async def _prepare_request_parameters(
|
|
|
233
240
|
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
234
241
|
)
|
|
235
242
|
|
|
236
|
-
|
|
243
|
+
output_schema = ctx.deps.output_schema
|
|
237
244
|
return models.ModelRequestParameters(
|
|
238
245
|
function_tools=function_tool_defs,
|
|
239
|
-
|
|
240
|
-
|
|
246
|
+
allow_text_output=allow_text_output(output_schema),
|
|
247
|
+
output_tools=output_schema.tool_defs() if output_schema is not None else [],
|
|
241
248
|
)
|
|
242
249
|
|
|
243
250
|
|
|
@@ -271,8 +278,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
271
278
|
async with self._stream(ctx) as streamed_response:
|
|
272
279
|
agent_stream = result.AgentStream[DepsT, T](
|
|
273
280
|
streamed_response,
|
|
274
|
-
ctx.deps.
|
|
275
|
-
ctx.deps.
|
|
281
|
+
ctx.deps.output_schema,
|
|
282
|
+
ctx.deps.output_validators,
|
|
276
283
|
build_run_context(ctx),
|
|
277
284
|
ctx.deps.usage_limits,
|
|
278
285
|
)
|
|
@@ -290,6 +297,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
290
297
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
291
298
|
|
|
292
299
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
300
|
+
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
293
301
|
async with ctx.deps.model.request_stream(
|
|
294
302
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
295
303
|
) as streamed_response:
|
|
@@ -431,17 +439,17 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
431
439
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
432
440
|
tool_calls: list[_messages.ToolCallPart],
|
|
433
441
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
434
|
-
|
|
442
|
+
output_schema = ctx.deps.output_schema
|
|
435
443
|
|
|
436
|
-
# first look for the
|
|
444
|
+
# first, look for the output tool call
|
|
437
445
|
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
438
446
|
parts: list[_messages.ModelRequestPart] = []
|
|
439
|
-
if
|
|
440
|
-
for call,
|
|
447
|
+
if output_schema is not None:
|
|
448
|
+
for call, output_tool in output_schema.find_tool(tool_calls):
|
|
441
449
|
try:
|
|
442
|
-
result_data =
|
|
443
|
-
result_data = await
|
|
444
|
-
except
|
|
450
|
+
result_data = output_tool.validate(call)
|
|
451
|
+
result_data = await _validate_output(result_data, ctx, call)
|
|
452
|
+
except _output.ToolRetryError as e:
|
|
445
453
|
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
446
454
|
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
447
455
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
@@ -466,7 +474,11 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
466
474
|
else:
|
|
467
475
|
if tool_responses:
|
|
468
476
|
parts.extend(tool_responses)
|
|
469
|
-
|
|
477
|
+
run_context = build_run_context(ctx)
|
|
478
|
+
instructions = await ctx.deps.get_instructions(run_context)
|
|
479
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
480
|
+
_messages.ModelRequest(parts=parts, instructions=instructions)
|
|
481
|
+
)
|
|
470
482
|
|
|
471
483
|
def _handle_final_result(
|
|
472
484
|
self,
|
|
@@ -488,9 +500,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
488
500
|
'all_messages_events': json.dumps(
|
|
489
501
|
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
|
|
490
502
|
),
|
|
491
|
-
'final_result': final_result.
|
|
492
|
-
if isinstance(final_result.
|
|
493
|
-
else json.dumps(InstrumentedModel.serialize_any(final_result.
|
|
503
|
+
'final_result': final_result.output
|
|
504
|
+
if isinstance(final_result.output, str)
|
|
505
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
|
|
494
506
|
}
|
|
495
507
|
)
|
|
496
508
|
run_span.set_attributes(
|
|
@@ -507,7 +519,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
507
519
|
}
|
|
508
520
|
)
|
|
509
521
|
|
|
510
|
-
# End the run with self.data
|
|
511
522
|
return End(final_result)
|
|
512
523
|
|
|
513
524
|
async def _handle_text_response(
|
|
@@ -515,14 +526,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
515
526
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
516
527
|
texts: list[str],
|
|
517
528
|
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
518
|
-
|
|
529
|
+
output_schema = ctx.deps.output_schema
|
|
519
530
|
|
|
520
531
|
text = '\n\n'.join(texts)
|
|
521
|
-
if
|
|
532
|
+
if allow_text_output(output_schema):
|
|
522
533
|
result_data_input = cast(NodeRunEndT, text)
|
|
523
534
|
try:
|
|
524
|
-
result_data = await
|
|
525
|
-
except
|
|
535
|
+
result_data = await _validate_output(result_data_input, ctx, None)
|
|
536
|
+
except _output.ToolRetryError as e:
|
|
526
537
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
527
538
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
528
539
|
else:
|
|
@@ -534,7 +545,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
534
545
|
_messages.ModelRequest(
|
|
535
546
|
parts=[
|
|
536
547
|
_messages.RetryPromptPart(
|
|
537
|
-
content='Plain text responses are not permitted, please
|
|
548
|
+
content='Plain text responses are not permitted, please include your response in a tool call',
|
|
538
549
|
)
|
|
539
550
|
]
|
|
540
551
|
)
|
|
@@ -555,8 +566,8 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
555
566
|
|
|
556
567
|
async def process_function_tools(
|
|
557
568
|
tool_calls: list[_messages.ToolCallPart],
|
|
558
|
-
|
|
559
|
-
|
|
569
|
+
output_tool_name: str | None,
|
|
570
|
+
output_tool_call_id: str | None,
|
|
560
571
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
561
572
|
output_parts: list[_messages.ModelRequestPart],
|
|
562
573
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
@@ -566,22 +577,22 @@ async def process_function_tools(
|
|
|
566
577
|
|
|
567
578
|
Because async iterators can't have return values, we use `output_parts` as an output argument.
|
|
568
579
|
"""
|
|
569
|
-
stub_function_tools = bool(
|
|
570
|
-
|
|
580
|
+
stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
|
|
581
|
+
output_schema = ctx.deps.output_schema
|
|
571
582
|
|
|
572
|
-
# we rely on the fact that if we found a result, it's the first
|
|
573
|
-
|
|
583
|
+
# we rely on the fact that if we found a result, it's the first output tool in the last
|
|
584
|
+
found_used_output_tool = False
|
|
574
585
|
run_context = build_run_context(ctx)
|
|
575
586
|
|
|
576
587
|
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
577
588
|
call_index_to_event_id: dict[int, str] = {}
|
|
578
589
|
for call in tool_calls:
|
|
579
590
|
if (
|
|
580
|
-
call.tool_name ==
|
|
581
|
-
and call.tool_call_id ==
|
|
582
|
-
and not
|
|
591
|
+
call.tool_name == output_tool_name
|
|
592
|
+
and call.tool_call_id == output_tool_call_id
|
|
593
|
+
and not found_used_output_tool
|
|
583
594
|
):
|
|
584
|
-
|
|
595
|
+
found_used_output_tool = True
|
|
585
596
|
output_parts.append(
|
|
586
597
|
_messages.ToolReturnPart(
|
|
587
598
|
tool_name=call.tool_name,
|
|
@@ -618,15 +629,15 @@ async def process_function_tools(
|
|
|
618
629
|
yield event
|
|
619
630
|
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
620
631
|
calls_to_run.append((mcp_tool, call))
|
|
621
|
-
elif
|
|
622
|
-
# if tool_name is in
|
|
632
|
+
elif output_schema is not None and call.tool_name in output_schema.tools:
|
|
633
|
+
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
623
634
|
# validation, we don't add another part here
|
|
624
|
-
if
|
|
625
|
-
if
|
|
626
|
-
content = '
|
|
635
|
+
if output_tool_name is not None:
|
|
636
|
+
if found_used_output_tool:
|
|
637
|
+
content = 'Output tool not used - a final result was already processed.'
|
|
627
638
|
else:
|
|
628
639
|
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
629
|
-
content = '
|
|
640
|
+
content = 'Output tool not used - result failed validation.'
|
|
630
641
|
part = _messages.ToolReturnPart(
|
|
631
642
|
tool_name=call.tool_name,
|
|
632
643
|
content=content,
|
|
@@ -706,8 +717,8 @@ def _unknown_tool(
|
|
|
706
717
|
) -> _messages.RetryPromptPart:
|
|
707
718
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
708
719
|
tool_names = list(ctx.deps.function_tools.keys())
|
|
709
|
-
if
|
|
710
|
-
tool_names.extend(
|
|
720
|
+
if output_schema := ctx.deps.output_schema:
|
|
721
|
+
tool_names.extend(output_schema.tool_names())
|
|
711
722
|
|
|
712
723
|
if tool_names:
|
|
713
724
|
msg = f'Available tools: {", ".join(tool_names)}'
|
|
@@ -721,20 +732,20 @@ def _unknown_tool(
|
|
|
721
732
|
)
|
|
722
733
|
|
|
723
734
|
|
|
724
|
-
async def
|
|
735
|
+
async def _validate_output(
|
|
725
736
|
result_data: T,
|
|
726
737
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
727
738
|
tool_call: _messages.ToolCallPart | None,
|
|
728
739
|
) -> T:
|
|
729
|
-
for validator in ctx.deps.
|
|
740
|
+
for validator in ctx.deps.output_validators:
|
|
730
741
|
run_context = build_run_context(ctx)
|
|
731
742
|
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
732
743
|
return result_data
|
|
733
744
|
|
|
734
745
|
|
|
735
|
-
def
|
|
746
|
+
def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
|
|
736
747
|
"""Check if the result schema allows text results."""
|
|
737
|
-
return
|
|
748
|
+
return output_schema is None or output_schema.allow_text_output
|
|
738
749
|
|
|
739
750
|
|
|
740
751
|
@dataclasses.dataclass
|
|
@@ -786,19 +797,19 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
786
797
|
|
|
787
798
|
|
|
788
799
|
def build_agent_graph(
|
|
789
|
-
name: str | None, deps_type: type[DepsT],
|
|
790
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[
|
|
800
|
+
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
|
|
801
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
|
|
791
802
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
792
803
|
nodes = (
|
|
793
804
|
UserPromptNode[DepsT],
|
|
794
805
|
ModelRequestNode[DepsT],
|
|
795
806
|
CallToolsNode[DepsT],
|
|
796
807
|
)
|
|
797
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[
|
|
808
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
|
|
798
809
|
nodes=nodes,
|
|
799
810
|
name=name or 'Agent',
|
|
800
811
|
state_type=GraphAgentState,
|
|
801
|
-
run_end_type=result.FinalResult[
|
|
812
|
+
run_end_type=result.FinalResult[OutputT],
|
|
802
813
|
auto_instrument=False,
|
|
803
814
|
)
|
|
804
815
|
return graph
|
|
@@ -208,14 +208,13 @@ async def ask_agent(
|
|
|
208
208
|
if not stream:
|
|
209
209
|
with status:
|
|
210
210
|
result = await agent.run(prompt, message_history=messages)
|
|
211
|
-
content = result.
|
|
211
|
+
content = result.output
|
|
212
212
|
console.print(Markdown(content, code_theme=code_theme))
|
|
213
213
|
return result.all_messages()
|
|
214
214
|
|
|
215
215
|
with status, ExitStack() as stack:
|
|
216
216
|
async with agent.iter(prompt, message_history=messages) as agent_run:
|
|
217
217
|
live = Live('', refresh_per_second=15, console=console, vertical_overflow='visible')
|
|
218
|
-
content: str = ''
|
|
219
218
|
async for node in agent_run:
|
|
220
219
|
if Agent.is_model_request_node(node):
|
|
221
220
|
async with node.stream(agent_run.ctx) as handle_stream:
|