pydantic-ai-slim 0.0.55__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/PKG-INFO +5 -5
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/__init__.py +10 -3
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_agent_graph.py +67 -55
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_cli.py +1 -2
- pydantic_ai_slim-0.0.55/pydantic_ai/_result.py → pydantic_ai_slim-0.1.0/pydantic_ai/_output.py +69 -47
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_utils.py +20 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/agent.py +501 -161
- pydantic_ai_slim-0.1.0/pydantic_ai/format_as_xml.py +9 -0
- pydantic_ai_slim-0.0.55/pydantic_ai/format_as_xml.py → pydantic_ai_slim-0.1.0/pydantic_ai/format_prompt.py +1 -1
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/messages.py +104 -21
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/__init__.py +24 -4
- pydantic_ai_slim-0.1.0/pydantic_ai/models/_json_schema.py +156 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/anthropic.py +5 -3
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/bedrock.py +100 -22
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/cohere.py +48 -44
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/fallback.py +2 -1
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/function.py +8 -8
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/gemini.py +65 -75
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/groq.py +32 -28
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/instrumented.py +4 -4
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/mistral.py +62 -58
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/openai.py +110 -158
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/test.py +45 -46
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/result.py +203 -90
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pyproject.toml +2 -2
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/README.md +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.55 → pydantic_ai_slim-0.1.0}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,13 +29,13 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0
|
|
32
|
+
Requires-Dist: pydantic-graph==0.1.0
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
36
36
|
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
37
37
|
Provides-Extra: bedrock
|
|
38
|
-
Requires-Dist: boto3>=1.
|
|
38
|
+
Requires-Dist: boto3>=1.35.74; extra == 'bedrock'
|
|
39
39
|
Provides-Extra: cli
|
|
40
40
|
Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
|
|
41
41
|
Requires-Dist: prompt-toolkit>=3; extra == 'cli'
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.0
|
|
48
|
+
Requires-Dist: pydantic-evals==0.1.0; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -55,7 +55,7 @@ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
|
|
|
55
55
|
Provides-Extra: mistral
|
|
56
56
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
57
57
|
Provides-Extra: openai
|
|
58
|
-
Requires-Dist: openai>=1.
|
|
58
|
+
Requires-Dist: openai>=1.74.0; extra == 'openai'
|
|
59
59
|
Provides-Extra: tavily
|
|
60
60
|
Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
|
|
61
61
|
Provides-Extra: vertexai
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from importlib.metadata import version
|
|
1
|
+
from importlib.metadata import version as _metadata_version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import (
|
|
@@ -10,7 +10,9 @@ from .exceptions import (
|
|
|
10
10
|
UsageLimitExceeded,
|
|
11
11
|
UserError,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from .format_prompt import format_as_xml
|
|
14
|
+
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
|
|
15
|
+
from .result import ToolOutput
|
|
14
16
|
from .tools import RunContext, Tool
|
|
15
17
|
|
|
16
18
|
__all__ = (
|
|
@@ -33,10 +35,15 @@ __all__ = (
|
|
|
33
35
|
# messages
|
|
34
36
|
'ImageUrl',
|
|
35
37
|
'AudioUrl',
|
|
38
|
+
'VideoUrl',
|
|
36
39
|
'DocumentUrl',
|
|
37
40
|
'BinaryContent',
|
|
38
41
|
# tools
|
|
39
42
|
'Tool',
|
|
40
43
|
'RunContext',
|
|
44
|
+
# result
|
|
45
|
+
'ToolOutput',
|
|
46
|
+
# format_prompt
|
|
47
|
+
'format_as_xml',
|
|
41
48
|
)
|
|
42
|
-
__version__ =
|
|
49
|
+
__version__ = _metadata_version('pydantic_ai_slim')
|
|
@@ -16,7 +16,7 @@ from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
|
16
16
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
17
|
|
|
18
18
|
from . import (
|
|
19
|
-
|
|
19
|
+
_output,
|
|
20
20
|
_system_prompt,
|
|
21
21
|
exceptions,
|
|
22
22
|
messages as _messages,
|
|
@@ -25,7 +25,7 @@ from . import (
|
|
|
25
25
|
usage as _usage,
|
|
26
26
|
)
|
|
27
27
|
from .models.instrumented import InstrumentedModel
|
|
28
|
-
from .result import
|
|
28
|
+
from .result import OutputDataT, ToolOutput
|
|
29
29
|
from .settings import ModelSettings, merge_model_settings
|
|
30
30
|
from .tools import RunContext, Tool, ToolDefinition
|
|
31
31
|
|
|
@@ -53,7 +53,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
53
53
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
54
54
|
"""
|
|
55
55
|
DepsT = TypeVar('DepsT')
|
|
56
|
-
|
|
56
|
+
OutputT = TypeVar('OutputT')
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@dataclasses.dataclass
|
|
@@ -74,7 +74,7 @@ class GraphAgentState:
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
@dataclasses.dataclass
|
|
77
|
-
class GraphAgentDeps(Generic[DepsT,
|
|
77
|
+
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
78
78
|
"""Dependencies/config passed to the agent graph."""
|
|
79
79
|
|
|
80
80
|
user_deps: DepsT
|
|
@@ -88,9 +88,8 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
88
88
|
max_result_retries: int
|
|
89
89
|
end_strategy: EndStrategy
|
|
90
90
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
|
|
91
|
+
output_schema: _output.OutputSchema[OutputDataT] | None
|
|
92
|
+
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
|
|
94
93
|
|
|
95
94
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
96
95
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
@@ -126,6 +125,9 @@ def is_agent_node(
|
|
|
126
125
|
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
127
126
|
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
128
127
|
|
|
128
|
+
instructions: str | None
|
|
129
|
+
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
130
|
+
|
|
129
131
|
system_prompts: tuple[str, ...]
|
|
130
132
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
131
133
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
@@ -167,6 +169,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
167
169
|
ctx_messages.used = True
|
|
168
170
|
|
|
169
171
|
parts: list[_messages.ModelRequestPart] = []
|
|
172
|
+
instructions = await self._instructions(run_context)
|
|
170
173
|
if message_history:
|
|
171
174
|
# Shallow copy messages
|
|
172
175
|
messages.extend(message_history)
|
|
@@ -177,7 +180,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
177
180
|
|
|
178
181
|
if user_prompt is not None:
|
|
179
182
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
180
|
-
return messages, _messages.ModelRequest(parts)
|
|
183
|
+
return messages, _messages.ModelRequest(parts, instructions=instructions)
|
|
181
184
|
|
|
182
185
|
async def _reevaluate_dynamic_prompts(
|
|
183
186
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -207,6 +210,15 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
207
210
|
messages.append(_messages.SystemPromptPart(prompt))
|
|
208
211
|
return messages
|
|
209
212
|
|
|
213
|
+
async def _instructions(self, run_context: RunContext[DepsT]) -> str | None:
|
|
214
|
+
if self.instructions is None and not self.instructions_functions:
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
instructions = self.instructions or ''
|
|
218
|
+
for instructions_runner in self.instructions_functions:
|
|
219
|
+
instructions += await instructions_runner.run(run_context)
|
|
220
|
+
return instructions
|
|
221
|
+
|
|
210
222
|
|
|
211
223
|
async def _prepare_request_parameters(
|
|
212
224
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
@@ -233,11 +245,11 @@ async def _prepare_request_parameters(
|
|
|
233
245
|
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
234
246
|
)
|
|
235
247
|
|
|
236
|
-
|
|
248
|
+
output_schema = ctx.deps.output_schema
|
|
237
249
|
return models.ModelRequestParameters(
|
|
238
250
|
function_tools=function_tool_defs,
|
|
239
|
-
|
|
240
|
-
|
|
251
|
+
allow_text_output=allow_text_output(output_schema),
|
|
252
|
+
output_tools=output_schema.tool_defs() if output_schema is not None else [],
|
|
241
253
|
)
|
|
242
254
|
|
|
243
255
|
|
|
@@ -271,8 +283,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
271
283
|
async with self._stream(ctx) as streamed_response:
|
|
272
284
|
agent_stream = result.AgentStream[DepsT, T](
|
|
273
285
|
streamed_response,
|
|
274
|
-
ctx.deps.
|
|
275
|
-
ctx.deps.
|
|
286
|
+
ctx.deps.output_schema,
|
|
287
|
+
ctx.deps.output_validators,
|
|
276
288
|
build_run_context(ctx),
|
|
277
289
|
ctx.deps.usage_limits,
|
|
278
290
|
)
|
|
@@ -290,6 +302,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
290
302
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
291
303
|
|
|
292
304
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
305
|
+
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
293
306
|
async with ctx.deps.model.request_stream(
|
|
294
307
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
295
308
|
) as streamed_response:
|
|
@@ -431,17 +444,17 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
431
444
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
432
445
|
tool_calls: list[_messages.ToolCallPart],
|
|
433
446
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
434
|
-
|
|
447
|
+
output_schema = ctx.deps.output_schema
|
|
435
448
|
|
|
436
|
-
# first look for the
|
|
449
|
+
# first, look for the output tool call
|
|
437
450
|
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
438
451
|
parts: list[_messages.ModelRequestPart] = []
|
|
439
|
-
if
|
|
440
|
-
for call,
|
|
452
|
+
if output_schema is not None:
|
|
453
|
+
for call, output_tool in output_schema.find_tool(tool_calls):
|
|
441
454
|
try:
|
|
442
|
-
result_data =
|
|
443
|
-
result_data = await
|
|
444
|
-
except
|
|
455
|
+
result_data = output_tool.validate(call)
|
|
456
|
+
result_data = await _validate_output(result_data, ctx, call)
|
|
457
|
+
except _output.ToolRetryError as e:
|
|
445
458
|
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
446
459
|
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
447
460
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
@@ -488,9 +501,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
488
501
|
'all_messages_events': json.dumps(
|
|
489
502
|
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
|
|
490
503
|
),
|
|
491
|
-
'final_result': final_result.
|
|
492
|
-
if isinstance(final_result.
|
|
493
|
-
else json.dumps(InstrumentedModel.serialize_any(final_result.
|
|
504
|
+
'final_result': final_result.output
|
|
505
|
+
if isinstance(final_result.output, str)
|
|
506
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
|
|
494
507
|
}
|
|
495
508
|
)
|
|
496
509
|
run_span.set_attributes(
|
|
@@ -507,7 +520,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
507
520
|
}
|
|
508
521
|
)
|
|
509
522
|
|
|
510
|
-
# End the run with self.data
|
|
511
523
|
return End(final_result)
|
|
512
524
|
|
|
513
525
|
async def _handle_text_response(
|
|
@@ -515,14 +527,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
515
527
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
516
528
|
texts: list[str],
|
|
517
529
|
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
518
|
-
|
|
530
|
+
output_schema = ctx.deps.output_schema
|
|
519
531
|
|
|
520
532
|
text = '\n\n'.join(texts)
|
|
521
|
-
if
|
|
533
|
+
if allow_text_output(output_schema):
|
|
522
534
|
result_data_input = cast(NodeRunEndT, text)
|
|
523
535
|
try:
|
|
524
|
-
result_data = await
|
|
525
|
-
except
|
|
536
|
+
result_data = await _validate_output(result_data_input, ctx, None)
|
|
537
|
+
except _output.ToolRetryError as e:
|
|
526
538
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
527
539
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
528
540
|
else:
|
|
@@ -534,7 +546,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
534
546
|
_messages.ModelRequest(
|
|
535
547
|
parts=[
|
|
536
548
|
_messages.RetryPromptPart(
|
|
537
|
-
content='Plain text responses are not permitted, please
|
|
549
|
+
content='Plain text responses are not permitted, please include your response in a tool call',
|
|
538
550
|
)
|
|
539
551
|
]
|
|
540
552
|
)
|
|
@@ -555,8 +567,8 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
555
567
|
|
|
556
568
|
async def process_function_tools(
|
|
557
569
|
tool_calls: list[_messages.ToolCallPart],
|
|
558
|
-
|
|
559
|
-
|
|
570
|
+
output_tool_name: str | None,
|
|
571
|
+
output_tool_call_id: str | None,
|
|
560
572
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
561
573
|
output_parts: list[_messages.ModelRequestPart],
|
|
562
574
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
@@ -566,22 +578,22 @@ async def process_function_tools(
|
|
|
566
578
|
|
|
567
579
|
Because async iterators can't have return values, we use `output_parts` as an output argument.
|
|
568
580
|
"""
|
|
569
|
-
stub_function_tools = bool(
|
|
570
|
-
|
|
581
|
+
stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
|
|
582
|
+
output_schema = ctx.deps.output_schema
|
|
571
583
|
|
|
572
|
-
# we rely on the fact that if we found a result, it's the first
|
|
573
|
-
|
|
584
|
+
# we rely on the fact that if we found a result, it's the first output tool in the last
|
|
585
|
+
found_used_output_tool = False
|
|
574
586
|
run_context = build_run_context(ctx)
|
|
575
587
|
|
|
576
588
|
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
577
589
|
call_index_to_event_id: dict[int, str] = {}
|
|
578
590
|
for call in tool_calls:
|
|
579
591
|
if (
|
|
580
|
-
call.tool_name ==
|
|
581
|
-
and call.tool_call_id ==
|
|
582
|
-
and not
|
|
592
|
+
call.tool_name == output_tool_name
|
|
593
|
+
and call.tool_call_id == output_tool_call_id
|
|
594
|
+
and not found_used_output_tool
|
|
583
595
|
):
|
|
584
|
-
|
|
596
|
+
found_used_output_tool = True
|
|
585
597
|
output_parts.append(
|
|
586
598
|
_messages.ToolReturnPart(
|
|
587
599
|
tool_name=call.tool_name,
|
|
@@ -618,15 +630,15 @@ async def process_function_tools(
|
|
|
618
630
|
yield event
|
|
619
631
|
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
620
632
|
calls_to_run.append((mcp_tool, call))
|
|
621
|
-
elif
|
|
622
|
-
# if tool_name is in
|
|
633
|
+
elif output_schema is not None and call.tool_name in output_schema.tools:
|
|
634
|
+
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
623
635
|
# validation, we don't add another part here
|
|
624
|
-
if
|
|
625
|
-
if
|
|
626
|
-
content = '
|
|
636
|
+
if output_tool_name is not None:
|
|
637
|
+
if found_used_output_tool:
|
|
638
|
+
content = 'Output tool not used - a final result was already processed.'
|
|
627
639
|
else:
|
|
628
640
|
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
629
|
-
content = '
|
|
641
|
+
content = 'Output tool not used - result failed validation.'
|
|
630
642
|
part = _messages.ToolReturnPart(
|
|
631
643
|
tool_name=call.tool_name,
|
|
632
644
|
content=content,
|
|
@@ -706,8 +718,8 @@ def _unknown_tool(
|
|
|
706
718
|
) -> _messages.RetryPromptPart:
|
|
707
719
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
708
720
|
tool_names = list(ctx.deps.function_tools.keys())
|
|
709
|
-
if
|
|
710
|
-
tool_names.extend(
|
|
721
|
+
if output_schema := ctx.deps.output_schema:
|
|
722
|
+
tool_names.extend(output_schema.tool_names())
|
|
711
723
|
|
|
712
724
|
if tool_names:
|
|
713
725
|
msg = f'Available tools: {", ".join(tool_names)}'
|
|
@@ -721,20 +733,20 @@ def _unknown_tool(
|
|
|
721
733
|
)
|
|
722
734
|
|
|
723
735
|
|
|
724
|
-
async def
|
|
736
|
+
async def _validate_output(
|
|
725
737
|
result_data: T,
|
|
726
738
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
727
739
|
tool_call: _messages.ToolCallPart | None,
|
|
728
740
|
) -> T:
|
|
729
|
-
for validator in ctx.deps.
|
|
741
|
+
for validator in ctx.deps.output_validators:
|
|
730
742
|
run_context = build_run_context(ctx)
|
|
731
743
|
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
732
744
|
return result_data
|
|
733
745
|
|
|
734
746
|
|
|
735
|
-
def
|
|
747
|
+
def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
|
|
736
748
|
"""Check if the result schema allows text results."""
|
|
737
|
-
return
|
|
749
|
+
return output_schema is None or output_schema.allow_text_output
|
|
738
750
|
|
|
739
751
|
|
|
740
752
|
@dataclasses.dataclass
|
|
@@ -786,19 +798,19 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
786
798
|
|
|
787
799
|
|
|
788
800
|
def build_agent_graph(
|
|
789
|
-
name: str | None, deps_type: type[DepsT],
|
|
790
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[
|
|
801
|
+
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
|
|
802
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
|
|
791
803
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
792
804
|
nodes = (
|
|
793
805
|
UserPromptNode[DepsT],
|
|
794
806
|
ModelRequestNode[DepsT],
|
|
795
807
|
CallToolsNode[DepsT],
|
|
796
808
|
)
|
|
797
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[
|
|
809
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
|
|
798
810
|
nodes=nodes,
|
|
799
811
|
name=name or 'Agent',
|
|
800
812
|
state_type=GraphAgentState,
|
|
801
|
-
run_end_type=result.FinalResult[
|
|
813
|
+
run_end_type=result.FinalResult[OutputT],
|
|
802
814
|
auto_instrument=False,
|
|
803
815
|
)
|
|
804
816
|
return graph
|
|
@@ -208,14 +208,13 @@ async def ask_agent(
|
|
|
208
208
|
if not stream:
|
|
209
209
|
with status:
|
|
210
210
|
result = await agent.run(prompt, message_history=messages)
|
|
211
|
-
content = result.
|
|
211
|
+
content = result.output
|
|
212
212
|
console.print(Markdown(content, code_theme=code_theme))
|
|
213
213
|
return result.all_messages()
|
|
214
214
|
|
|
215
215
|
with status, ExitStack() as stack:
|
|
216
216
|
async with agent.iter(prompt, message_history=messages) as agent_run:
|
|
217
217
|
live = Live('', refresh_per_second=15, console=console, vertical_overflow='visible')
|
|
218
|
-
content: str = ''
|
|
219
218
|
async for node in agent_run:
|
|
220
219
|
if Agent.is_model_request_node(node):
|
|
221
220
|
async with node.stream(agent_run.ctx) as handle_stream:
|
pydantic_ai_slim-0.0.55/pydantic_ai/_result.py → pydantic_ai_slim-0.1.0/pydantic_ai/_output.py
RENAMED
|
@@ -12,7 +12,7 @@ from typing_inspection.introspection import is_union_origin
|
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .result import
|
|
15
|
+
from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput
|
|
16
16
|
from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
T = TypeVar('T')
|
|
@@ -20,8 +20,8 @@ T = TypeVar('T')
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@dataclass
|
|
23
|
-
class
|
|
24
|
-
function:
|
|
23
|
+
class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
24
|
+
function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
|
|
25
25
|
_takes_ctx: bool = field(init=False)
|
|
26
26
|
_is_async: bool = field(init=False)
|
|
27
27
|
|
|
@@ -77,47 +77,68 @@ class ToolRetryError(Exception):
|
|
|
77
77
|
|
|
78
78
|
|
|
79
79
|
@dataclass
|
|
80
|
-
class
|
|
80
|
+
class OutputSchema(Generic[OutputDataT]):
|
|
81
81
|
"""Model the final response from an agent run.
|
|
82
82
|
|
|
83
|
-
Similar to `Tool` but for the final
|
|
83
|
+
Similar to `Tool` but for the final output of running an agent.
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
|
-
tools: dict[str,
|
|
87
|
-
|
|
86
|
+
tools: dict[str, OutputSchemaTool[OutputDataT]]
|
|
87
|
+
allow_text_output: bool
|
|
88
88
|
|
|
89
89
|
@classmethod
|
|
90
90
|
def build(
|
|
91
|
-
cls: type[
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
cls: type[OutputSchema[T]],
|
|
92
|
+
output_type: type[T] | ToolOutput[T],
|
|
93
|
+
name: str | None = None,
|
|
94
|
+
description: str | None = None,
|
|
95
|
+
strict: bool | None = None,
|
|
96
|
+
) -> OutputSchema[T] | None:
|
|
97
|
+
"""Build an OutputSchema dataclass from a response type."""
|
|
98
|
+
if output_type is str:
|
|
95
99
|
return None
|
|
96
100
|
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
|
|
101
|
+
if isinstance(output_type, ToolOutput):
|
|
102
|
+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
103
|
+
name = output_type.name
|
|
104
|
+
description = output_type.description
|
|
105
|
+
output_type_ = output_type.output_type
|
|
106
|
+
strict = output_type.strict
|
|
100
107
|
else:
|
|
101
|
-
|
|
108
|
+
output_type_ = output_type
|
|
102
109
|
|
|
103
|
-
|
|
104
|
-
|
|
110
|
+
if output_type_option := extract_str_from_union(output_type):
|
|
111
|
+
output_type_ = output_type_option.value
|
|
112
|
+
allow_text_output = True
|
|
113
|
+
else:
|
|
114
|
+
allow_text_output = False
|
|
105
115
|
|
|
106
|
-
tools: dict[str,
|
|
107
|
-
if args := get_union_args(
|
|
116
|
+
tools: dict[str, OutputSchemaTool[T]] = {}
|
|
117
|
+
if args := get_union_args(output_type_):
|
|
108
118
|
for i, arg in enumerate(args, start=1):
|
|
109
|
-
tool_name = union_tool_name(name, arg)
|
|
119
|
+
tool_name = raw_tool_name = union_tool_name(name, arg)
|
|
110
120
|
while tool_name in tools:
|
|
111
|
-
tool_name = f'{
|
|
112
|
-
tools[tool_name] =
|
|
121
|
+
tool_name = f'{raw_tool_name}_{i}'
|
|
122
|
+
tools[tool_name] = cast(
|
|
123
|
+
OutputSchemaTool[T],
|
|
124
|
+
OutputSchemaTool(
|
|
125
|
+
output_type=arg, name=tool_name, description=description, multiple=True, strict=strict
|
|
126
|
+
),
|
|
127
|
+
)
|
|
113
128
|
else:
|
|
114
|
-
|
|
129
|
+
name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
130
|
+
tools[name] = cast(
|
|
131
|
+
OutputSchemaTool[T],
|
|
132
|
+
OutputSchemaTool(
|
|
133
|
+
output_type=output_type_, name=name, description=description, multiple=False, strict=strict
|
|
134
|
+
),
|
|
135
|
+
)
|
|
115
136
|
|
|
116
|
-
return cls(tools=tools,
|
|
137
|
+
return cls(tools=tools, allow_text_output=allow_text_output)
|
|
117
138
|
|
|
118
139
|
def find_named_tool(
|
|
119
140
|
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
120
|
-
) -> tuple[_messages.ToolCallPart,
|
|
141
|
+
) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None:
|
|
121
142
|
"""Find a tool that matches one of the calls, with a specific name."""
|
|
122
143
|
for part in parts:
|
|
123
144
|
if isinstance(part, _messages.ToolCallPart):
|
|
@@ -127,7 +148,7 @@ class ResultSchema(Generic[ResultDataT]):
|
|
|
127
148
|
def find_tool(
|
|
128
149
|
self,
|
|
129
150
|
parts: Iterable[_messages.ModelResponsePart],
|
|
130
|
-
) -> Iterator[tuple[_messages.ToolCallPart,
|
|
151
|
+
) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]:
|
|
131
152
|
"""Find a tool that matches one of the calls."""
|
|
132
153
|
for part in parts:
|
|
133
154
|
if isinstance(part, _messages.ToolCallPart):
|
|
@@ -147,16 +168,16 @@ DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
|
147
168
|
|
|
148
169
|
|
|
149
170
|
@dataclass(init=False)
|
|
150
|
-
class
|
|
171
|
+
class OutputSchemaTool(Generic[OutputDataT]):
|
|
151
172
|
tool_def: ToolDefinition
|
|
152
173
|
type_adapter: TypeAdapter[Any]
|
|
153
174
|
|
|
154
|
-
def __init__(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
if _utils.is_model_like(
|
|
159
|
-
self.type_adapter = TypeAdapter(
|
|
175
|
+
def __init__(
|
|
176
|
+
self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None
|
|
177
|
+
):
|
|
178
|
+
"""Build a OutputSchemaTool from a response type."""
|
|
179
|
+
if _utils.is_model_like(output_type):
|
|
180
|
+
self.type_adapter = TypeAdapter(output_type)
|
|
160
181
|
outer_typed_dict_key: str | None = None
|
|
161
182
|
# noinspection PyArgumentList
|
|
162
183
|
parameters_json_schema = _utils.check_object_json_schema(
|
|
@@ -165,7 +186,7 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
165
186
|
else:
|
|
166
187
|
response_data_typed_dict = TypedDict( # noqa: UP013
|
|
167
188
|
'response_data_typed_dict',
|
|
168
|
-
{'response':
|
|
189
|
+
{'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
|
|
169
190
|
)
|
|
170
191
|
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
171
192
|
outer_typed_dict_key = 'response'
|
|
@@ -184,19 +205,20 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
184
205
|
else:
|
|
185
206
|
tool_description = description or DEFAULT_DESCRIPTION
|
|
186
207
|
if multiple:
|
|
187
|
-
tool_description = f'{union_arg_name(
|
|
208
|
+
tool_description = f'{union_arg_name(output_type)}: {tool_description}'
|
|
188
209
|
|
|
189
210
|
self.tool_def = ToolDefinition(
|
|
190
211
|
name=name,
|
|
191
212
|
description=tool_description,
|
|
192
213
|
parameters_json_schema=parameters_json_schema,
|
|
193
214
|
outer_typed_dict_key=outer_typed_dict_key,
|
|
215
|
+
strict=strict,
|
|
194
216
|
)
|
|
195
217
|
|
|
196
218
|
def validate(
|
|
197
219
|
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
198
|
-
) ->
|
|
199
|
-
"""Validate
|
|
220
|
+
) -> OutputDataT:
|
|
221
|
+
"""Validate an output message.
|
|
200
222
|
|
|
201
223
|
Args:
|
|
202
224
|
tool_call: The tool call from the LLM to validate.
|
|
@@ -204,14 +226,14 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
204
226
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
205
227
|
|
|
206
228
|
Returns:
|
|
207
|
-
Either the validated
|
|
229
|
+
Either the validated output data (left) or a retry message (right).
|
|
208
230
|
"""
|
|
209
231
|
try:
|
|
210
232
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
211
233
|
if isinstance(tool_call.args, str):
|
|
212
|
-
|
|
234
|
+
output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
213
235
|
else:
|
|
214
|
-
|
|
236
|
+
output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
215
237
|
except ValidationError as e:
|
|
216
238
|
if wrap_validation_errors:
|
|
217
239
|
m = _messages.RetryPromptPart(
|
|
@@ -224,21 +246,21 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
224
246
|
raise
|
|
225
247
|
else:
|
|
226
248
|
if k := self.tool_def.outer_typed_dict_key:
|
|
227
|
-
|
|
228
|
-
return
|
|
249
|
+
output = output[k]
|
|
250
|
+
return output
|
|
229
251
|
|
|
230
252
|
|
|
231
|
-
def union_tool_name(base_name: str, union_arg: Any) -> str:
|
|
232
|
-
return f'{base_name}_{union_arg_name(union_arg)}'
|
|
253
|
+
def union_tool_name(base_name: str | None, union_arg: Any) -> str:
|
|
254
|
+
return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}'
|
|
233
255
|
|
|
234
256
|
|
|
235
257
|
def union_arg_name(union_arg: Any) -> str:
|
|
236
258
|
return union_arg.__name__
|
|
237
259
|
|
|
238
260
|
|
|
239
|
-
def extract_str_from_union(
|
|
261
|
+
def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
|
|
240
262
|
"""Extract the string type from a Union, return the remaining union or remaining type."""
|
|
241
|
-
union_args = get_union_args(
|
|
263
|
+
union_args = get_union_args(output_type)
|
|
242
264
|
if any(t is str for t in union_args):
|
|
243
265
|
remain_args: list[Any] = []
|
|
244
266
|
includes_str = False
|
|
@@ -255,7 +277,7 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
|
255
277
|
|
|
256
278
|
|
|
257
279
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
258
|
-
"""Extract the arguments of a Union type if `
|
|
280
|
+
"""Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
|
|
259
281
|
if typing_objects.is_typealiastype(tp):
|
|
260
282
|
tp = tp.__value__
|
|
261
283
|
|