pydantic-ai-slim 1.0.6__tar.gz → 1.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/PKG-INFO +3 -3
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_agent_graph.py +208 -127
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/ag_ui.py +44 -33
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/agent/__init__.py +38 -46
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/agent/abstract.py +7 -7
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/agent/wrapper.py +0 -1
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/dbos/_agent.py +14 -10
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/dbos/_mcp_server.py +4 -2
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_agent.py +0 -1
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_logfire.py +15 -3
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_toolset.py +17 -12
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/mcp.py +5 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/run.py +0 -2
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/tools.py +11 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/function.py +50 -9
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/.gitignore +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/LICENSE +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/README.md +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_a2a.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_function_schema.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_mcp.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_otel_messages.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_output.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_run_context.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_thinking_part.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_tool_manager.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/builtin_tools.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/direct.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/dbos/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/dbos/_model.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/dbos/_utils.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_function_toolset.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_mcp_server.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_model.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/durable_exec/temporal/_run_context.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/ext/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/ext/aci.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/ext/langchain.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/gemini.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/google.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/huggingface.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/mcp_sampling.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/output.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/_json_schema.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/amazon.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/anthropic.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/cohere.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/deepseek.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/google.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/grok.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/groq.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/harmony.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/meta.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/mistral.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/moonshotai.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/openai.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/profiles/qwen.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/cerebras.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/fireworks.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/gateway.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/github.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/google.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/grok.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/heroku.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/huggingface.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/litellm.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/moonshotai.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/ollama.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/openrouter.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/together.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/providers/vercel.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/retries.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/__init__.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/_dynamic.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/abstract.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/approval_required.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/combined.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/external.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/filtered.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/prefixed.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/prepared.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/renamed.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/toolsets/wrapper.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pydantic_ai/usage.py +0 -0
- {pydantic_ai_slim-1.0.6 → pydantic_ai_slim-1.0.7}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.7
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Project-URL: Homepage, https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim
|
|
@@ -33,7 +33,7 @@ Requires-Dist: genai-prices>=0.0.23
|
|
|
33
33
|
Requires-Dist: griffe>=1.3.2
|
|
34
34
|
Requires-Dist: httpx>=0.27
|
|
35
35
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
36
|
-
Requires-Dist: pydantic-graph==1.0.
|
|
36
|
+
Requires-Dist: pydantic-graph==1.0.7
|
|
37
37
|
Requires-Dist: pydantic>=2.10
|
|
38
38
|
Requires-Dist: typing-inspection>=0.4.0
|
|
39
39
|
Provides-Extra: a2a
|
|
@@ -57,7 +57,7 @@ Requires-Dist: dbos>=1.13.0; extra == 'dbos'
|
|
|
57
57
|
Provides-Extra: duckduckgo
|
|
58
58
|
Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
|
|
59
59
|
Provides-Extra: evals
|
|
60
|
-
Requires-Dist: pydantic-evals==1.0.
|
|
60
|
+
Requires-Dist: pydantic-evals==1.0.7; extra == 'evals'
|
|
61
61
|
Provides-Extra: google
|
|
62
62
|
Requires-Dist: google-genai>=1.31.0; extra == 'google'
|
|
63
63
|
Provides-Extra: groq
|
|
@@ -8,7 +8,7 @@ from collections import defaultdict, deque
|
|
|
8
8
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
|
9
9
|
from contextlib import asynccontextmanager, contextmanager
|
|
10
10
|
from contextvars import ContextVar
|
|
11
|
-
from dataclasses import field
|
|
11
|
+
from dataclasses import field, replace
|
|
12
12
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
|
|
13
13
|
|
|
14
14
|
from opentelemetry.trace import Tracer
|
|
@@ -16,7 +16,7 @@ from typing_extensions import TypeVar, assert_never
|
|
|
16
16
|
|
|
17
17
|
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
|
|
18
18
|
from pydantic_ai._tool_manager import ToolManager
|
|
19
|
-
from pydantic_ai._utils import is_async_callable, run_in_executor
|
|
19
|
+
from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor
|
|
20
20
|
from pydantic_ai.builtin_tools import AbstractBuiltinTool
|
|
21
21
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
22
22
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
@@ -26,7 +26,9 @@ from .exceptions import ToolRetryError
|
|
|
26
26
|
from .output import OutputDataT, OutputSpec
|
|
27
27
|
from .settings import ModelSettings
|
|
28
28
|
from .tools import (
|
|
29
|
+
DeferredToolCallResult,
|
|
29
30
|
DeferredToolResult,
|
|
31
|
+
DeferredToolResults,
|
|
30
32
|
RunContext,
|
|
31
33
|
ToolApproved,
|
|
32
34
|
ToolDefinition,
|
|
@@ -123,7 +125,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
|
|
|
123
125
|
|
|
124
126
|
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
|
|
125
127
|
tool_manager: ToolManager[DepsT]
|
|
126
|
-
tool_call_results: dict[str, DeferredToolResult] | None
|
|
127
128
|
|
|
128
129
|
tracer: Tracer
|
|
129
130
|
instrumentation_settings: InstrumentationSettings | None
|
|
@@ -160,14 +161,18 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
160
161
|
|
|
161
162
|
_: dataclasses.KW_ONLY
|
|
162
163
|
|
|
163
|
-
|
|
164
|
-
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
164
|
+
deferred_tool_results: DeferredToolResults | None = None
|
|
165
165
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
166
|
+
instructions: str | None = None
|
|
167
|
+
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
|
|
169
168
|
|
|
170
|
-
|
|
169
|
+
system_prompts: tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
|
170
|
+
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
|
|
171
|
+
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
|
|
172
|
+
default_factory=dict
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
async def run( # noqa: C901
|
|
171
176
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
172
177
|
) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
|
|
173
178
|
try:
|
|
@@ -181,119 +186,127 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
181
186
|
messages = ctx_messages.messages
|
|
182
187
|
ctx_messages.used = True
|
|
183
188
|
|
|
189
|
+
message_history = _clean_message_history(ctx.state.message_history)
|
|
184
190
|
# Add message history to the `capture_run_messages` list, which will be empty at this point
|
|
185
|
-
messages.extend(
|
|
191
|
+
messages.extend(message_history)
|
|
186
192
|
# Use the `capture_run_messages` list as the message history so that new messages are added to it
|
|
187
193
|
ctx.state.message_history = messages
|
|
194
|
+
ctx.deps.new_message_index = len(messages)
|
|
188
195
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
192
|
-
if messages:
|
|
193
|
-
# Reevaluate any dynamic system prompt parts
|
|
194
|
-
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
195
|
-
else:
|
|
196
|
-
parts.extend(await self._sys_parts(run_context))
|
|
197
|
-
|
|
198
|
-
if (tool_call_results := ctx.deps.tool_call_results) is not None:
|
|
199
|
-
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
|
|
200
|
-
# If tool call results were provided, that means the previous run ended on deferred tool calls.
|
|
201
|
-
# That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
|
|
202
|
-
# a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
|
|
203
|
-
# So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
|
|
204
|
-
ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
|
|
205
|
-
tool_call_results, last_message
|
|
206
|
-
)
|
|
207
|
-
messages.pop()
|
|
196
|
+
if self.deferred_tool_results is not None:
|
|
197
|
+
return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx)
|
|
208
198
|
|
|
209
|
-
|
|
210
|
-
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
|
|
199
|
+
next_message: _messages.ModelRequest | None = None
|
|
211
200
|
|
|
212
201
|
if messages and (last_message := messages[-1]):
|
|
213
202
|
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
|
|
214
203
|
# Drop last message from history and reuse its parts
|
|
215
204
|
messages.pop()
|
|
216
|
-
|
|
205
|
+
next_message = _messages.ModelRequest(parts=last_message.parts)
|
|
206
|
+
|
|
207
|
+
# Extract `UserPromptPart` content from the popped message and add to `ctx.deps.prompt`
|
|
208
|
+
user_prompt_parts = [part for part in last_message.parts if isinstance(part, _messages.UserPromptPart)]
|
|
209
|
+
if user_prompt_parts:
|
|
210
|
+
if len(user_prompt_parts) == 1:
|
|
211
|
+
ctx.deps.prompt = user_prompt_parts[0].content
|
|
212
|
+
else:
|
|
213
|
+
combined_content: list[_messages.UserContent] = []
|
|
214
|
+
for part in user_prompt_parts:
|
|
215
|
+
if isinstance(part.content, str):
|
|
216
|
+
combined_content.append(part.content)
|
|
217
|
+
else:
|
|
218
|
+
combined_content.extend(part.content)
|
|
219
|
+
ctx.deps.prompt = combined_content
|
|
217
220
|
elif isinstance(last_message, _messages.ModelResponse):
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
return
|
|
221
|
+
if self.user_prompt is None:
|
|
222
|
+
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
223
|
+
return CallToolsNode[DepsT, NodeRunEndT](last_message)
|
|
224
|
+
elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
|
|
225
|
+
raise exceptions.UserError(
|
|
226
|
+
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
|
|
227
|
+
)
|
|
221
228
|
|
|
222
|
-
|
|
223
|
-
|
|
229
|
+
# Build the run context after `ctx.deps.prompt` has been updated
|
|
230
|
+
run_context = build_run_context(ctx)
|
|
224
231
|
|
|
225
|
-
|
|
226
|
-
|
|
232
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
233
|
+
if messages:
|
|
234
|
+
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
227
235
|
|
|
228
|
-
|
|
236
|
+
if next_message:
|
|
237
|
+
await self._reevaluate_dynamic_prompts([next_message], run_context)
|
|
238
|
+
else:
|
|
239
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
240
|
+
if not messages:
|
|
241
|
+
parts.extend(await self._sys_parts(run_context))
|
|
229
242
|
|
|
230
|
-
async def _handle_message_history_model_response(
|
|
231
|
-
self,
|
|
232
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
233
|
-
message: _messages.ModelResponse,
|
|
234
|
-
) -> CallToolsNode[DepsT, NodeRunEndT] | None:
|
|
235
|
-
unprocessed_tool_calls = any(isinstance(part, _messages.ToolCallPart) for part in message.parts)
|
|
236
|
-
if unprocessed_tool_calls:
|
|
237
243
|
if self.user_prompt is not None:
|
|
238
|
-
|
|
239
|
-
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
|
|
240
|
-
)
|
|
241
|
-
else:
|
|
242
|
-
if ctx.deps.tool_call_results is not None:
|
|
243
|
-
raise exceptions.UserError(
|
|
244
|
-
'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
|
|
245
|
-
)
|
|
244
|
+
parts.append(_messages.UserPromptPart(self.user_prompt))
|
|
246
245
|
|
|
247
|
-
|
|
248
|
-
# `CallToolsNode` requires the tool manager to be prepared for the run step
|
|
249
|
-
# This will raise errors for any tool name conflicts
|
|
250
|
-
run_context = build_run_context(ctx)
|
|
251
|
-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
252
|
-
|
|
253
|
-
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
254
|
-
return CallToolsNode[DepsT, NodeRunEndT](model_response=message)
|
|
255
|
-
|
|
256
|
-
def _update_tool_call_results_from_model_request(
|
|
257
|
-
self, tool_call_results: dict[str, DeferredToolResult], message: _messages.ModelRequest
|
|
258
|
-
) -> dict[str, DeferredToolResult]:
|
|
259
|
-
last_tool_return: _messages.ToolReturn | None = None
|
|
260
|
-
user_content: list[str | _messages.UserContent] = []
|
|
261
|
-
for part in message.parts:
|
|
262
|
-
if isinstance(part, _messages.ToolReturnPart):
|
|
263
|
-
if part.tool_call_id in tool_call_results:
|
|
264
|
-
raise exceptions.UserError(
|
|
265
|
-
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
266
|
-
)
|
|
246
|
+
next_message = _messages.ModelRequest(parts=parts)
|
|
267
247
|
|
|
268
|
-
|
|
269
|
-
tool_call_results[part.tool_call_id] = last_tool_return
|
|
270
|
-
elif isinstance(part, _messages.RetryPromptPart):
|
|
271
|
-
if part.tool_call_id in tool_call_results:
|
|
272
|
-
raise exceptions.UserError(
|
|
273
|
-
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
274
|
-
)
|
|
248
|
+
next_message.instructions = await ctx.deps.get_instructions(run_context)
|
|
275
249
|
|
|
276
|
-
|
|
277
|
-
elif isinstance(part, _messages.UserPromptPart):
|
|
278
|
-
# Tools can return user parts via `ToolReturn.content` or by returning multi-modal content.
|
|
279
|
-
# These go together with a specific `ToolReturnPart`, but we don't have a way to know which,
|
|
280
|
-
# so (below) we just add them to the last one, matching the tool-results-before-user-parts order of the request.
|
|
281
|
-
if isinstance(part.content, str):
|
|
282
|
-
user_content.append(part.content)
|
|
283
|
-
else:
|
|
284
|
-
user_content.extend(part.content)
|
|
285
|
-
else:
|
|
286
|
-
raise exceptions.UserError(f'Unexpected message part type: {type(part)}') # pragma: no cover
|
|
250
|
+
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
|
|
287
251
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
252
|
+
async def _handle_deferred_tool_results( # noqa: C901
|
|
253
|
+
self,
|
|
254
|
+
deferred_tool_results: DeferredToolResults,
|
|
255
|
+
messages: list[_messages.ModelMessage],
|
|
256
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
257
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
258
|
+
if not messages:
|
|
259
|
+
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
|
|
260
|
+
|
|
261
|
+
last_model_request: _messages.ModelRequest | None = None
|
|
262
|
+
last_model_response: _messages.ModelResponse | None = None
|
|
263
|
+
for message in reversed(messages):
|
|
264
|
+
if isinstance(message, _messages.ModelRequest):
|
|
265
|
+
last_model_request = message
|
|
266
|
+
elif isinstance(message, _messages.ModelResponse): # pragma: no branch
|
|
267
|
+
last_model_response = message
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
if not last_model_response:
|
|
271
|
+
raise exceptions.UserError(
|
|
272
|
+
'Tool call results were provided, but the message history does not contain a `ModelResponse`.'
|
|
273
|
+
)
|
|
274
|
+
if not any(isinstance(part, _messages.ToolCallPart) for part in last_model_response.parts):
|
|
275
|
+
raise exceptions.UserError(
|
|
276
|
+
'Tool call results were provided, but the message history does not contain any unprocessed tool calls.'
|
|
277
|
+
)
|
|
278
|
+
if self.user_prompt is not None:
|
|
279
|
+
raise exceptions.UserError(
|
|
280
|
+
'Cannot provide a new user prompt when the message history contains unprocessed tool calls.'
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None
|
|
284
|
+
tool_call_results = {}
|
|
285
|
+
for tool_call_id, approval in deferred_tool_results.approvals.items():
|
|
286
|
+
if approval is True:
|
|
287
|
+
approval = ToolApproved()
|
|
288
|
+
elif approval is False:
|
|
289
|
+
approval = ToolDenied()
|
|
290
|
+
tool_call_results[tool_call_id] = approval
|
|
291
|
+
|
|
292
|
+
if calls := deferred_tool_results.calls:
|
|
293
|
+
call_result_types = get_union_args(DeferredToolCallResult)
|
|
294
|
+
for tool_call_id, result in calls.items():
|
|
295
|
+
if not isinstance(result, call_result_types):
|
|
296
|
+
result = _messages.ToolReturn(result)
|
|
297
|
+
tool_call_results[tool_call_id] = result
|
|
298
|
+
|
|
299
|
+
if last_model_request:
|
|
300
|
+
for part in last_model_request.parts:
|
|
301
|
+
if isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart):
|
|
302
|
+
if part.tool_call_id in tool_call_results:
|
|
303
|
+
raise exceptions.UserError(
|
|
304
|
+
f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.'
|
|
305
|
+
)
|
|
306
|
+
tool_call_results[part.tool_call_id] = 'skip'
|
|
295
307
|
|
|
296
|
-
|
|
308
|
+
# Skip ModelRequestNode and go directly to CallToolsNode
|
|
309
|
+
return CallToolsNode[DepsT, NodeRunEndT](last_model_response, tool_call_results=tool_call_results)
|
|
297
310
|
|
|
298
311
|
async def _reevaluate_dynamic_prompts(
|
|
299
312
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -330,6 +343,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
330
343
|
messages.append(_messages.SystemPromptPart(prompt))
|
|
331
344
|
return messages
|
|
332
345
|
|
|
346
|
+
__repr__ = dataclasses_no_defaults_repr
|
|
347
|
+
|
|
333
348
|
|
|
334
349
|
async def _prepare_request_parameters(
|
|
335
350
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
@@ -441,6 +456,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
441
456
|
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
442
457
|
|
|
443
458
|
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
|
|
459
|
+
message_history = _clean_message_history(message_history)
|
|
444
460
|
|
|
445
461
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
446
462
|
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
@@ -476,12 +492,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
476
492
|
|
|
477
493
|
return self._result
|
|
478
494
|
|
|
495
|
+
__repr__ = dataclasses_no_defaults_repr
|
|
496
|
+
|
|
479
497
|
|
|
480
498
|
@dataclasses.dataclass
|
|
481
499
|
class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
482
500
|
"""The node that processes a model response, and decides whether to end the run or make a new request."""
|
|
483
501
|
|
|
484
502
|
model_response: _messages.ModelResponse
|
|
503
|
+
tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None
|
|
485
504
|
|
|
486
505
|
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
|
|
487
506
|
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
@@ -582,11 +601,20 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
582
601
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
583
602
|
run_context = build_run_context(ctx)
|
|
584
603
|
|
|
604
|
+
# This will raise errors for any tool name conflicts
|
|
605
|
+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
606
|
+
|
|
585
607
|
output_parts: list[_messages.ModelRequestPart] = []
|
|
586
608
|
output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1)
|
|
587
609
|
|
|
588
|
-
async for event in
|
|
589
|
-
ctx.deps.tool_manager,
|
|
610
|
+
async for event in process_tool_calls(
|
|
611
|
+
tool_manager=ctx.deps.tool_manager,
|
|
612
|
+
tool_calls=tool_calls,
|
|
613
|
+
tool_call_results=self.tool_call_results,
|
|
614
|
+
final_result=None,
|
|
615
|
+
ctx=ctx,
|
|
616
|
+
output_parts=output_parts,
|
|
617
|
+
output_final_result=output_final_result,
|
|
590
618
|
):
|
|
591
619
|
yield event
|
|
592
620
|
|
|
@@ -639,6 +667,8 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
639
667
|
else:
|
|
640
668
|
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
|
|
641
669
|
|
|
670
|
+
__repr__ = dataclasses_no_defaults_repr
|
|
671
|
+
|
|
642
672
|
|
|
643
673
|
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
644
674
|
"""Build a `RunContext` object from the current agent graph run context."""
|
|
@@ -652,13 +682,14 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
652
682
|
trace_include_content=ctx.deps.instrumentation_settings is not None
|
|
653
683
|
and ctx.deps.instrumentation_settings.include_content,
|
|
654
684
|
run_step=ctx.state.run_step,
|
|
655
|
-
tool_call_approved=ctx.state.run_step == 0
|
|
685
|
+
tool_call_approved=ctx.state.run_step == 0,
|
|
656
686
|
)
|
|
657
687
|
|
|
658
688
|
|
|
659
|
-
async def
|
|
689
|
+
async def process_tool_calls( # noqa: C901
|
|
660
690
|
tool_manager: ToolManager[DepsT],
|
|
661
691
|
tool_calls: list[_messages.ToolCallPart],
|
|
692
|
+
tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None,
|
|
662
693
|
final_result: result.FinalResult[NodeRunEndT] | None,
|
|
663
694
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
664
695
|
output_parts: list[_messages.ModelRequestPart],
|
|
@@ -739,14 +770,13 @@ async def process_function_tools( # noqa: C901
|
|
|
739
770
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
740
771
|
calls_to_run.extend(tool_calls_by_kind['unknown'])
|
|
741
772
|
|
|
742
|
-
|
|
743
|
-
if
|
|
744
|
-
deferred_tool_results = ctx.deps.tool_call_results
|
|
773
|
+
calls_to_run_results: dict[str, DeferredToolResult] = {}
|
|
774
|
+
if tool_call_results is not None:
|
|
745
775
|
# Deferred tool calls are "run" as well, by reading their value from the tool call results
|
|
746
776
|
calls_to_run.extend(tool_calls_by_kind['external'])
|
|
747
777
|
calls_to_run.extend(tool_calls_by_kind['unapproved'])
|
|
748
778
|
|
|
749
|
-
result_tool_call_ids = set(
|
|
779
|
+
result_tool_call_ids = set(tool_call_results.keys())
|
|
750
780
|
tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run}
|
|
751
781
|
if tool_call_ids_to_run != result_tool_call_ids:
|
|
752
782
|
raise exceptions.UserError(
|
|
@@ -754,24 +784,29 @@ async def process_function_tools( # noqa: C901
|
|
|
754
784
|
f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}'
|
|
755
785
|
)
|
|
756
786
|
|
|
787
|
+
# Filter out calls that were already executed before and should now be skipped
|
|
788
|
+
calls_to_run_results = {call_id: result for call_id, result in tool_call_results.items() if result != 'skip'}
|
|
789
|
+
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
|
|
790
|
+
|
|
757
791
|
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
|
|
758
792
|
|
|
759
793
|
if calls_to_run:
|
|
760
794
|
async for event in _call_tools(
|
|
761
|
-
tool_manager,
|
|
762
|
-
calls_to_run,
|
|
763
|
-
|
|
764
|
-
ctx.deps.tracer,
|
|
765
|
-
ctx.deps.usage_limits,
|
|
766
|
-
output_parts,
|
|
767
|
-
deferred_calls,
|
|
795
|
+
tool_manager=tool_manager,
|
|
796
|
+
tool_calls=calls_to_run,
|
|
797
|
+
tool_call_results=calls_to_run_results,
|
|
798
|
+
tracer=ctx.deps.tracer,
|
|
799
|
+
usage_limits=ctx.deps.usage_limits,
|
|
800
|
+
output_parts=output_parts,
|
|
801
|
+
output_deferred_calls=deferred_calls,
|
|
768
802
|
):
|
|
769
803
|
yield event
|
|
770
804
|
|
|
771
805
|
# Finally, we handle deferred tool calls (unless they were already included in the run because results were provided)
|
|
772
|
-
if
|
|
806
|
+
if tool_call_results is None:
|
|
807
|
+
calls = [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]
|
|
773
808
|
if final_result:
|
|
774
|
-
for call in
|
|
809
|
+
for call in calls:
|
|
775
810
|
output_parts.append(
|
|
776
811
|
_messages.ToolReturnPart(
|
|
777
812
|
tool_name=call.tool_name,
|
|
@@ -779,13 +814,11 @@ async def process_function_tools( # noqa: C901
|
|
|
779
814
|
tool_call_id=call.tool_call_id,
|
|
780
815
|
)
|
|
781
816
|
)
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
yield _messages.FunctionToolCallEvent(call)
|
|
817
|
+
elif calls:
|
|
818
|
+
deferred_calls['external'].extend(tool_calls_by_kind['external'])
|
|
819
|
+
deferred_calls['unapproved'].extend(tool_calls_by_kind['unapproved'])
|
|
786
820
|
|
|
787
|
-
for call in
|
|
788
|
-
deferred_calls['unapproved'].append(call)
|
|
821
|
+
for call in calls:
|
|
789
822
|
yield _messages.FunctionToolCallEvent(call)
|
|
790
823
|
|
|
791
824
|
if not final_result and deferred_calls:
|
|
@@ -807,7 +840,7 @@ async def process_function_tools( # noqa: C901
|
|
|
807
840
|
async def _call_tools(
|
|
808
841
|
tool_manager: ToolManager[DepsT],
|
|
809
842
|
tool_calls: list[_messages.ToolCallPart],
|
|
810
|
-
|
|
843
|
+
tool_call_results: dict[str, DeferredToolResult],
|
|
811
844
|
tracer: Tracer,
|
|
812
845
|
usage_limits: _usage.UsageLimits | None,
|
|
813
846
|
output_parts: list[_messages.ModelRequestPart],
|
|
@@ -853,7 +886,7 @@ async def _call_tools(
|
|
|
853
886
|
if tool_manager.should_call_sequentially(tool_calls):
|
|
854
887
|
for index, call in enumerate(tool_calls):
|
|
855
888
|
if event := await handle_call_or_result(
|
|
856
|
-
_call_tool(tool_manager, call,
|
|
889
|
+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
|
|
857
890
|
index,
|
|
858
891
|
):
|
|
859
892
|
yield event
|
|
@@ -861,7 +894,7 @@ async def _call_tools(
|
|
|
861
894
|
else:
|
|
862
895
|
tasks = [
|
|
863
896
|
asyncio.create_task(
|
|
864
|
-
_call_tool(tool_manager, call,
|
|
897
|
+
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
|
|
865
898
|
name=call.tool_name,
|
|
866
899
|
)
|
|
867
900
|
for call in tool_calls
|
|
@@ -1079,3 +1112,51 @@ async def _process_message_history(
|
|
|
1079
1112
|
# Replaces the message history in the state with the processed messages
|
|
1080
1113
|
state.message_history = messages
|
|
1081
1114
|
return messages
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_messages.ModelMessage]:
|
|
1118
|
+
"""Clean the message history by merging consecutive messages of the same type."""
|
|
1119
|
+
clean_messages: list[_messages.ModelMessage] = []
|
|
1120
|
+
for message in messages:
|
|
1121
|
+
last_message = clean_messages[-1] if len(clean_messages) > 0 else None
|
|
1122
|
+
|
|
1123
|
+
if isinstance(message, _messages.ModelRequest):
|
|
1124
|
+
if (
|
|
1125
|
+
last_message
|
|
1126
|
+
and isinstance(last_message, _messages.ModelRequest)
|
|
1127
|
+
# Requests can only be merged if they have the same instructions
|
|
1128
|
+
and (
|
|
1129
|
+
not last_message.instructions
|
|
1130
|
+
or not message.instructions
|
|
1131
|
+
or last_message.instructions == message.instructions
|
|
1132
|
+
)
|
|
1133
|
+
):
|
|
1134
|
+
parts = [*last_message.parts, *message.parts]
|
|
1135
|
+
parts.sort(
|
|
1136
|
+
# Tool return parts always need to be at the start
|
|
1137
|
+
key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1
|
|
1138
|
+
)
|
|
1139
|
+
merged_message = _messages.ModelRequest(
|
|
1140
|
+
parts=parts,
|
|
1141
|
+
instructions=last_message.instructions or message.instructions,
|
|
1142
|
+
)
|
|
1143
|
+
clean_messages[-1] = merged_message
|
|
1144
|
+
else:
|
|
1145
|
+
clean_messages.append(message)
|
|
1146
|
+
elif isinstance(message, _messages.ModelResponse): # pragma: no branch
|
|
1147
|
+
if (
|
|
1148
|
+
last_message
|
|
1149
|
+
and isinstance(last_message, _messages.ModelResponse)
|
|
1150
|
+
# Responses can only be merged if they didn't really come from an API
|
|
1151
|
+
and last_message.provider_response_id is None
|
|
1152
|
+
and last_message.provider_name is None
|
|
1153
|
+
and last_message.model_name is None
|
|
1154
|
+
and message.provider_response_id is None
|
|
1155
|
+
and message.provider_name is None
|
|
1156
|
+
and message.model_name is None
|
|
1157
|
+
):
|
|
1158
|
+
merged_message = replace(last_message, parts=[*last_message.parts, *message.parts])
|
|
1159
|
+
clean_messages[-1] = merged_message
|
|
1160
|
+
else:
|
|
1161
|
+
clean_messages.append(message)
|
|
1162
|
+
return clean_messages
|
|
@@ -23,6 +23,7 @@ from typing import (
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
from pydantic import BaseModel, ValidationError
|
|
26
|
+
from typing_extensions import assert_never
|
|
26
27
|
|
|
27
28
|
from . import _utils
|
|
28
29
|
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
@@ -32,7 +33,9 @@ from .messages import (
|
|
|
32
33
|
FunctionToolResultEvent,
|
|
33
34
|
ModelMessage,
|
|
34
35
|
ModelRequest,
|
|
36
|
+
ModelRequestPart,
|
|
35
37
|
ModelResponse,
|
|
38
|
+
ModelResponsePart,
|
|
36
39
|
ModelResponseStreamEvent,
|
|
37
40
|
PartDeltaEvent,
|
|
38
41
|
PartStartEvent,
|
|
@@ -573,49 +576,57 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
|
573
576
|
"""Convert a AG-UI history to a Pydantic AI one."""
|
|
574
577
|
result: list[ModelMessage] = []
|
|
575
578
|
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
579
|
+
request_parts: list[ModelRequestPart] | None = None
|
|
580
|
+
response_parts: list[ModelResponsePart] | None = None
|
|
576
581
|
for msg in messages:
|
|
577
|
-
if isinstance(msg, UserMessage):
|
|
578
|
-
|
|
582
|
+
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage | ToolMessage):
|
|
583
|
+
if request_parts is None:
|
|
584
|
+
request_parts = []
|
|
585
|
+
result.append(ModelRequest(parts=request_parts))
|
|
586
|
+
response_parts = None
|
|
587
|
+
|
|
588
|
+
if isinstance(msg, UserMessage):
|
|
589
|
+
request_parts.append(UserPromptPart(content=msg.content))
|
|
590
|
+
elif isinstance(msg, SystemMessage | DeveloperMessage):
|
|
591
|
+
request_parts.append(SystemPromptPart(content=msg.content))
|
|
592
|
+
elif isinstance(msg, ToolMessage):
|
|
593
|
+
tool_name = tool_calls.get(msg.tool_call_id)
|
|
594
|
+
if tool_name is None: # pragma: no cover
|
|
595
|
+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
596
|
+
|
|
597
|
+
request_parts.append(
|
|
598
|
+
ToolReturnPart(
|
|
599
|
+
tool_name=tool_name,
|
|
600
|
+
content=msg.content,
|
|
601
|
+
tool_call_id=msg.tool_call_id,
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
assert_never(msg)
|
|
606
|
+
|
|
579
607
|
elif isinstance(msg, AssistantMessage):
|
|
608
|
+
if response_parts is None:
|
|
609
|
+
response_parts = []
|
|
610
|
+
result.append(ModelResponse(parts=response_parts))
|
|
611
|
+
request_parts = None
|
|
612
|
+
|
|
580
613
|
if msg.content:
|
|
581
|
-
|
|
614
|
+
response_parts.append(TextPart(content=msg.content))
|
|
582
615
|
|
|
583
616
|
if msg.tool_calls:
|
|
584
617
|
for tool_call in msg.tool_calls:
|
|
585
618
|
tool_calls[tool_call.id] = tool_call.function.name
|
|
586
619
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
tool_call_id=tool_call.id,
|
|
593
|
-
args=tool_call.function.arguments,
|
|
594
|
-
)
|
|
595
|
-
for tool_call in msg.tool_calls
|
|
596
|
-
]
|
|
620
|
+
response_parts.extend(
|
|
621
|
+
ToolCallPart(
|
|
622
|
+
tool_name=tool_call.function.name,
|
|
623
|
+
tool_call_id=tool_call.id,
|
|
624
|
+
args=tool_call.function.arguments,
|
|
597
625
|
)
|
|
626
|
+
for tool_call in msg.tool_calls
|
|
598
627
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
elif isinstance(msg, ToolMessage):
|
|
602
|
-
tool_name = tool_calls.get(msg.tool_call_id)
|
|
603
|
-
if tool_name is None: # pragma: no cover
|
|
604
|
-
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
605
|
-
|
|
606
|
-
result.append(
|
|
607
|
-
ModelRequest(
|
|
608
|
-
parts=[
|
|
609
|
-
ToolReturnPart(
|
|
610
|
-
tool_name=tool_name,
|
|
611
|
-
content=msg.content,
|
|
612
|
-
tool_call_id=msg.tool_call_id,
|
|
613
|
-
)
|
|
614
|
-
]
|
|
615
|
-
)
|
|
616
|
-
)
|
|
617
|
-
elif isinstance(msg, DeveloperMessage): # pragma: no branch
|
|
618
|
-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
628
|
+
else:
|
|
629
|
+
assert_never(msg)
|
|
619
630
|
|
|
620
631
|
return result
|
|
621
632
|
|