pydantic-ai-slim 0.4.5__tar.gz → 0.4.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/.gitignore +0 -1
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/PKG-INFO +4 -4
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_function_schema.py +13 -4
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_output.py +41 -25
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/agent.py +9 -29
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/__init__.py +1 -1
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/function.py +15 -16
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/mistral.py +12 -2
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/result.py +115 -151
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pyproject.toml +1 -1
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/LICENSE +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/README.md +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_a2a.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_mcp.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_run_context.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_thinking_part.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_tool_manager.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/ag_ui.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/direct.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/aci.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/ext/langchain.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/gemini.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/google.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/huggingface.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/mcp_sampling.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/output.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/_json_schema.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/amazon.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/anthropic.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/cohere.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/deepseek.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/google.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/grok.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/meta.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/mistral.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/moonshotai.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/openai.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/profiles/qwen.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/fireworks.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/github.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/grok.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/heroku.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/huggingface.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/openrouter.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/providers/together.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/__init__.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/abstract.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/combined.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/deferred.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/filtered.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/function.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/prefixed.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/prepared.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/renamed.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/toolsets/wrapper.py +0 -0
- {pydantic_ai_slim-0.4.5 → pydantic_ai_slim-0.4.6}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.6
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>, Douwe Maan <douwe@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -30,7 +30,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
30
30
|
Requires-Dist: griffe>=1.3.2
|
|
31
31
|
Requires-Dist: httpx>=0.27
|
|
32
32
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
33
|
-
Requires-Dist: pydantic-graph==0.4.
|
|
33
|
+
Requires-Dist: pydantic-graph==0.4.6
|
|
34
34
|
Requires-Dist: pydantic>=2.10
|
|
35
35
|
Requires-Dist: typing-inspection>=0.4.0
|
|
36
36
|
Provides-Extra: a2a
|
|
@@ -51,7 +51,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
51
51
|
Provides-Extra: duckduckgo
|
|
52
52
|
Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
|
|
53
53
|
Provides-Extra: evals
|
|
54
|
-
Requires-Dist: pydantic-evals==0.4.
|
|
54
|
+
Requires-Dist: pydantic-evals==0.4.6; extra == 'evals'
|
|
55
55
|
Provides-Extra: google
|
|
56
56
|
Requires-Dist: google-genai>=1.24.0; extra == 'google'
|
|
57
57
|
Provides-Extra: groq
|
|
@@ -63,7 +63,7 @@ Requires-Dist: logfire>=3.11.0; extra == 'logfire'
|
|
|
63
63
|
Provides-Extra: mcp
|
|
64
64
|
Requires-Dist: mcp>=1.9.4; (python_version >= '3.10') and extra == 'mcp'
|
|
65
65
|
Provides-Extra: mistral
|
|
66
|
-
Requires-Dist: mistralai>=1.2
|
|
66
|
+
Requires-Dist: mistralai>=1.9.2; extra == 'mistral'
|
|
67
67
|
Provides-Extra: openai
|
|
68
68
|
Requires-Dist: openai>=1.92.0; extra == 'openai'
|
|
69
69
|
Provides-Extra: tavily
|
|
@@ -96,8 +96,13 @@ def function_schema( # noqa: C901
|
|
|
96
96
|
config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
|
|
97
97
|
config_wrapper = ConfigWrapper(config)
|
|
98
98
|
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
|
|
99
|
+
errors: list[str] = []
|
|
99
100
|
|
|
100
|
-
|
|
101
|
+
try:
|
|
102
|
+
sig = signature(function)
|
|
103
|
+
except ValueError as e:
|
|
104
|
+
errors.append(str(e))
|
|
105
|
+
sig = signature(lambda: None)
|
|
101
106
|
|
|
102
107
|
type_hints = _typing_extra.get_function_type_hints(function)
|
|
103
108
|
|
|
@@ -105,7 +110,6 @@ def function_schema( # noqa: C901
|
|
|
105
110
|
fields: dict[str, core_schema.TypedDictField] = {}
|
|
106
111
|
positional_fields: list[str] = []
|
|
107
112
|
var_positional_field: str | None = None
|
|
108
|
-
errors: list[str] = []
|
|
109
113
|
decorators = _decorators.DecoratorInfos()
|
|
110
114
|
|
|
111
115
|
description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
|
|
@@ -235,14 +239,19 @@ def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
|
|
|
235
239
|
Returns:
|
|
236
240
|
`True` if the function takes a `RunContext` as first argument, `False` otherwise.
|
|
237
241
|
"""
|
|
238
|
-
|
|
242
|
+
try:
|
|
243
|
+
sig = signature(function)
|
|
244
|
+
except ValueError: # pragma: no cover
|
|
245
|
+
return False # pragma: no cover
|
|
239
246
|
try:
|
|
240
247
|
first_param_name = next(iter(sig.parameters.keys()))
|
|
241
248
|
except StopIteration:
|
|
242
249
|
return False
|
|
243
250
|
else:
|
|
244
251
|
type_hints = _typing_extra.get_function_type_hints(function)
|
|
245
|
-
annotation = type_hints
|
|
252
|
+
annotation = type_hints.get(first_param_name)
|
|
253
|
+
if annotation is None:
|
|
254
|
+
return False # pragma: no cover
|
|
246
255
|
return True is not sig.empty and _is_call_ctx(annotation)
|
|
247
256
|
|
|
248
257
|
|
|
@@ -69,12 +69,31 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
|
|
|
69
69
|
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
async def
|
|
72
|
+
async def execute_traced_output_function(
|
|
73
73
|
function_schema: _function_schema.FunctionSchema,
|
|
74
74
|
run_context: RunContext[AgentDepsT],
|
|
75
75
|
args: dict[str, Any] | Any,
|
|
76
|
+
wrap_validation_errors: bool = True,
|
|
76
77
|
) -> Any:
|
|
77
|
-
"""Execute
|
|
78
|
+
"""Execute an output function within a traced span with error handling.
|
|
79
|
+
|
|
80
|
+
This function executes the output function within an OpenTelemetry span for observability,
|
|
81
|
+
automatically records the function response, and handles ModelRetry exceptions by converting
|
|
82
|
+
them to ToolRetryError when wrap_validation_errors is True.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
function_schema: The function schema containing the function to execute
|
|
86
|
+
run_context: The current run context containing tracing and tool information
|
|
87
|
+
args: Arguments to pass to the function
|
|
88
|
+
wrap_validation_errors: If True, wrap ModelRetry exceptions in ToolRetryError
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
The result of the function execution
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught
|
|
95
|
+
ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs
|
|
96
|
+
"""
|
|
78
97
|
# Set up span attributes
|
|
79
98
|
tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
|
|
80
99
|
attributes = {
|
|
@@ -96,7 +115,19 @@ async def execute_output_function_with_span(
|
|
|
96
115
|
)
|
|
97
116
|
|
|
98
117
|
with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
|
|
99
|
-
|
|
118
|
+
try:
|
|
119
|
+
output = await function_schema.call(args, run_context)
|
|
120
|
+
except ModelRetry as r:
|
|
121
|
+
if wrap_validation_errors:
|
|
122
|
+
m = _messages.RetryPromptPart(
|
|
123
|
+
content=r.message,
|
|
124
|
+
tool_name=run_context.tool_name,
|
|
125
|
+
)
|
|
126
|
+
if run_context.tool_call_id:
|
|
127
|
+
m.tool_call_id = run_context.tool_call_id # pragma: no cover
|
|
128
|
+
raise ToolRetryError(m) from r
|
|
129
|
+
else:
|
|
130
|
+
raise
|
|
100
131
|
|
|
101
132
|
# Record response if content inclusion is enabled
|
|
102
133
|
if run_context.trace_include_content and span.is_recording():
|
|
@@ -663,16 +694,7 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
663
694
|
else:
|
|
664
695
|
raise
|
|
665
696
|
|
|
666
|
-
|
|
667
|
-
output = await self.call(output, run_context)
|
|
668
|
-
except ModelRetry as r:
|
|
669
|
-
if wrap_validation_errors:
|
|
670
|
-
m = _messages.RetryPromptPart(
|
|
671
|
-
content=r.message,
|
|
672
|
-
)
|
|
673
|
-
raise ToolRetryError(m) from r
|
|
674
|
-
else:
|
|
675
|
-
raise # pragma: no cover
|
|
697
|
+
output = await self.call(output, run_context, wrap_validation_errors)
|
|
676
698
|
|
|
677
699
|
return output
|
|
678
700
|
|
|
@@ -691,12 +713,15 @@ class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
691
713
|
self,
|
|
692
714
|
output: Any,
|
|
693
715
|
run_context: RunContext[AgentDepsT],
|
|
716
|
+
wrap_validation_errors: bool = True,
|
|
694
717
|
):
|
|
695
718
|
if k := self.outer_typed_dict_key:
|
|
696
719
|
output = output[k]
|
|
697
720
|
|
|
698
721
|
if self._function_schema:
|
|
699
|
-
output = await
|
|
722
|
+
output = await execute_traced_output_function(
|
|
723
|
+
self._function_schema, run_context, output, wrap_validation_errors
|
|
724
|
+
)
|
|
700
725
|
|
|
701
726
|
return output
|
|
702
727
|
|
|
@@ -856,16 +881,7 @@ class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]):
|
|
|
856
881
|
wrap_validation_errors: bool = True,
|
|
857
882
|
) -> OutputDataT:
|
|
858
883
|
args = {self._str_argument_name: data}
|
|
859
|
-
|
|
860
|
-
output = await execute_output_function_with_span(self._function_schema, run_context, args)
|
|
861
|
-
except ModelRetry as r:
|
|
862
|
-
if wrap_validation_errors:
|
|
863
|
-
m = _messages.RetryPromptPart(
|
|
864
|
-
content=r.message,
|
|
865
|
-
)
|
|
866
|
-
raise ToolRetryError(m) from r
|
|
867
|
-
else:
|
|
868
|
-
raise # pragma: no cover
|
|
884
|
+
output = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
|
|
869
885
|
|
|
870
886
|
return cast(OutputDataT, output)
|
|
871
887
|
|
|
@@ -975,7 +991,7 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
|
|
|
975
991
|
async def call_tool(
|
|
976
992
|
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
|
|
977
993
|
) -> Any:
|
|
978
|
-
output = await self.processors[name].call(tool_args, ctx)
|
|
994
|
+
output = await self.processors[name].call(tool_args, ctx, wrap_validation_errors=False)
|
|
979
995
|
for validator in self.output_validators:
|
|
980
996
|
output = await validator.validate(output, ctx, wrap_validation_errors=False)
|
|
981
997
|
return output
|
|
@@ -36,7 +36,7 @@ from ._tool_manager import ToolManager
|
|
|
36
36
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
37
37
|
from .output import OutputDataT, OutputSpec
|
|
38
38
|
from .profiles import ModelProfile
|
|
39
|
-
from .result import FinalResult, StreamedRunResult
|
|
39
|
+
from .result import AgentStream, FinalResult, StreamedRunResult
|
|
40
40
|
from .settings import ModelSettings, merge_model_settings
|
|
41
41
|
from .tools import (
|
|
42
42
|
AgentDepsT,
|
|
@@ -1127,29 +1127,15 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1127
1127
|
while True:
|
|
1128
1128
|
if self.is_model_request_node(node):
|
|
1129
1129
|
graph_ctx = agent_run.ctx
|
|
1130
|
-
async with node.
|
|
1131
|
-
|
|
1132
|
-
async def stream_to_final(
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
async for maybe_part_event in streamed_response:
|
|
1137
|
-
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1138
|
-
new_part = maybe_part_event.part
|
|
1139
|
-
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
1140
|
-
output_schema, _output.TextOutputSchema
|
|
1141
|
-
):
|
|
1142
|
-
return FinalResult(s, None, None)
|
|
1143
|
-
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
1144
|
-
tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
|
|
1145
|
-
):
|
|
1146
|
-
if tool_def.kind == 'output':
|
|
1147
|
-
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
|
|
1148
|
-
elif tool_def.kind == 'deferred':
|
|
1149
|
-
return FinalResult(s, None, None)
|
|
1130
|
+
async with node.stream(graph_ctx) as stream:
|
|
1131
|
+
|
|
1132
|
+
async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None:
|
|
1133
|
+
async for event in stream:
|
|
1134
|
+
if isinstance(event, _messages.FinalResultEvent):
|
|
1135
|
+
return FinalResult(s, event.tool_name, event.tool_call_id)
|
|
1150
1136
|
return None
|
|
1151
1137
|
|
|
1152
|
-
final_result = await stream_to_final(
|
|
1138
|
+
final_result = await stream_to_final(stream)
|
|
1153
1139
|
if final_result is not None:
|
|
1154
1140
|
if yielded:
|
|
1155
1141
|
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
|
|
@@ -1184,14 +1170,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1184
1170
|
yield StreamedRunResult(
|
|
1185
1171
|
messages,
|
|
1186
1172
|
graph_ctx.deps.new_message_index,
|
|
1187
|
-
|
|
1188
|
-
streamed_response,
|
|
1189
|
-
graph_ctx.deps.output_schema,
|
|
1190
|
-
_agent_graph.build_run_context(graph_ctx),
|
|
1191
|
-
graph_ctx.deps.output_validators,
|
|
1192
|
-
final_result.tool_name,
|
|
1173
|
+
stream,
|
|
1193
1174
|
on_complete,
|
|
1194
|
-
graph_ctx.deps.tool_manager,
|
|
1195
1175
|
)
|
|
1196
1176
|
break
|
|
1197
1177
|
next_node = await agent_run.next(node)
|
|
@@ -16,9 +16,7 @@ from pydantic_ai.profiles import ModelProfileSpec
|
|
|
16
16
|
from .. import _utils, usage
|
|
17
17
|
from .._utils import PeekableAsyncStream
|
|
18
18
|
from ..messages import (
|
|
19
|
-
AudioUrl,
|
|
20
19
|
BinaryContent,
|
|
21
|
-
ImageUrl,
|
|
22
20
|
ModelMessage,
|
|
23
21
|
ModelRequest,
|
|
24
22
|
ModelResponse,
|
|
@@ -345,18 +343,19 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
|
|
|
345
343
|
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
346
344
|
if not content:
|
|
347
345
|
return 0
|
|
346
|
+
|
|
348
347
|
if isinstance(content, str):
|
|
349
|
-
return len(
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
348
|
+
return len(_TOKEN_SPLIT_RE.split(content.strip()))
|
|
349
|
+
|
|
350
|
+
tokens = 0
|
|
351
|
+
for part in content:
|
|
352
|
+
if isinstance(part, str):
|
|
353
|
+
tokens += len(_TOKEN_SPLIT_RE.split(part.strip()))
|
|
354
|
+
elif isinstance(part, BinaryContent):
|
|
355
|
+
tokens += len(part.data)
|
|
356
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
|
|
357
|
+
|
|
358
|
+
return tokens
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
_TOKEN_SPLIT_RE = re.compile(r'[\s",.:]+')
|
|
@@ -52,6 +52,7 @@ try:
|
|
|
52
52
|
CompletionChunk as MistralCompletionChunk,
|
|
53
53
|
Content as MistralContent,
|
|
54
54
|
ContentChunk as MistralContentChunk,
|
|
55
|
+
DocumentURLChunk as MistralDocumentURLChunk,
|
|
55
56
|
FunctionCall as MistralFunctionCall,
|
|
56
57
|
ImageURL as MistralImageURL,
|
|
57
58
|
ImageURLChunk as MistralImageURLChunk,
|
|
@@ -539,10 +540,19 @@ class MistralModel(Model):
|
|
|
539
540
|
if item.is_image:
|
|
540
541
|
image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
|
|
541
542
|
content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
|
|
543
|
+
elif item.media_type == 'application/pdf':
|
|
544
|
+
content.append(
|
|
545
|
+
MistralDocumentURLChunk(
|
|
546
|
+
document_url=f'data:application/pdf;base64,{base64_encoded}', type='document_url'
|
|
547
|
+
)
|
|
548
|
+
)
|
|
542
549
|
else:
|
|
543
|
-
raise RuntimeError('
|
|
550
|
+
raise RuntimeError('BinaryContent other than image or PDF is not supported in Mistral.')
|
|
544
551
|
elif isinstance(item, DocumentUrl):
|
|
545
|
-
|
|
552
|
+
if item.media_type == 'application/pdf':
|
|
553
|
+
content.append(MistralDocumentURLChunk(document_url=item.url, type='document_url'))
|
|
554
|
+
else:
|
|
555
|
+
raise RuntimeError('DocumentUrl other than PDF is not supported in Mistral.')
|
|
546
556
|
elif isinstance(item, VideoUrl):
|
|
547
557
|
raise RuntimeError('VideoUrl is not supported in Mistral.')
|
|
548
558
|
else: # pragma: no cover
|
|
@@ -63,22 +63,18 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
63
63
|
async for response in self.stream_responses(debounce_by=debounce_by):
|
|
64
64
|
if self._final_result_event is not None:
|
|
65
65
|
try:
|
|
66
|
-
yield await self._validate_response(
|
|
67
|
-
response, self._final_result_event.tool_name, allow_partial=True
|
|
68
|
-
)
|
|
66
|
+
yield await self._validate_response(response, allow_partial=True)
|
|
69
67
|
except ValidationError:
|
|
70
68
|
pass
|
|
71
69
|
if self._final_result_event is not None: # pragma: no branch
|
|
72
|
-
yield await self._validate_response(
|
|
73
|
-
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
|
|
74
|
-
)
|
|
70
|
+
yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
|
|
75
71
|
|
|
76
72
|
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
|
|
77
73
|
"""Asynchronously stream the (unvalidated) model responses for the agent."""
|
|
78
74
|
# if the message currently has any parts with content, yield before streaming
|
|
79
75
|
msg = self._raw_stream_response.get()
|
|
80
76
|
for part in msg.parts:
|
|
81
|
-
if part.has_content():
|
|
77
|
+
if part.has_content():
|
|
82
78
|
yield msg
|
|
83
79
|
break
|
|
84
80
|
|
|
@@ -86,6 +82,35 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
86
82
|
async for _items in group_iter:
|
|
87
83
|
yield self._raw_stream_response.get() # current state of the response
|
|
88
84
|
|
|
85
|
+
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
86
|
+
"""Stream the text result as an async iterable.
|
|
87
|
+
|
|
88
|
+
!!! note
|
|
89
|
+
Result validators will NOT be called on the text result if `delta=True`.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
|
|
93
|
+
up to the current point.
|
|
94
|
+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
|
|
95
|
+
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
96
|
+
performing validation as each token is received.
|
|
97
|
+
"""
|
|
98
|
+
if not isinstance(self._output_schema, PlainTextOutputSchema):
|
|
99
|
+
raise exceptions.UserError('stream_text() can only be used with text responses')
|
|
100
|
+
|
|
101
|
+
if delta:
|
|
102
|
+
async for text in self._stream_response_text(delta=True, debounce_by=debounce_by):
|
|
103
|
+
yield text
|
|
104
|
+
else:
|
|
105
|
+
async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
|
|
106
|
+
for validator in self._output_validators:
|
|
107
|
+
text = await validator.validate(text, self._run_ctx) # pragma: no cover
|
|
108
|
+
yield text
|
|
109
|
+
|
|
110
|
+
def get(self) -> _messages.ModelResponse:
|
|
111
|
+
"""Get the current state of the response."""
|
|
112
|
+
return self._raw_stream_response.get()
|
|
113
|
+
|
|
89
114
|
def usage(self) -> Usage:
|
|
90
115
|
"""Return the usage of the whole run.
|
|
91
116
|
|
|
@@ -94,10 +119,24 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
94
119
|
"""
|
|
95
120
|
return self._initial_run_ctx_usage + self._raw_stream_response.usage()
|
|
96
121
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
122
|
+
def timestamp(self) -> datetime:
|
|
123
|
+
"""Get the timestamp of the response."""
|
|
124
|
+
return self._raw_stream_response.timestamp
|
|
125
|
+
|
|
126
|
+
async def get_output(self) -> OutputDataT:
|
|
127
|
+
"""Stream the whole response, validate the output and return it."""
|
|
128
|
+
async for _ in self:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
return await self._validate_response(self._raw_stream_response.get(), allow_partial=False)
|
|
132
|
+
|
|
133
|
+
async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
|
|
100
134
|
"""Validate a structured result message."""
|
|
135
|
+
if self._final_result_event is None:
|
|
136
|
+
raise exceptions.UnexpectedModelBehavior('Invalid response, unable to find output') # pragma: no cover
|
|
137
|
+
|
|
138
|
+
output_tool_name = self._final_result_event.tool_name
|
|
139
|
+
|
|
101
140
|
if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None:
|
|
102
141
|
tool_call = next(
|
|
103
142
|
(
|
|
@@ -114,7 +153,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
114
153
|
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
115
154
|
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
116
155
|
if not self._output_schema.allows_deferred_tool_calls:
|
|
117
|
-
raise exceptions.UserError(
|
|
156
|
+
raise exceptions.UserError(
|
|
118
157
|
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
119
158
|
)
|
|
120
159
|
return cast(OutputDataT, deferred_tool_calls)
|
|
@@ -132,6 +171,54 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
|
|
|
132
171
|
'Invalid response, unable to process text output'
|
|
133
172
|
)
|
|
134
173
|
|
|
174
|
+
async def _stream_response_text(
|
|
175
|
+
self, *, delta: bool = False, debounce_by: float | None = 0.1
|
|
176
|
+
) -> AsyncIterator[str]:
|
|
177
|
+
"""Stream the response as an async iterable of text."""
|
|
178
|
+
|
|
179
|
+
# Define a "merged" version of the iterator that will yield items that have already been retrieved
|
|
180
|
+
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
|
|
181
|
+
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
|
|
182
|
+
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
|
|
183
|
+
# yields tuples of (text_content, part_index)
|
|
184
|
+
# we don't currently make use of the part_index, but in principle this may be useful
|
|
185
|
+
# so we retain it here for now to make possible future refactors simpler
|
|
186
|
+
msg = self._raw_stream_response.get()
|
|
187
|
+
for i, part in enumerate(msg.parts):
|
|
188
|
+
if isinstance(part, _messages.TextPart) and part.content:
|
|
189
|
+
yield part.content, i
|
|
190
|
+
|
|
191
|
+
async for event in self._raw_stream_response:
|
|
192
|
+
if (
|
|
193
|
+
isinstance(event, _messages.PartStartEvent)
|
|
194
|
+
and isinstance(event.part, _messages.TextPart)
|
|
195
|
+
and event.part.content
|
|
196
|
+
):
|
|
197
|
+
yield event.part.content, event.index # pragma: no cover
|
|
198
|
+
elif ( # pragma: no branch
|
|
199
|
+
isinstance(event, _messages.PartDeltaEvent)
|
|
200
|
+
and isinstance(event.delta, _messages.TextPartDelta)
|
|
201
|
+
and event.delta.content_delta
|
|
202
|
+
):
|
|
203
|
+
yield event.delta.content_delta, event.index
|
|
204
|
+
|
|
205
|
+
async def _stream_text_deltas() -> AsyncIterator[str]:
|
|
206
|
+
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
|
|
207
|
+
async for items in group_iter:
|
|
208
|
+
# Note: we are currently just dropping the part index on the group here
|
|
209
|
+
yield ''.join([content for content, _ in items])
|
|
210
|
+
|
|
211
|
+
if delta:
|
|
212
|
+
async for text in _stream_text_deltas():
|
|
213
|
+
yield text
|
|
214
|
+
else:
|
|
215
|
+
# a quick benchmark shows it's faster to build up a string with concat when we're
|
|
216
|
+
# yielding at each step
|
|
217
|
+
deltas: list[str] = []
|
|
218
|
+
async for text in _stream_text_deltas():
|
|
219
|
+
deltas.append(text)
|
|
220
|
+
yield ''.join(deltas)
|
|
221
|
+
|
|
135
222
|
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
136
223
|
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
137
224
|
|
|
@@ -189,16 +276,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
189
276
|
_all_messages: list[_messages.ModelMessage]
|
|
190
277
|
_new_message_index: int
|
|
191
278
|
|
|
192
|
-
|
|
193
|
-
_stream_response: models.StreamedResponse
|
|
194
|
-
_output_schema: OutputSchema[OutputDataT]
|
|
195
|
-
_run_ctx: RunContext[AgentDepsT]
|
|
196
|
-
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
|
|
197
|
-
_output_tool_name: str | None
|
|
279
|
+
_stream_response: AgentStream[AgentDepsT, OutputDataT]
|
|
198
280
|
_on_complete: Callable[[], Awaitable[None]]
|
|
199
|
-
_tool_manager: ToolManager[AgentDepsT]
|
|
200
281
|
|
|
201
|
-
_initial_run_ctx_usage: Usage = field(init=False)
|
|
202
282
|
is_complete: bool = field(default=False, init=False)
|
|
203
283
|
"""Whether the stream has all been received.
|
|
204
284
|
|
|
@@ -209,9 +289,6 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
209
289
|
[`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes.
|
|
210
290
|
"""
|
|
211
291
|
|
|
212
|
-
def __post_init__(self):
|
|
213
|
-
self._initial_run_ctx_usage = copy(self._run_ctx.usage)
|
|
214
|
-
|
|
215
292
|
@overload
|
|
216
293
|
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
217
294
|
|
|
@@ -332,12 +409,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
332
409
|
Returns:
|
|
333
410
|
An async iterable of the response data.
|
|
334
411
|
"""
|
|
335
|
-
async for
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
except ValidationError:
|
|
339
|
-
if is_last:
|
|
340
|
-
raise # pragma: no cover
|
|
412
|
+
async for output in self._stream_response.stream_output(debounce_by=debounce_by):
|
|
413
|
+
yield output
|
|
414
|
+
await self._marked_completed(self._stream_response.get())
|
|
341
415
|
|
|
342
416
|
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
|
|
343
417
|
"""Stream the text result as an async iterable.
|
|
@@ -352,16 +426,8 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
352
426
|
Debouncing is particularly important for long structured responses to reduce the overhead of
|
|
353
427
|
performing validation as each token is received.
|
|
354
428
|
"""
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
if delta:
|
|
359
|
-
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
|
|
360
|
-
yield text
|
|
361
|
-
else:
|
|
362
|
-
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
|
|
363
|
-
combined_validated_text = await self._validate_text_output(text)
|
|
364
|
-
yield combined_validated_text
|
|
429
|
+
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
|
|
430
|
+
yield text
|
|
365
431
|
await self._marked_completed(self._stream_response.get())
|
|
366
432
|
|
|
367
433
|
async def stream_structured(
|
|
@@ -378,13 +444,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
378
444
|
An async iterable of the structured response message and whether that is the last message.
|
|
379
445
|
"""
|
|
380
446
|
# if the message currently has any parts with content, yield before streaming
|
|
381
|
-
msg
|
|
382
|
-
for part in msg.parts:
|
|
383
|
-
if part.has_content():
|
|
384
|
-
yield msg, False
|
|
385
|
-
break
|
|
386
|
-
|
|
387
|
-
async for msg in self._stream_response_structured(debounce_by=debounce_by):
|
|
447
|
+
async for msg in self._stream_response.stream_responses(debounce_by=debounce_by):
|
|
388
448
|
yield msg, False
|
|
389
449
|
|
|
390
450
|
msg = self._stream_response.get()
|
|
@@ -394,15 +454,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
394
454
|
|
|
395
455
|
async def get_output(self) -> OutputDataT:
|
|
396
456
|
"""Stream the whole response, validate and return it."""
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
async for _ in usage_checking_stream:
|
|
402
|
-
pass
|
|
403
|
-
message = self._stream_response.get()
|
|
404
|
-
await self._marked_completed(message)
|
|
405
|
-
return await self.validate_structured_output(message)
|
|
457
|
+
output = await self._stream_response.get_output()
|
|
458
|
+
await self._marked_completed(self._stream_response.get())
|
|
459
|
+
return output
|
|
406
460
|
|
|
407
461
|
@deprecated('`get_data` is deprecated, use `get_output` instead.')
|
|
408
462
|
async def get_data(self) -> OutputDataT:
|
|
@@ -414,11 +468,11 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
414
468
|
!!! note
|
|
415
469
|
This won't return the full usage until the stream is finished.
|
|
416
470
|
"""
|
|
417
|
-
return self.
|
|
471
|
+
return self._stream_response.usage()
|
|
418
472
|
|
|
419
473
|
def timestamp(self) -> datetime:
|
|
420
474
|
"""Get the timestamp of the response."""
|
|
421
|
-
return self._stream_response.timestamp
|
|
475
|
+
return self._stream_response.timestamp()
|
|
422
476
|
|
|
423
477
|
@deprecated('`validate_structured_result` is deprecated, use `validate_structured_output` instead.')
|
|
424
478
|
async def validate_structured_result(
|
|
@@ -430,105 +484,15 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]):
|
|
|
430
484
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
431
485
|
) -> OutputDataT:
|
|
432
486
|
"""Validate a structured result message."""
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
part
|
|
437
|
-
for part in message.parts
|
|
438
|
-
if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name
|
|
439
|
-
),
|
|
440
|
-
None,
|
|
441
|
-
)
|
|
442
|
-
if tool_call is None:
|
|
443
|
-
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
444
|
-
f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
|
|
445
|
-
)
|
|
446
|
-
return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial)
|
|
447
|
-
elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts):
|
|
448
|
-
if not self._output_schema.allows_deferred_tool_calls:
|
|
449
|
-
raise exceptions.UserError(
|
|
450
|
-
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
|
|
451
|
-
)
|
|
452
|
-
return cast(OutputDataT, deferred_tool_calls)
|
|
453
|
-
elif isinstance(self._output_schema, TextOutputSchema):
|
|
454
|
-
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
455
|
-
|
|
456
|
-
result_data = await self._output_schema.process(
|
|
457
|
-
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
|
|
458
|
-
)
|
|
459
|
-
for validator in self._output_validators:
|
|
460
|
-
result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
|
|
461
|
-
return result_data
|
|
462
|
-
else:
|
|
463
|
-
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
|
|
464
|
-
'Invalid response, unable to process text output'
|
|
465
|
-
)
|
|
466
|
-
|
|
467
|
-
async def _validate_text_output(self, text: str) -> str:
|
|
468
|
-
for validator in self._output_validators:
|
|
469
|
-
text = await validator.validate(text, self._run_ctx) # pragma: no cover
|
|
470
|
-
return text
|
|
487
|
+
return await self._stream_response._validate_response( # pyright: ignore[reportPrivateUsage]
|
|
488
|
+
message, allow_partial=allow_partial
|
|
489
|
+
)
|
|
471
490
|
|
|
472
491
|
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
|
|
473
492
|
self.is_complete = True
|
|
474
493
|
self._all_messages.append(message)
|
|
475
494
|
await self._on_complete()
|
|
476
495
|
|
|
477
|
-
async def _stream_response_structured(
|
|
478
|
-
self, *, debounce_by: float | None = 0.1
|
|
479
|
-
) -> AsyncIterator[_messages.ModelResponse]:
|
|
480
|
-
async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
|
|
481
|
-
async for _items in group_iter:
|
|
482
|
-
yield self._stream_response.get()
|
|
483
|
-
|
|
484
|
-
async def _stream_response_text(
|
|
485
|
-
self, *, delta: bool = False, debounce_by: float | None = 0.1
|
|
486
|
-
) -> AsyncIterator[str]:
|
|
487
|
-
"""Stream the response as an async iterable of text."""
|
|
488
|
-
|
|
489
|
-
# Define a "merged" version of the iterator that will yield items that have already been retrieved
|
|
490
|
-
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
|
|
491
|
-
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
|
|
492
|
-
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
|
|
493
|
-
# yields tuples of (text_content, part_index)
|
|
494
|
-
# we don't currently make use of the part_index, but in principle this may be useful
|
|
495
|
-
# so we retain it here for now to make possible future refactors simpler
|
|
496
|
-
msg = self._stream_response.get()
|
|
497
|
-
for i, part in enumerate(msg.parts):
|
|
498
|
-
if isinstance(part, _messages.TextPart) and part.content:
|
|
499
|
-
yield part.content, i
|
|
500
|
-
|
|
501
|
-
async for event in self._stream_response:
|
|
502
|
-
if (
|
|
503
|
-
isinstance(event, _messages.PartStartEvent)
|
|
504
|
-
and isinstance(event.part, _messages.TextPart)
|
|
505
|
-
and event.part.content
|
|
506
|
-
):
|
|
507
|
-
yield event.part.content, event.index # pragma: no cover
|
|
508
|
-
elif ( # pragma: no branch
|
|
509
|
-
isinstance(event, _messages.PartDeltaEvent)
|
|
510
|
-
and isinstance(event.delta, _messages.TextPartDelta)
|
|
511
|
-
and event.delta.content_delta
|
|
512
|
-
):
|
|
513
|
-
yield event.delta.content_delta, event.index
|
|
514
|
-
|
|
515
|
-
async def _stream_text_deltas() -> AsyncIterator[str]:
|
|
516
|
-
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
|
|
517
|
-
async for items in group_iter:
|
|
518
|
-
# Note: we are currently just dropping the part index on the group here
|
|
519
|
-
yield ''.join([content for content, _ in items])
|
|
520
|
-
|
|
521
|
-
if delta:
|
|
522
|
-
async for text in _stream_text_deltas():
|
|
523
|
-
yield text
|
|
524
|
-
else:
|
|
525
|
-
# a quick benchmark shows it's faster to build up a string with concat when we're
|
|
526
|
-
# yielding at each step
|
|
527
|
-
deltas: list[str] = []
|
|
528
|
-
async for text in _stream_text_deltas():
|
|
529
|
-
deltas.append(text)
|
|
530
|
-
yield ''.join(deltas)
|
|
531
|
-
|
|
532
496
|
|
|
533
497
|
@dataclass(repr=False)
|
|
534
498
|
class FinalResult(Generic[OutputDataT]):
|
|
@@ -556,12 +520,12 @@ def _get_usage_checking_stream_response(
|
|
|
556
520
|
) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
|
|
557
521
|
if limits is not None and limits.has_token_limits():
|
|
558
522
|
|
|
559
|
-
async def _usage_checking_iterator():
|
|
523
|
+
async def _usage_checking_iterator():
|
|
560
524
|
async for item in stream_response:
|
|
561
525
|
limits.check_tokens(get_usage())
|
|
562
526
|
yield item
|
|
563
527
|
|
|
564
|
-
return _usage_checking_iterator()
|
|
528
|
+
return _usage_checking_iterator()
|
|
565
529
|
else:
|
|
566
530
|
return stream_response
|
|
567
531
|
|
|
@@ -68,7 +68,7 @@ vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
|
|
|
68
68
|
google = ["google-genai>=1.24.0"]
|
|
69
69
|
anthropic = ["anthropic>=0.52.0"]
|
|
70
70
|
groq = ["groq>=0.19.0"]
|
|
71
|
-
mistral = ["mistralai>=1.2
|
|
71
|
+
mistral = ["mistralai>=1.9.2"]
|
|
72
72
|
bedrock = ["boto3>=1.37.24"]
|
|
73
73
|
huggingface = ["huggingface-hub[inference]>=0.33.2"]
|
|
74
74
|
# Tools
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|