pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +70 -59
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +511 -161
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +24 -4
- pydantic_ai/models/_json_schema.py +160 -0
- pydantic_ai/models/anthropic.py +5 -3
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +82 -75
- pydantic_ai/models/groq.py +32 -28
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +62 -58
- pydantic_ai/models/openai.py +110 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +4 -4
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.1.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.55.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/entry_points.txt +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from importlib.metadata import version
|
|
1
|
+
from importlib.metadata import version as _metadata_version
|
|
2
2
|
|
|
3
3
|
from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import (
|
|
@@ -10,7 +10,9 @@ from .exceptions import (
|
|
|
10
10
|
UsageLimitExceeded,
|
|
11
11
|
UserError,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from .format_prompt import format_as_xml
|
|
14
|
+
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
|
|
15
|
+
from .result import ToolOutput
|
|
14
16
|
from .tools import RunContext, Tool
|
|
15
17
|
|
|
16
18
|
__all__ = (
|
|
@@ -33,10 +35,15 @@ __all__ = (
|
|
|
33
35
|
# messages
|
|
34
36
|
'ImageUrl',
|
|
35
37
|
'AudioUrl',
|
|
38
|
+
'VideoUrl',
|
|
36
39
|
'DocumentUrl',
|
|
37
40
|
'BinaryContent',
|
|
38
41
|
# tools
|
|
39
42
|
'Tool',
|
|
40
43
|
'RunContext',
|
|
44
|
+
# result
|
|
45
|
+
'ToolOutput',
|
|
46
|
+
# format_prompt
|
|
47
|
+
'format_as_xml',
|
|
41
48
|
)
|
|
42
|
-
__version__ =
|
|
49
|
+
__version__ = _metadata_version('pydantic_ai_slim')
|
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -3,11 +3,11 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
import json
|
|
6
|
-
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
6
|
+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
10
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
11
11
|
|
|
12
12
|
from opentelemetry.trace import Span, Tracer
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
@@ -16,7 +16,7 @@ from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
|
16
16
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
17
|
|
|
18
18
|
from . import (
|
|
19
|
-
|
|
19
|
+
_output,
|
|
20
20
|
_system_prompt,
|
|
21
21
|
exceptions,
|
|
22
22
|
messages as _messages,
|
|
@@ -25,7 +25,7 @@ from . import (
|
|
|
25
25
|
usage as _usage,
|
|
26
26
|
)
|
|
27
27
|
from .models.instrumented import InstrumentedModel
|
|
28
|
-
from .result import
|
|
28
|
+
from .result import OutputDataT, ToolOutput
|
|
29
29
|
from .settings import ModelSettings, merge_model_settings
|
|
30
30
|
from .tools import RunContext, Tool, ToolDefinition
|
|
31
31
|
|
|
@@ -53,7 +53,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
53
53
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
54
54
|
"""
|
|
55
55
|
DepsT = TypeVar('DepsT')
|
|
56
|
-
|
|
56
|
+
OutputT = TypeVar('OutputT')
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@dataclasses.dataclass
|
|
@@ -74,7 +74,7 @@ class GraphAgentState:
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
@dataclasses.dataclass
|
|
77
|
-
class GraphAgentDeps(Generic[DepsT,
|
|
77
|
+
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
78
78
|
"""Dependencies/config passed to the agent graph."""
|
|
79
79
|
|
|
80
80
|
user_deps: DepsT
|
|
@@ -87,10 +87,10 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
87
87
|
usage_limits: _usage.UsageLimits
|
|
88
88
|
max_result_retries: int
|
|
89
89
|
end_strategy: EndStrategy
|
|
90
|
+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
|
|
92
|
+
output_schema: _output.OutputSchema[OutputDataT] | None
|
|
93
|
+
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
|
|
94
94
|
|
|
95
95
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
96
96
|
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
@@ -126,6 +126,9 @@ def is_agent_node(
|
|
|
126
126
|
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
127
127
|
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
128
128
|
|
|
129
|
+
instructions: str | None
|
|
130
|
+
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
131
|
+
|
|
129
132
|
system_prompts: tuple[str, ...]
|
|
130
133
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
131
134
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
@@ -139,7 +142,9 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
139
142
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
140
143
|
) -> _messages.ModelRequest:
|
|
141
144
|
run_context = build_run_context(ctx)
|
|
142
|
-
history, next_message = await self._prepare_messages(
|
|
145
|
+
history, next_message = await self._prepare_messages(
|
|
146
|
+
self.user_prompt, ctx.state.message_history, ctx.deps.get_instructions, run_context
|
|
147
|
+
)
|
|
143
148
|
ctx.state.message_history = history
|
|
144
149
|
run_context.messages = history
|
|
145
150
|
|
|
@@ -153,6 +158,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
153
158
|
self,
|
|
154
159
|
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
155
160
|
message_history: list[_messages.ModelMessage] | None,
|
|
161
|
+
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]],
|
|
156
162
|
run_context: RunContext[DepsT],
|
|
157
163
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
158
164
|
try:
|
|
@@ -167,6 +173,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
167
173
|
ctx_messages.used = True
|
|
168
174
|
|
|
169
175
|
parts: list[_messages.ModelRequestPart] = []
|
|
176
|
+
instructions = await get_instructions(run_context)
|
|
170
177
|
if message_history:
|
|
171
178
|
# Shallow copy messages
|
|
172
179
|
messages.extend(message_history)
|
|
@@ -177,7 +184,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
177
184
|
|
|
178
185
|
if user_prompt is not None:
|
|
179
186
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
180
|
-
return messages, _messages.ModelRequest(parts)
|
|
187
|
+
return messages, _messages.ModelRequest(parts, instructions=instructions)
|
|
181
188
|
|
|
182
189
|
async def _reevaluate_dynamic_prompts(
|
|
183
190
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -233,11 +240,11 @@ async def _prepare_request_parameters(
|
|
|
233
240
|
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
234
241
|
)
|
|
235
242
|
|
|
236
|
-
|
|
243
|
+
output_schema = ctx.deps.output_schema
|
|
237
244
|
return models.ModelRequestParameters(
|
|
238
245
|
function_tools=function_tool_defs,
|
|
239
|
-
|
|
240
|
-
|
|
246
|
+
allow_text_output=allow_text_output(output_schema),
|
|
247
|
+
output_tools=output_schema.tool_defs() if output_schema is not None else [],
|
|
241
248
|
)
|
|
242
249
|
|
|
243
250
|
|
|
@@ -271,8 +278,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
271
278
|
async with self._stream(ctx) as streamed_response:
|
|
272
279
|
agent_stream = result.AgentStream[DepsT, T](
|
|
273
280
|
streamed_response,
|
|
274
|
-
ctx.deps.
|
|
275
|
-
ctx.deps.
|
|
281
|
+
ctx.deps.output_schema,
|
|
282
|
+
ctx.deps.output_validators,
|
|
276
283
|
build_run_context(ctx),
|
|
277
284
|
ctx.deps.usage_limits,
|
|
278
285
|
)
|
|
@@ -290,6 +297,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
290
297
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
291
298
|
|
|
292
299
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
300
|
+
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
293
301
|
async with ctx.deps.model.request_stream(
|
|
294
302
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
295
303
|
) as streamed_response:
|
|
@@ -431,17 +439,17 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
431
439
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
432
440
|
tool_calls: list[_messages.ToolCallPart],
|
|
433
441
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
434
|
-
|
|
442
|
+
output_schema = ctx.deps.output_schema
|
|
435
443
|
|
|
436
|
-
# first look for the
|
|
444
|
+
# first, look for the output tool call
|
|
437
445
|
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
438
446
|
parts: list[_messages.ModelRequestPart] = []
|
|
439
|
-
if
|
|
440
|
-
for call,
|
|
447
|
+
if output_schema is not None:
|
|
448
|
+
for call, output_tool in output_schema.find_tool(tool_calls):
|
|
441
449
|
try:
|
|
442
|
-
result_data =
|
|
443
|
-
result_data = await
|
|
444
|
-
except
|
|
450
|
+
result_data = output_tool.validate(call)
|
|
451
|
+
result_data = await _validate_output(result_data, ctx, call)
|
|
452
|
+
except _output.ToolRetryError as e:
|
|
445
453
|
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
446
454
|
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
447
455
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
@@ -466,7 +474,11 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
466
474
|
else:
|
|
467
475
|
if tool_responses:
|
|
468
476
|
parts.extend(tool_responses)
|
|
469
|
-
|
|
477
|
+
run_context = build_run_context(ctx)
|
|
478
|
+
instructions = await ctx.deps.get_instructions(run_context)
|
|
479
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
|
|
480
|
+
_messages.ModelRequest(parts=parts, instructions=instructions)
|
|
481
|
+
)
|
|
470
482
|
|
|
471
483
|
def _handle_final_result(
|
|
472
484
|
self,
|
|
@@ -488,9 +500,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
488
500
|
'all_messages_events': json.dumps(
|
|
489
501
|
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)]
|
|
490
502
|
),
|
|
491
|
-
'final_result': final_result.
|
|
492
|
-
if isinstance(final_result.
|
|
493
|
-
else json.dumps(InstrumentedModel.serialize_any(final_result.
|
|
503
|
+
'final_result': final_result.output
|
|
504
|
+
if isinstance(final_result.output, str)
|
|
505
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.output)),
|
|
494
506
|
}
|
|
495
507
|
)
|
|
496
508
|
run_span.set_attributes(
|
|
@@ -507,7 +519,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
507
519
|
}
|
|
508
520
|
)
|
|
509
521
|
|
|
510
|
-
# End the run with self.data
|
|
511
522
|
return End(final_result)
|
|
512
523
|
|
|
513
524
|
async def _handle_text_response(
|
|
@@ -515,14 +526,14 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
515
526
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
516
527
|
texts: list[str],
|
|
517
528
|
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
518
|
-
|
|
529
|
+
output_schema = ctx.deps.output_schema
|
|
519
530
|
|
|
520
531
|
text = '\n\n'.join(texts)
|
|
521
|
-
if
|
|
532
|
+
if allow_text_output(output_schema):
|
|
522
533
|
result_data_input = cast(NodeRunEndT, text)
|
|
523
534
|
try:
|
|
524
|
-
result_data = await
|
|
525
|
-
except
|
|
535
|
+
result_data = await _validate_output(result_data_input, ctx, None)
|
|
536
|
+
except _output.ToolRetryError as e:
|
|
526
537
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
527
538
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
528
539
|
else:
|
|
@@ -534,7 +545,7 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
534
545
|
_messages.ModelRequest(
|
|
535
546
|
parts=[
|
|
536
547
|
_messages.RetryPromptPart(
|
|
537
|
-
content='Plain text responses are not permitted, please
|
|
548
|
+
content='Plain text responses are not permitted, please include your response in a tool call',
|
|
538
549
|
)
|
|
539
550
|
]
|
|
540
551
|
)
|
|
@@ -555,8 +566,8 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
555
566
|
|
|
556
567
|
async def process_function_tools(
|
|
557
568
|
tool_calls: list[_messages.ToolCallPart],
|
|
558
|
-
|
|
559
|
-
|
|
569
|
+
output_tool_name: str | None,
|
|
570
|
+
output_tool_call_id: str | None,
|
|
560
571
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
561
572
|
output_parts: list[_messages.ModelRequestPart],
|
|
562
573
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
@@ -566,22 +577,22 @@ async def process_function_tools(
|
|
|
566
577
|
|
|
567
578
|
Because async iterators can't have return values, we use `output_parts` as an output argument.
|
|
568
579
|
"""
|
|
569
|
-
stub_function_tools = bool(
|
|
570
|
-
|
|
580
|
+
stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early'
|
|
581
|
+
output_schema = ctx.deps.output_schema
|
|
571
582
|
|
|
572
|
-
# we rely on the fact that if we found a result, it's the first
|
|
573
|
-
|
|
583
|
+
# we rely on the fact that if we found a result, it's the first output tool in the last
|
|
584
|
+
found_used_output_tool = False
|
|
574
585
|
run_context = build_run_context(ctx)
|
|
575
586
|
|
|
576
587
|
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
577
588
|
call_index_to_event_id: dict[int, str] = {}
|
|
578
589
|
for call in tool_calls:
|
|
579
590
|
if (
|
|
580
|
-
call.tool_name ==
|
|
581
|
-
and call.tool_call_id ==
|
|
582
|
-
and not
|
|
591
|
+
call.tool_name == output_tool_name
|
|
592
|
+
and call.tool_call_id == output_tool_call_id
|
|
593
|
+
and not found_used_output_tool
|
|
583
594
|
):
|
|
584
|
-
|
|
595
|
+
found_used_output_tool = True
|
|
585
596
|
output_parts.append(
|
|
586
597
|
_messages.ToolReturnPart(
|
|
587
598
|
tool_name=call.tool_name,
|
|
@@ -618,15 +629,15 @@ async def process_function_tools(
|
|
|
618
629
|
yield event
|
|
619
630
|
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
620
631
|
calls_to_run.append((mcp_tool, call))
|
|
621
|
-
elif
|
|
622
|
-
# if tool_name is in
|
|
632
|
+
elif output_schema is not None and call.tool_name in output_schema.tools:
|
|
633
|
+
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
|
|
623
634
|
# validation, we don't add another part here
|
|
624
|
-
if
|
|
625
|
-
if
|
|
626
|
-
content = '
|
|
635
|
+
if output_tool_name is not None:
|
|
636
|
+
if found_used_output_tool:
|
|
637
|
+
content = 'Output tool not used - a final result was already processed.'
|
|
627
638
|
else:
|
|
628
639
|
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
629
|
-
content = '
|
|
640
|
+
content = 'Output tool not used - result failed validation.'
|
|
630
641
|
part = _messages.ToolReturnPart(
|
|
631
642
|
tool_name=call.tool_name,
|
|
632
643
|
content=content,
|
|
@@ -706,8 +717,8 @@ def _unknown_tool(
|
|
|
706
717
|
) -> _messages.RetryPromptPart:
|
|
707
718
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
708
719
|
tool_names = list(ctx.deps.function_tools.keys())
|
|
709
|
-
if
|
|
710
|
-
tool_names.extend(
|
|
720
|
+
if output_schema := ctx.deps.output_schema:
|
|
721
|
+
tool_names.extend(output_schema.tool_names())
|
|
711
722
|
|
|
712
723
|
if tool_names:
|
|
713
724
|
msg = f'Available tools: {", ".join(tool_names)}'
|
|
@@ -721,20 +732,20 @@ def _unknown_tool(
|
|
|
721
732
|
)
|
|
722
733
|
|
|
723
734
|
|
|
724
|
-
async def
|
|
735
|
+
async def _validate_output(
|
|
725
736
|
result_data: T,
|
|
726
737
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
727
738
|
tool_call: _messages.ToolCallPart | None,
|
|
728
739
|
) -> T:
|
|
729
|
-
for validator in ctx.deps.
|
|
740
|
+
for validator in ctx.deps.output_validators:
|
|
730
741
|
run_context = build_run_context(ctx)
|
|
731
742
|
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
732
743
|
return result_data
|
|
733
744
|
|
|
734
745
|
|
|
735
|
-
def
|
|
746
|
+
def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
|
|
736
747
|
"""Check if the result schema allows text results."""
|
|
737
|
-
return
|
|
748
|
+
return output_schema is None or output_schema.allow_text_output
|
|
738
749
|
|
|
739
750
|
|
|
740
751
|
@dataclasses.dataclass
|
|
@@ -786,19 +797,19 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
786
797
|
|
|
787
798
|
|
|
788
799
|
def build_agent_graph(
|
|
789
|
-
name: str | None, deps_type: type[DepsT],
|
|
790
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[
|
|
800
|
+
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
|
|
801
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
|
|
791
802
|
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
792
803
|
nodes = (
|
|
793
804
|
UserPromptNode[DepsT],
|
|
794
805
|
ModelRequestNode[DepsT],
|
|
795
806
|
CallToolsNode[DepsT],
|
|
796
807
|
)
|
|
797
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[
|
|
808
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
|
|
798
809
|
nodes=nodes,
|
|
799
810
|
name=name or 'Agent',
|
|
800
811
|
state_type=GraphAgentState,
|
|
801
|
-
run_end_type=result.FinalResult[
|
|
812
|
+
run_end_type=result.FinalResult[OutputT],
|
|
802
813
|
auto_instrument=False,
|
|
803
814
|
)
|
|
804
815
|
return graph
|
pydantic_ai/_cli.py
CHANGED
|
@@ -208,14 +208,13 @@ async def ask_agent(
|
|
|
208
208
|
if not stream:
|
|
209
209
|
with status:
|
|
210
210
|
result = await agent.run(prompt, message_history=messages)
|
|
211
|
-
content = result.
|
|
211
|
+
content = result.output
|
|
212
212
|
console.print(Markdown(content, code_theme=code_theme))
|
|
213
213
|
return result.all_messages()
|
|
214
214
|
|
|
215
215
|
with status, ExitStack() as stack:
|
|
216
216
|
async with agent.iter(prompt, message_history=messages) as agent_run:
|
|
217
217
|
live = Live('', refresh_per_second=15, console=console, vertical_overflow='visible')
|
|
218
|
-
content: str = ''
|
|
219
218
|
async for node in agent_run:
|
|
220
219
|
if Agent.is_model_request_node(node):
|
|
221
220
|
async with node.stream(agent_run.ctx) as handle_stream:
|
|
@@ -12,7 +12,7 @@ from typing_inspection.introspection import is_union_origin
|
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .result import
|
|
15
|
+
from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput
|
|
16
16
|
from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
|
|
17
17
|
|
|
18
18
|
T = TypeVar('T')
|
|
@@ -20,8 +20,8 @@ T = TypeVar('T')
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@dataclass
|
|
23
|
-
class
|
|
24
|
-
function:
|
|
23
|
+
class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
|
|
24
|
+
function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv]
|
|
25
25
|
_takes_ctx: bool = field(init=False)
|
|
26
26
|
_is_async: bool = field(init=False)
|
|
27
27
|
|
|
@@ -77,47 +77,68 @@ class ToolRetryError(Exception):
|
|
|
77
77
|
|
|
78
78
|
|
|
79
79
|
@dataclass
|
|
80
|
-
class
|
|
80
|
+
class OutputSchema(Generic[OutputDataT]):
|
|
81
81
|
"""Model the final response from an agent run.
|
|
82
82
|
|
|
83
|
-
Similar to `Tool` but for the final
|
|
83
|
+
Similar to `Tool` but for the final output of running an agent.
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
|
-
tools: dict[str,
|
|
87
|
-
|
|
86
|
+
tools: dict[str, OutputSchemaTool[OutputDataT]]
|
|
87
|
+
allow_text_output: bool
|
|
88
88
|
|
|
89
89
|
@classmethod
|
|
90
90
|
def build(
|
|
91
|
-
cls: type[
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
cls: type[OutputSchema[T]],
|
|
92
|
+
output_type: type[T] | ToolOutput[T],
|
|
93
|
+
name: str | None = None,
|
|
94
|
+
description: str | None = None,
|
|
95
|
+
strict: bool | None = None,
|
|
96
|
+
) -> OutputSchema[T] | None:
|
|
97
|
+
"""Build an OutputSchema dataclass from a response type."""
|
|
98
|
+
if output_type is str:
|
|
95
99
|
return None
|
|
96
100
|
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
|
|
101
|
+
if isinstance(output_type, ToolOutput):
|
|
102
|
+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
|
|
103
|
+
name = output_type.name
|
|
104
|
+
description = output_type.description
|
|
105
|
+
output_type_ = output_type.output_type
|
|
106
|
+
strict = output_type.strict
|
|
100
107
|
else:
|
|
101
|
-
|
|
108
|
+
output_type_ = output_type
|
|
102
109
|
|
|
103
|
-
|
|
104
|
-
|
|
110
|
+
if output_type_option := extract_str_from_union(output_type):
|
|
111
|
+
output_type_ = output_type_option.value
|
|
112
|
+
allow_text_output = True
|
|
113
|
+
else:
|
|
114
|
+
allow_text_output = False
|
|
105
115
|
|
|
106
|
-
tools: dict[str,
|
|
107
|
-
if args := get_union_args(
|
|
116
|
+
tools: dict[str, OutputSchemaTool[T]] = {}
|
|
117
|
+
if args := get_union_args(output_type_):
|
|
108
118
|
for i, arg in enumerate(args, start=1):
|
|
109
|
-
tool_name = union_tool_name(name, arg)
|
|
119
|
+
tool_name = raw_tool_name = union_tool_name(name, arg)
|
|
110
120
|
while tool_name in tools:
|
|
111
|
-
tool_name = f'{
|
|
112
|
-
tools[tool_name] =
|
|
121
|
+
tool_name = f'{raw_tool_name}_{i}'
|
|
122
|
+
tools[tool_name] = cast(
|
|
123
|
+
OutputSchemaTool[T],
|
|
124
|
+
OutputSchemaTool(
|
|
125
|
+
output_type=arg, name=tool_name, description=description, multiple=True, strict=strict
|
|
126
|
+
),
|
|
127
|
+
)
|
|
113
128
|
else:
|
|
114
|
-
|
|
129
|
+
name = name or DEFAULT_OUTPUT_TOOL_NAME
|
|
130
|
+
tools[name] = cast(
|
|
131
|
+
OutputSchemaTool[T],
|
|
132
|
+
OutputSchemaTool(
|
|
133
|
+
output_type=output_type_, name=name, description=description, multiple=False, strict=strict
|
|
134
|
+
),
|
|
135
|
+
)
|
|
115
136
|
|
|
116
|
-
return cls(tools=tools,
|
|
137
|
+
return cls(tools=tools, allow_text_output=allow_text_output)
|
|
117
138
|
|
|
118
139
|
def find_named_tool(
|
|
119
140
|
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
120
|
-
) -> tuple[_messages.ToolCallPart,
|
|
141
|
+
) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None:
|
|
121
142
|
"""Find a tool that matches one of the calls, with a specific name."""
|
|
122
143
|
for part in parts:
|
|
123
144
|
if isinstance(part, _messages.ToolCallPart):
|
|
@@ -127,7 +148,7 @@ class ResultSchema(Generic[ResultDataT]):
|
|
|
127
148
|
def find_tool(
|
|
128
149
|
self,
|
|
129
150
|
parts: Iterable[_messages.ModelResponsePart],
|
|
130
|
-
) -> Iterator[tuple[_messages.ToolCallPart,
|
|
151
|
+
) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]:
|
|
131
152
|
"""Find a tool that matches one of the calls."""
|
|
132
153
|
for part in parts:
|
|
133
154
|
if isinstance(part, _messages.ToolCallPart):
|
|
@@ -147,16 +168,16 @@ DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
|
147
168
|
|
|
148
169
|
|
|
149
170
|
@dataclass(init=False)
|
|
150
|
-
class
|
|
171
|
+
class OutputSchemaTool(Generic[OutputDataT]):
|
|
151
172
|
tool_def: ToolDefinition
|
|
152
173
|
type_adapter: TypeAdapter[Any]
|
|
153
174
|
|
|
154
|
-
def __init__(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
if _utils.is_model_like(
|
|
159
|
-
self.type_adapter = TypeAdapter(
|
|
175
|
+
def __init__(
|
|
176
|
+
self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None
|
|
177
|
+
):
|
|
178
|
+
"""Build a OutputSchemaTool from a response type."""
|
|
179
|
+
if _utils.is_model_like(output_type):
|
|
180
|
+
self.type_adapter = TypeAdapter(output_type)
|
|
160
181
|
outer_typed_dict_key: str | None = None
|
|
161
182
|
# noinspection PyArgumentList
|
|
162
183
|
parameters_json_schema = _utils.check_object_json_schema(
|
|
@@ -165,7 +186,7 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
165
186
|
else:
|
|
166
187
|
response_data_typed_dict = TypedDict( # noqa: UP013
|
|
167
188
|
'response_data_typed_dict',
|
|
168
|
-
{'response':
|
|
189
|
+
{'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
|
|
169
190
|
)
|
|
170
191
|
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
171
192
|
outer_typed_dict_key = 'response'
|
|
@@ -184,19 +205,20 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
184
205
|
else:
|
|
185
206
|
tool_description = description or DEFAULT_DESCRIPTION
|
|
186
207
|
if multiple:
|
|
187
|
-
tool_description = f'{union_arg_name(
|
|
208
|
+
tool_description = f'{union_arg_name(output_type)}: {tool_description}'
|
|
188
209
|
|
|
189
210
|
self.tool_def = ToolDefinition(
|
|
190
211
|
name=name,
|
|
191
212
|
description=tool_description,
|
|
192
213
|
parameters_json_schema=parameters_json_schema,
|
|
193
214
|
outer_typed_dict_key=outer_typed_dict_key,
|
|
215
|
+
strict=strict,
|
|
194
216
|
)
|
|
195
217
|
|
|
196
218
|
def validate(
|
|
197
219
|
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
198
|
-
) ->
|
|
199
|
-
"""Validate
|
|
220
|
+
) -> OutputDataT:
|
|
221
|
+
"""Validate an output message.
|
|
200
222
|
|
|
201
223
|
Args:
|
|
202
224
|
tool_call: The tool call from the LLM to validate.
|
|
@@ -204,14 +226,14 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
204
226
|
wrap_validation_errors: If true, wrap the validation errors in a retry message.
|
|
205
227
|
|
|
206
228
|
Returns:
|
|
207
|
-
Either the validated
|
|
229
|
+
Either the validated output data (left) or a retry message (right).
|
|
208
230
|
"""
|
|
209
231
|
try:
|
|
210
232
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
211
233
|
if isinstance(tool_call.args, str):
|
|
212
|
-
|
|
234
|
+
output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
213
235
|
else:
|
|
214
|
-
|
|
236
|
+
output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
215
237
|
except ValidationError as e:
|
|
216
238
|
if wrap_validation_errors:
|
|
217
239
|
m = _messages.RetryPromptPart(
|
|
@@ -224,21 +246,21 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
224
246
|
raise
|
|
225
247
|
else:
|
|
226
248
|
if k := self.tool_def.outer_typed_dict_key:
|
|
227
|
-
|
|
228
|
-
return
|
|
249
|
+
output = output[k]
|
|
250
|
+
return output
|
|
229
251
|
|
|
230
252
|
|
|
231
|
-
def union_tool_name(base_name: str, union_arg: Any) -> str:
|
|
232
|
-
return f'{base_name}_{union_arg_name(union_arg)}'
|
|
253
|
+
def union_tool_name(base_name: str | None, union_arg: Any) -> str:
|
|
254
|
+
return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}'
|
|
233
255
|
|
|
234
256
|
|
|
235
257
|
def union_arg_name(union_arg: Any) -> str:
|
|
236
258
|
return union_arg.__name__
|
|
237
259
|
|
|
238
260
|
|
|
239
|
-
def extract_str_from_union(
|
|
261
|
+
def extract_str_from_union(output_type: Any) -> _utils.Option[Any]:
|
|
240
262
|
"""Extract the string type from a Union, return the remaining union or remaining type."""
|
|
241
|
-
union_args = get_union_args(
|
|
263
|
+
union_args = get_union_args(output_type)
|
|
242
264
|
if any(t is str for t in union_args):
|
|
243
265
|
remain_args: list[Any] = []
|
|
244
266
|
includes_str = False
|
|
@@ -255,7 +277,7 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
|
255
277
|
|
|
256
278
|
|
|
257
279
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
258
|
-
"""Extract the arguments of a Union type if `
|
|
280
|
+
"""Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple."""
|
|
259
281
|
if typing_objects.is_typealiastype(tp):
|
|
260
282
|
tp = tp.__value__
|
|
261
283
|
|
pydantic_ai/_utils.py
CHANGED
|
@@ -15,7 +15,14 @@ from pydantic import BaseModel
|
|
|
15
15
|
from pydantic.json_schema import JsonSchemaValue
|
|
16
16
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
17
17
|
|
|
18
|
+
from pydantic_graph._utils import AbstractSpan
|
|
19
|
+
|
|
20
|
+
AbstractSpan = AbstractSpan
|
|
21
|
+
|
|
18
22
|
if TYPE_CHECKING:
|
|
23
|
+
from pydantic_ai.agent import AgentRun, AgentRunResult
|
|
24
|
+
from pydantic_graph import GraphRun, GraphRunResult
|
|
25
|
+
|
|
19
26
|
from . import messages as _messages
|
|
20
27
|
from .tools import ObjectJsonSchema
|
|
21
28
|
|
|
@@ -281,3 +288,16 @@ class PeekableAsyncStream(Generic[T]):
|
|
|
281
288
|
except StopAsyncIteration:
|
|
282
289
|
self._exhausted = True
|
|
283
290
|
raise
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
|
|
294
|
+
import logfire
|
|
295
|
+
import logfire_api
|
|
296
|
+
from logfire.experimental.annotations import get_traceparent
|
|
297
|
+
|
|
298
|
+
span: AbstractSpan | None = x._span(required=False) # type: ignore[reportPrivateUsage]
|
|
299
|
+
if not span: # pragma: no cover
|
|
300
|
+
return ''
|
|
301
|
+
if isinstance(span, logfire_api.LogfireSpan): # pragma: no cover
|
|
302
|
+
assert isinstance(span, logfire.LogfireSpan)
|
|
303
|
+
return get_traceparent(span)
|