aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,1923 @@
|
|
|
1
|
+
"""ReactAgent - Autonomous agent with think-act-observe loop.
|
|
2
|
+
|
|
3
|
+
ReactAgent uses the unified BaseAgent constructor:
|
|
4
|
+
__init__(self, ctx: InvocationContext, config: AgentConfig | None = None)
|
|
5
|
+
|
|
6
|
+
All services (llm, tools, storage, etc.) are accessed through ctx.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import json
|
|
13
|
+
from dataclasses import asdict
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Any, AsyncIterator, ClassVar, Literal, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from ..core.base import AgentConfig, BaseAgent, ToolInjectionMode
|
|
18
|
+
from ..core.context import InvocationContext
|
|
19
|
+
from ..core.logging import react_logger as logger
|
|
20
|
+
from ..core.event_bus import Events
|
|
21
|
+
from ..context_providers import ContextProvider, AgentContext
|
|
22
|
+
from ..core.types.block import BlockEvent, BlockKind, BlockOp
|
|
23
|
+
from ..llm import LLMMessage, ToolDefinition
|
|
24
|
+
from ..middleware import HookAction
|
|
25
|
+
from ..core.types import (
|
|
26
|
+
Invocation,
|
|
27
|
+
InvocationState,
|
|
28
|
+
PromptInput,
|
|
29
|
+
ToolContext,
|
|
30
|
+
ToolResult,
|
|
31
|
+
ToolInvocation,
|
|
32
|
+
ToolInvocationState,
|
|
33
|
+
generate_id,
|
|
34
|
+
)
|
|
35
|
+
from ..core.state import State
|
|
36
|
+
from ..core.signals import SuspendSignal, HITLSuspend
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from ..llm import LLMProvider
|
|
40
|
+
from ..tool import ToolSet
|
|
41
|
+
from ..core.types.tool import BaseTool
|
|
42
|
+
from ..core.types.session import Session
|
|
43
|
+
from ..backends import Backends
|
|
44
|
+
from ..backends.state import StateBackend
|
|
45
|
+
from ..backends.snapshot import SnapshotBackend
|
|
46
|
+
from ..backends.subagent import AgentConfig as SubAgentConfig
|
|
47
|
+
from ..core.event_bus import Bus
|
|
48
|
+
from ..middleware import MiddlewareChain, Middleware
|
|
49
|
+
from ..memory import MemoryManager
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SessionNotFoundError(Exception):
|
|
53
|
+
"""Raised when session is not found in storage."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ReactAgent(BaseAgent):
|
|
58
|
+
"""ReAct Agent - Autonomous agent with tool calling loop.
|
|
59
|
+
|
|
60
|
+
Implements the think-act-observe pattern:
|
|
61
|
+
1. Think: LLM generates reasoning and decides on actions
|
|
62
|
+
2. Act: Execute tool calls
|
|
63
|
+
3. Observe: Process tool results
|
|
64
|
+
4. Repeat until done or max steps reached
|
|
65
|
+
|
|
66
|
+
Two ways to create:
|
|
67
|
+
|
|
68
|
+
1. Simple (recommended for most cases):
|
|
69
|
+
agent = ReactAgent.create(llm=llm, tools=tools, config=config)
|
|
70
|
+
|
|
71
|
+
2. Advanced (for custom Session/Backends/Bus):
|
|
72
|
+
ctx = InvocationContext(session=session, backends=backends, bus=bus, llm=llm, tools=tools)
|
|
73
|
+
agent = ReactAgent(ctx, config)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# Class-level config
|
|
77
|
+
agent_type: ClassVar[Literal["react", "workflow"]] = "react"
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def create(
|
|
81
|
+
cls,
|
|
82
|
+
llm: "LLMProvider",
|
|
83
|
+
tools: "ToolSet | list[BaseTool] | None" = None,
|
|
84
|
+
config: AgentConfig | None = None,
|
|
85
|
+
*,
|
|
86
|
+
backends: "Backends | None" = None,
|
|
87
|
+
session: "Session | None" = None,
|
|
88
|
+
bus: "Bus | None" = None,
|
|
89
|
+
middlewares: "list[Middleware] | None" = None,
|
|
90
|
+
subagents: "list[SubAgentConfig] | None" = None,
|
|
91
|
+
memory: "MemoryManager | None" = None,
|
|
92
|
+
snapshot: "SnapshotBackend | None" = None,
|
|
93
|
+
# ContextProvider system
|
|
94
|
+
context_providers: "list[ContextProvider] | None" = None,
|
|
95
|
+
enable_history: bool = True,
|
|
96
|
+
history_limit: int = 50,
|
|
97
|
+
# Tool customization
|
|
98
|
+
delegate_tool_class: "type[BaseTool] | None" = None,
|
|
99
|
+
) -> "ReactAgent":
|
|
100
|
+
"""Create ReactAgent with minimal boilerplate.
|
|
101
|
+
|
|
102
|
+
This is the recommended way to create a ReactAgent for simple use cases.
|
|
103
|
+
Session, Storage, and Bus are auto-created if not provided.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
llm: LLM provider (required)
|
|
107
|
+
tools: Tool registry or list of tools (optional)
|
|
108
|
+
config: Agent configuration (optional)
|
|
109
|
+
backends: Backends container (recommended, auto-created if None)
|
|
110
|
+
session: Session object (auto-created if None)
|
|
111
|
+
bus: Event bus (auto-created if None)
|
|
112
|
+
middlewares: List of middlewares (auto-creates chain)
|
|
113
|
+
subagents: List of sub-agent configs (auto-creates SubAgentManager)
|
|
114
|
+
memory: Memory manager (optional)
|
|
115
|
+
snapshot: Snapshot backend (optional)
|
|
116
|
+
context_providers: Additional custom context providers (optional)
|
|
117
|
+
enable_history: Enable message history (default True)
|
|
118
|
+
history_limit: Max conversation turns to keep (default 50)
|
|
119
|
+
delegate_tool_class: Custom DelegateTool class (optional)
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Configured ReactAgent ready to run
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
# Minimal
|
|
126
|
+
agent = ReactAgent.create(llm=my_llm)
|
|
127
|
+
|
|
128
|
+
# With backends
|
|
129
|
+
agent = ReactAgent.create(
|
|
130
|
+
llm=my_llm,
|
|
131
|
+
backends=Backends.create_default(),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# With tools and middlewares
|
|
135
|
+
agent = ReactAgent.create(
|
|
136
|
+
llm=my_llm,
|
|
137
|
+
tools=[tool1, tool2],
|
|
138
|
+
middlewares=[MessageContainerMiddleware()],
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# With sub-agents
|
|
142
|
+
agent = ReactAgent.create(
|
|
143
|
+
llm=my_llm,
|
|
144
|
+
subagents=[
|
|
145
|
+
AgentConfig(key="researcher", agent=researcher_agent),
|
|
146
|
+
],
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# With custom context providers
|
|
150
|
+
agent = ReactAgent.create(
|
|
151
|
+
llm=my_llm,
|
|
152
|
+
tools=[tool1],
|
|
153
|
+
context_providers=[MyRAGProvider(), MyProjectProvider()],
|
|
154
|
+
)
|
|
155
|
+
"""
|
|
156
|
+
from ..core.event_bus import EventBus
|
|
157
|
+
from ..core.types.session import Session, generate_id
|
|
158
|
+
from ..backends import Backends
|
|
159
|
+
from ..backends.subagent import ListSubAgentBackend
|
|
160
|
+
from ..tool import ToolSet
|
|
161
|
+
from ..tool.builtin import DelegateTool
|
|
162
|
+
from ..middleware import MiddlewareChain, MessageBackendMiddleware
|
|
163
|
+
from ..context_providers import MessageContextProvider
|
|
164
|
+
|
|
165
|
+
# Auto-create backends if not provided
|
|
166
|
+
if backends is None:
|
|
167
|
+
backends = Backends.create_default()
|
|
168
|
+
|
|
169
|
+
# Auto-create missing components
|
|
170
|
+
if session is None:
|
|
171
|
+
session = Session(id=generate_id("sess"))
|
|
172
|
+
if bus is None:
|
|
173
|
+
bus = EventBus()
|
|
174
|
+
|
|
175
|
+
# Create middleware chain (add MessageBackendMiddleware if history enabled)
|
|
176
|
+
middleware_chain: MiddlewareChain | None = None
|
|
177
|
+
if middlewares or enable_history:
|
|
178
|
+
middleware_chain = MiddlewareChain()
|
|
179
|
+
# Add message persistence middleware first (uses backends.message)
|
|
180
|
+
if enable_history and backends.message is not None:
|
|
181
|
+
middleware_chain.use(MessageBackendMiddleware(max_history=history_limit))
|
|
182
|
+
# Add user middlewares
|
|
183
|
+
if middlewares:
|
|
184
|
+
for mw in middlewares:
|
|
185
|
+
middleware_chain.use(mw)
|
|
186
|
+
|
|
187
|
+
# === Build tools list (direct, no provider) ===
|
|
188
|
+
tool_list: list["BaseTool"] = []
|
|
189
|
+
if tools is not None:
|
|
190
|
+
if isinstance(tools, ToolSet):
|
|
191
|
+
tool_list = list(tools.all())
|
|
192
|
+
else:
|
|
193
|
+
tool_list = list(tools)
|
|
194
|
+
|
|
195
|
+
# Handle subagents - create DelegateTool directly
|
|
196
|
+
if subagents:
|
|
197
|
+
backend = ListSubAgentBackend(subagents)
|
|
198
|
+
tool_cls = delegate_tool_class or DelegateTool
|
|
199
|
+
delegate_tool = tool_cls(backend, middleware=middleware_chain)
|
|
200
|
+
tool_list.append(delegate_tool)
|
|
201
|
+
|
|
202
|
+
# === Build providers ===
|
|
203
|
+
default_providers: list["ContextProvider"] = []
|
|
204
|
+
|
|
205
|
+
# MessageContextProvider - for fetching history (uses backends.message)
|
|
206
|
+
if enable_history:
|
|
207
|
+
message_provider = MessageContextProvider(max_messages=history_limit * 2)
|
|
208
|
+
default_providers.append(message_provider)
|
|
209
|
+
|
|
210
|
+
# Combine default + custom context_providers
|
|
211
|
+
all_providers = default_providers + (context_providers or [])
|
|
212
|
+
|
|
213
|
+
# Build context
|
|
214
|
+
ctx = InvocationContext(
|
|
215
|
+
session=session,
|
|
216
|
+
invocation_id=generate_id("inv"),
|
|
217
|
+
agent_id=config.name if config else "react_agent",
|
|
218
|
+
backends=backends,
|
|
219
|
+
bus=bus,
|
|
220
|
+
llm=llm,
|
|
221
|
+
middleware=middleware_chain,
|
|
222
|
+
memory=memory,
|
|
223
|
+
snapshot=snapshot,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
agent = cls(ctx, config)
|
|
227
|
+
agent._tools = tool_list # Direct tools (not from context_provider)
|
|
228
|
+
agent._context_providers = all_providers
|
|
229
|
+
agent._delegate_tool_class = delegate_tool_class or DelegateTool
|
|
230
|
+
agent._middleware_chain = middleware_chain
|
|
231
|
+
return agent
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
async def restore(
|
|
235
|
+
cls,
|
|
236
|
+
session_id: str,
|
|
237
|
+
llm: "LLMProvider",
|
|
238
|
+
*,
|
|
239
|
+
backends: "Backends | None" = None,
|
|
240
|
+
tools: "ToolSet | list[BaseTool] | None" = None,
|
|
241
|
+
config: AgentConfig | None = None,
|
|
242
|
+
bus: "Bus | None" = None,
|
|
243
|
+
middleware: "MiddlewareChain | None" = None,
|
|
244
|
+
memory: "MemoryManager | None" = None,
|
|
245
|
+
snapshot: "SnapshotBackend | None" = None,
|
|
246
|
+
) -> "ReactAgent":
|
|
247
|
+
"""Restore agent from persisted state.
|
|
248
|
+
|
|
249
|
+
Use this to resume an agent after:
|
|
250
|
+
- Page refresh
|
|
251
|
+
- Process restart
|
|
252
|
+
- Cross-process recovery
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
session_id: Session ID to restore
|
|
256
|
+
llm: LLM provider
|
|
257
|
+
backends: Backends container (recommended, auto-created if None)
|
|
258
|
+
tools: Tool registry or list of tools
|
|
259
|
+
config: Agent configuration
|
|
260
|
+
bus: Event bus (auto-created if None)
|
|
261
|
+
middleware: Middleware chain
|
|
262
|
+
memory: Memory manager
|
|
263
|
+
snapshot: Snapshot backend
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Restored ReactAgent ready to continue
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
SessionNotFoundError: If session not found
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
agent = await ReactAgent.restore(
|
|
273
|
+
session_id="sess_xxx",
|
|
274
|
+
backends=my_backends,
|
|
275
|
+
llm=my_llm,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Check if waiting for HITL response
|
|
279
|
+
if agent.is_suspended:
|
|
280
|
+
print(f"Waiting for: {agent.pending_request}")
|
|
281
|
+
else:
|
|
282
|
+
# Continue conversation
|
|
283
|
+
await agent.run("Continue...")
|
|
284
|
+
"""
|
|
285
|
+
from ..core.event_bus import Bus
|
|
286
|
+
from ..core.types.session import Session, Invocation, InvocationState, generate_id
|
|
287
|
+
from ..core.state import State
|
|
288
|
+
from ..tool import ToolSet
|
|
289
|
+
from ..backends import Backends
|
|
290
|
+
|
|
291
|
+
# Auto-create backends if not provided
|
|
292
|
+
if backends is None:
|
|
293
|
+
backends = Backends.create_default()
|
|
294
|
+
|
|
295
|
+
# Validate storage backend is available
|
|
296
|
+
if backends.state is None:
|
|
297
|
+
raise ValueError("Cannot restore: no storage backend available (backends.state is None)")
|
|
298
|
+
|
|
299
|
+
storage = backends.state
|
|
300
|
+
|
|
301
|
+
# 1. Load session
|
|
302
|
+
session_data = await storage.get("sessions", session_id)
|
|
303
|
+
if not session_data:
|
|
304
|
+
raise SessionNotFoundError(f"Session not found: {session_id}")
|
|
305
|
+
session = Session.from_dict(session_data)
|
|
306
|
+
|
|
307
|
+
# 2. Load current invocation
|
|
308
|
+
invocation: Invocation | None = None
|
|
309
|
+
if session_data.get("current_invocation_id"):
|
|
310
|
+
inv_data = await storage.get("invocations", session_data["current_invocation_id"])
|
|
311
|
+
if inv_data:
|
|
312
|
+
invocation = Invocation.from_dict(inv_data)
|
|
313
|
+
|
|
314
|
+
# 3. Load state
|
|
315
|
+
state = State(storage, session_id)
|
|
316
|
+
await state.restore()
|
|
317
|
+
|
|
318
|
+
# 4. Handle tools
|
|
319
|
+
tool_set: ToolSet | None = None
|
|
320
|
+
if tools is not None:
|
|
321
|
+
if isinstance(tools, ToolSet):
|
|
322
|
+
tool_set = tools
|
|
323
|
+
else:
|
|
324
|
+
tool_set = ToolSet()
|
|
325
|
+
for tool in tools:
|
|
326
|
+
tool_set.add(tool)
|
|
327
|
+
else:
|
|
328
|
+
tool_set = ToolSet()
|
|
329
|
+
|
|
330
|
+
# 5. Create bus if needed
|
|
331
|
+
if bus is None:
|
|
332
|
+
bus = Bus()
|
|
333
|
+
|
|
334
|
+
# 6. Build context
|
|
335
|
+
ctx = InvocationContext(
|
|
336
|
+
session=session,
|
|
337
|
+
invocation_id=invocation.id if invocation else generate_id("inv"),
|
|
338
|
+
agent_id=config.name if config else "react_agent",
|
|
339
|
+
backends=backends,
|
|
340
|
+
bus=bus,
|
|
341
|
+
llm=llm,
|
|
342
|
+
tools=tool_set,
|
|
343
|
+
middleware=middleware,
|
|
344
|
+
memory=memory,
|
|
345
|
+
snapshot=snapshot,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# 7. Create agent
|
|
349
|
+
agent = cls(ctx, config)
|
|
350
|
+
agent._restored_invocation = invocation
|
|
351
|
+
agent._state = state
|
|
352
|
+
|
|
353
|
+
return agent
|
|
354
|
+
|
|
355
|
+
def __init__(
|
|
356
|
+
self,
|
|
357
|
+
ctx: InvocationContext,
|
|
358
|
+
config: AgentConfig | None = None,
|
|
359
|
+
):
|
|
360
|
+
"""Initialize ReactAgent.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
ctx: InvocationContext with llm, tools, storage, bus, session
|
|
364
|
+
config: Agent configuration
|
|
365
|
+
|
|
366
|
+
Raises:
|
|
367
|
+
ValueError: If ctx.llm or ctx.tools is None
|
|
368
|
+
"""
|
|
369
|
+
super().__init__(ctx, config)
|
|
370
|
+
|
|
371
|
+
# Validate required services
|
|
372
|
+
if ctx.llm is None:
|
|
373
|
+
raise ValueError("ReactAgent requires ctx.llm (LLMProvider)")
|
|
374
|
+
|
|
375
|
+
# Current execution state
|
|
376
|
+
self._current_invocation: Invocation | None = None
|
|
377
|
+
self._current_step: int = 0
|
|
378
|
+
self._message_history: list[LLMMessage] = []
|
|
379
|
+
self._text_buffer: str = ""
|
|
380
|
+
self._thinking_buffer: str = ""
|
|
381
|
+
self._tool_invocations: list[ToolInvocation] = []
|
|
382
|
+
|
|
383
|
+
# Block ID tracking for streaming (ensures consecutive deltas use same block_id)
|
|
384
|
+
self._current_text_block_id: str | None = None
|
|
385
|
+
self._current_thinking_block_id: str | None = None
|
|
386
|
+
|
|
387
|
+
# Tool call tracking for streaming arguments
|
|
388
|
+
self._call_id_to_tool: dict[str, str] = {} # call_id -> tool_name
|
|
389
|
+
self._tool_call_blocks: dict[str, str] = {} # call_id -> block_id
|
|
390
|
+
|
|
391
|
+
# Pause/resume support
|
|
392
|
+
self._paused = False
|
|
393
|
+
|
|
394
|
+
# Restore support
|
|
395
|
+
self._restored_invocation: "Invocation | None" = None
|
|
396
|
+
self._state: "State | None" = None
|
|
397
|
+
|
|
398
|
+
# Direct tools (passed to create())
|
|
399
|
+
self._tools: list["BaseTool"] = []
|
|
400
|
+
|
|
401
|
+
# ContextProviders for context engineering
|
|
402
|
+
self._context_providers: list[ContextProvider] = []
|
|
403
|
+
|
|
404
|
+
# DelegateTool class and middleware for dynamic subagent handling
|
|
405
|
+
self._delegate_tool_class: type | None = None
|
|
406
|
+
self._middleware_chain: "MiddlewareChain | None" = None
|
|
407
|
+
|
|
408
|
+
# Current AgentContext from providers (set by _fetch_agent_context)
|
|
409
|
+
self._agent_context: AgentContext | None = None
|
|
410
|
+
|
|
411
|
+
# ========== Suspension properties ==========
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def is_suspended(self) -> bool:
|
|
415
|
+
"""Check if agent is suspended (waiting for HITL input)."""
|
|
416
|
+
if self._restored_invocation:
|
|
417
|
+
return self._restored_invocation.state == InvocationState.SUSPENDED
|
|
418
|
+
return False
|
|
419
|
+
|
|
420
|
+
@property
|
|
421
|
+
def state(self) -> "State | None":
|
|
422
|
+
"""Get session state (for checkpoint/restore)."""
|
|
423
|
+
return self._state
|
|
424
|
+
|
|
425
|
+
# ========== Service accessors ==========
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def llm(self) -> "LLMProvider":
|
|
429
|
+
"""Get LLM provider (runtime override or context default)."""
|
|
430
|
+
# Check runtime override first
|
|
431
|
+
if self._run_config.get("llm") is not None:
|
|
432
|
+
return self._run_config["llm"]
|
|
433
|
+
return self._ctx.llm # type: ignore (validated in __init__)
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def snapshot(self):
|
|
437
|
+
"""Get snapshot backend from context."""
|
|
438
|
+
return self._ctx.snapshot
|
|
439
|
+
|
|
440
|
+
# ========== Runtime config helpers ==========
|
|
441
|
+
|
|
442
|
+
def _get_enable_thinking(self) -> bool:
|
|
443
|
+
"""Get enable_thinking (runtime override or config default)."""
|
|
444
|
+
if self._run_config.get("enable_thinking") is not None:
|
|
445
|
+
return self._run_config["enable_thinking"]
|
|
446
|
+
return self.config.enable_thinking
|
|
447
|
+
|
|
448
|
+
def _get_reasoning_effort(self) -> str | None:
|
|
449
|
+
"""Get reasoning_effort (runtime override or config default)."""
|
|
450
|
+
if self._run_config.get("reasoning_effort") is not None:
|
|
451
|
+
return self._run_config["reasoning_effort"]
|
|
452
|
+
return self.config.reasoning_effort
|
|
453
|
+
|
|
454
|
+
def _get_stream_thinking(self) -> bool:
|
|
455
|
+
"""Get stream_thinking (runtime override or config default)."""
|
|
456
|
+
if self._run_config.get("stream_thinking") is not None:
|
|
457
|
+
return self._run_config["stream_thinking"]
|
|
458
|
+
return self.config.stream_thinking
|
|
459
|
+
|
|
460
|
+
async def _execute(self, input: PromptInput | str) -> None:
|
|
461
|
+
"""Execute the React loop.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
input: User prompt input (PromptInput or str)
|
|
465
|
+
"""
|
|
466
|
+
# Normalize input
|
|
467
|
+
if isinstance(input, str):
|
|
468
|
+
input = PromptInput(text=input)
|
|
469
|
+
|
|
470
|
+
# NOTE: 如果需要 HITL 恢复到同一个 invocation(而不是创建新的),
|
|
471
|
+
# 可以检查 self._restored_invocation.state == SUSPENDED 并恢复精确状态。
|
|
472
|
+
# 当前设计:每次 run() 都创建新 invocation,HITL 回复也是新 invocation。
|
|
473
|
+
|
|
474
|
+
self.reset()
|
|
475
|
+
self._running = True
|
|
476
|
+
|
|
477
|
+
logger.info(
|
|
478
|
+
"Starting ReactAgent run",
|
|
479
|
+
extra={
|
|
480
|
+
"session_id": self.session.id,
|
|
481
|
+
"agent": self.name,
|
|
482
|
+
}
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Build middleware context
|
|
486
|
+
from ..core.context import emit as global_emit
|
|
487
|
+
mw_context = {
|
|
488
|
+
"session_id": self.session.id,
|
|
489
|
+
"agent_id": self.name,
|
|
490
|
+
"agent_type": self.agent_type,
|
|
491
|
+
"emit": global_emit, # For middleware to emit ActionEvent
|
|
492
|
+
"backends": self.ctx.backends,
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
try:
|
|
496
|
+
# Create new invocation
|
|
497
|
+
self._current_invocation = Invocation(
|
|
498
|
+
id=generate_id("inv"),
|
|
499
|
+
session_id=self.session.id,
|
|
500
|
+
state=InvocationState.RUNNING,
|
|
501
|
+
started_at=datetime.now(),
|
|
502
|
+
)
|
|
503
|
+
mw_context["invocation_id"] = self._current_invocation.id
|
|
504
|
+
|
|
505
|
+
logger.debug("Created invocation", extra={"invocation_id": self._current_invocation.id})
|
|
506
|
+
|
|
507
|
+
# === Middleware: on_agent_start ===
|
|
508
|
+
if self.middleware:
|
|
509
|
+
hook_result = await self.middleware.process_agent_start(
|
|
510
|
+
self.name, input, mw_context
|
|
511
|
+
)
|
|
512
|
+
if hook_result.action == HookAction.STOP:
|
|
513
|
+
logger.info("Agent stopped by middleware on_agent_start")
|
|
514
|
+
await self.ctx.emit(BlockEvent(
|
|
515
|
+
kind=BlockKind.ERROR,
|
|
516
|
+
op=BlockOp.APPLY,
|
|
517
|
+
data={"message": hook_result.message or "Stopped by middleware"},
|
|
518
|
+
))
|
|
519
|
+
return
|
|
520
|
+
elif hook_result.action == HookAction.SKIP:
|
|
521
|
+
logger.info("Agent skipped by middleware on_agent_start")
|
|
522
|
+
return
|
|
523
|
+
|
|
524
|
+
await self.bus.publish(
|
|
525
|
+
Events.INVOCATION_START,
|
|
526
|
+
{
|
|
527
|
+
"invocation_id": self._current_invocation.id,
|
|
528
|
+
"session_id": self.session.id,
|
|
529
|
+
},
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Build initial messages (loads history from storage)
|
|
533
|
+
self._message_history = await self._build_messages(input)
|
|
534
|
+
self._current_step = 0
|
|
535
|
+
|
|
536
|
+
# Save user message (real-time persistence)
|
|
537
|
+
await self._save_user_message(input)
|
|
538
|
+
|
|
539
|
+
# 3. Main loop
|
|
540
|
+
finish_reason = None
|
|
541
|
+
|
|
542
|
+
while not await self._check_abort():
|
|
543
|
+
self._current_step += 1
|
|
544
|
+
|
|
545
|
+
# Check step limit
|
|
546
|
+
if self._current_step > self.config.max_steps:
|
|
547
|
+
logger.warning(
|
|
548
|
+
"Max steps exceeded",
|
|
549
|
+
extra={
|
|
550
|
+
"max_steps": self.config.max_steps,
|
|
551
|
+
"invocation_id": self._current_invocation.id,
|
|
552
|
+
},
|
|
553
|
+
)
|
|
554
|
+
await self.ctx.emit(BlockEvent(
|
|
555
|
+
kind=BlockKind.ERROR,
|
|
556
|
+
op=BlockOp.APPLY,
|
|
557
|
+
data={"message": f"Max steps ({self.config.max_steps}) exceeded"},
|
|
558
|
+
))
|
|
559
|
+
break
|
|
560
|
+
|
|
561
|
+
# Take snapshot before step
|
|
562
|
+
snapshot_id = None
|
|
563
|
+
if self.snapshot:
|
|
564
|
+
snapshot_id = await self.snapshot.track()
|
|
565
|
+
|
|
566
|
+
# Execute step
|
|
567
|
+
finish_reason = await self._execute_step()
|
|
568
|
+
|
|
569
|
+
# Save assistant message (real-time persistence)
|
|
570
|
+
await self._save_assistant_message()
|
|
571
|
+
|
|
572
|
+
# Save message_history to state and checkpoint
|
|
573
|
+
if self._state:
|
|
574
|
+
self._save_messages_to_state()
|
|
575
|
+
await self._state.checkpoint()
|
|
576
|
+
|
|
577
|
+
# Check if we should exit
|
|
578
|
+
if finish_reason == "end_turn" and not self._tool_invocations:
|
|
579
|
+
break
|
|
580
|
+
|
|
581
|
+
# Process tool results and continue
|
|
582
|
+
if self._tool_invocations:
|
|
583
|
+
await self._process_tool_results()
|
|
584
|
+
|
|
585
|
+
# Save tool messages (real-time persistence)
|
|
586
|
+
await self._save_tool_messages()
|
|
587
|
+
|
|
588
|
+
self._tool_invocations.clear()
|
|
589
|
+
|
|
590
|
+
# Save message_history to state and checkpoint
|
|
591
|
+
if self._state:
|
|
592
|
+
self._save_messages_to_state()
|
|
593
|
+
await self._state.checkpoint()
|
|
594
|
+
|
|
595
|
+
# 4. Check if aborted
|
|
596
|
+
is_aborted = self.is_cancelled
|
|
597
|
+
|
|
598
|
+
# 5. Complete invocation
|
|
599
|
+
if is_aborted:
|
|
600
|
+
self._current_invocation.state = InvocationState.ABORTED
|
|
601
|
+
else:
|
|
602
|
+
self._current_invocation.state = InvocationState.COMPLETED
|
|
603
|
+
self._current_invocation.finished_at = datetime.now()
|
|
604
|
+
|
|
605
|
+
# Save to invocation backend
|
|
606
|
+
if self.ctx.backends and self.ctx.backends.invocation:
|
|
607
|
+
await self.ctx.backends.invocation.update(
|
|
608
|
+
self._current_invocation.id,
|
|
609
|
+
self._current_invocation.to_dict(),
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
duration_ms = self._current_invocation.duration_ms or 0
|
|
613
|
+
logger.info(
|
|
614
|
+
f"ReactAgent run {'aborted' if is_aborted else 'completed'}",
|
|
615
|
+
extra={
|
|
616
|
+
"invocation_id": self._current_invocation.id,
|
|
617
|
+
"steps": self._current_step,
|
|
618
|
+
"duration_ms": duration_ms,
|
|
619
|
+
"finish_reason": "aborted" if is_aborted else finish_reason,
|
|
620
|
+
},
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# === Middleware: on_agent_end ===
|
|
624
|
+
if self.middleware:
|
|
625
|
+
await self.middleware.process_agent_end(
|
|
626
|
+
self.name,
|
|
627
|
+
{"steps": self._current_step, "finish_reason": finish_reason},
|
|
628
|
+
mw_context,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
await self.bus.publish(
|
|
632
|
+
Events.INVOCATION_END,
|
|
633
|
+
{
|
|
634
|
+
"invocation_id": self._current_invocation.id,
|
|
635
|
+
"steps": self._current_step,
|
|
636
|
+
"state": self._current_invocation.state.value,
|
|
637
|
+
},
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Clear message_history from State after successful completion
|
|
641
|
+
# Historical messages are already persisted (truncated) via MessageStore
|
|
642
|
+
self._clear_messages_from_state()
|
|
643
|
+
if self._state:
|
|
644
|
+
await self._state.checkpoint()
|
|
645
|
+
|
|
646
|
+
except SuspendSignal as e:
|
|
647
|
+
# HITL/Suspend signal - invocation waits for user input
|
|
648
|
+
logger.info(
|
|
649
|
+
"Agent suspended",
|
|
650
|
+
extra={
|
|
651
|
+
"invocation_id": self._current_invocation.id
|
|
652
|
+
if self._current_invocation
|
|
653
|
+
else None,
|
|
654
|
+
"signal_type": type(e).__name__,
|
|
655
|
+
},
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
if self._current_invocation:
|
|
659
|
+
self._current_invocation.state = InvocationState.SUSPENDED
|
|
660
|
+
|
|
661
|
+
# Save invocation state
|
|
662
|
+
if self.ctx.backends and self.ctx.backends.invocation:
|
|
663
|
+
await self.ctx.backends.invocation.update(
|
|
664
|
+
self._current_invocation.id,
|
|
665
|
+
self._current_invocation.to_dict(),
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Save pending_request to execution state
|
|
669
|
+
if self._state:
|
|
670
|
+
self._state.execution["pending_request"] = e.to_dict()
|
|
671
|
+
self._save_messages_to_state()
|
|
672
|
+
await self._state.checkpoint()
|
|
673
|
+
|
|
674
|
+
# Don't raise - just return to exit cleanly
|
|
675
|
+
return
|
|
676
|
+
|
|
677
|
+
except Exception as e:
|
|
678
|
+
logger.error(
|
|
679
|
+
"ReactAgent run failed",
|
|
680
|
+
extra={
|
|
681
|
+
"error": str(e),
|
|
682
|
+
"invocation_id": self._current_invocation.id
|
|
683
|
+
if self._current_invocation
|
|
684
|
+
else None,
|
|
685
|
+
},
|
|
686
|
+
exc_info=True,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# === Middleware: on_error ===
|
|
690
|
+
if self.middleware:
|
|
691
|
+
processed_error = await self.middleware.process_error(e, mw_context)
|
|
692
|
+
if processed_error is None:
|
|
693
|
+
# Error suppressed by middleware
|
|
694
|
+
logger.info("Error suppressed by middleware")
|
|
695
|
+
return
|
|
696
|
+
|
|
697
|
+
if self._current_invocation:
|
|
698
|
+
self._current_invocation.state = InvocationState.FAILED
|
|
699
|
+
self._current_invocation.finished_at = datetime.now()
|
|
700
|
+
|
|
701
|
+
await self.ctx.emit(BlockEvent(
|
|
702
|
+
kind=BlockKind.ERROR,
|
|
703
|
+
op=BlockOp.APPLY,
|
|
704
|
+
data={"message": str(e)},
|
|
705
|
+
))
|
|
706
|
+
raise
|
|
707
|
+
|
|
708
|
+
finally:
|
|
709
|
+
self._running = False
|
|
710
|
+
self._restored_invocation = None
|
|
711
|
+
|
|
712
|
+
async def pause(self) -> str:
|
|
713
|
+
"""Pause execution and return invocation ID for later resume.
|
|
714
|
+
|
|
715
|
+
Saves current state to the invocation for later resumption.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
Invocation ID for resuming
|
|
719
|
+
"""
|
|
720
|
+
if not self._current_invocation:
|
|
721
|
+
raise RuntimeError("No active invocation to pause")
|
|
722
|
+
|
|
723
|
+
# Mark as paused
|
|
724
|
+
self._paused = True
|
|
725
|
+
self._current_invocation.mark_paused()
|
|
726
|
+
|
|
727
|
+
# Save state for resumption
|
|
728
|
+
self._current_invocation.agent_state = {
|
|
729
|
+
"step": self._current_step,
|
|
730
|
+
"message_history": [
|
|
731
|
+
{"role": m.role, "content": m.content} for m in self._message_history
|
|
732
|
+
],
|
|
733
|
+
"text_buffer": self._text_buffer,
|
|
734
|
+
}
|
|
735
|
+
self._current_invocation.step_count = self._current_step
|
|
736
|
+
|
|
737
|
+
# Save pending tool calls
|
|
738
|
+
self._current_invocation.pending_tool_ids = [
|
|
739
|
+
inv.tool_call_id
|
|
740
|
+
for inv in self._tool_invocations
|
|
741
|
+
if inv.state == ToolInvocationState.CALL
|
|
742
|
+
]
|
|
743
|
+
|
|
744
|
+
# Persist invocation
|
|
745
|
+
if self.ctx.backends and self.ctx.backends.invocation:
|
|
746
|
+
await self.ctx.backends.invocation.update(
|
|
747
|
+
self._current_invocation.id,
|
|
748
|
+
self._current_invocation.to_dict(),
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
await self.bus.publish(
|
|
752
|
+
Events.INVOCATION_PAUSE,
|
|
753
|
+
{
|
|
754
|
+
"invocation_id": self._current_invocation.id,
|
|
755
|
+
"step": self._current_step,
|
|
756
|
+
},
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
return self._current_invocation.id
|
|
760
|
+
|
|
761
|
+
async def _resume_internal(self, invocation_id: str) -> None:
|
|
762
|
+
"""Internal resume logic using emit."""
|
|
763
|
+
# Load invocation
|
|
764
|
+
if not self.ctx.backends or not self.ctx.backends.invocation:
|
|
765
|
+
raise ValueError("No invocation backend available")
|
|
766
|
+
inv_data = await self.ctx.backends.invocation.get(invocation_id)
|
|
767
|
+
if not inv_data:
|
|
768
|
+
raise ValueError(f"Invocation not found: {invocation_id}")
|
|
769
|
+
|
|
770
|
+
invocation = Invocation.from_dict(inv_data)
|
|
771
|
+
|
|
772
|
+
if invocation.state != InvocationState.PAUSED:
|
|
773
|
+
raise ValueError(f"Invocation is not paused: {invocation.state}")
|
|
774
|
+
|
|
775
|
+
# Restore state
|
|
776
|
+
self._current_invocation = invocation
|
|
777
|
+
self._paused = False
|
|
778
|
+
self._running = True
|
|
779
|
+
|
|
780
|
+
agent_state = invocation.agent_state or {}
|
|
781
|
+
self._current_step = agent_state.get("step", 0)
|
|
782
|
+
self._text_buffer = agent_state.get("text_buffer", "")
|
|
783
|
+
|
|
784
|
+
# Restore message history
|
|
785
|
+
self._message_history = [
|
|
786
|
+
LLMMessage(role=m["role"], content=m["content"])
|
|
787
|
+
for m in agent_state.get("message_history", [])
|
|
788
|
+
]
|
|
789
|
+
|
|
790
|
+
# Mark as running
|
|
791
|
+
invocation.state = InvocationState.RUNNING
|
|
792
|
+
|
|
793
|
+
await self.bus.publish(
|
|
794
|
+
Events.INVOCATION_RESUME,
|
|
795
|
+
{
|
|
796
|
+
"invocation_id": invocation_id,
|
|
797
|
+
"step": self._current_step,
|
|
798
|
+
},
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# Continue execution loop
|
|
802
|
+
try:
|
|
803
|
+
finish_reason = None
|
|
804
|
+
|
|
805
|
+
while not await self._check_abort() and not self._paused:
|
|
806
|
+
self._current_step += 1
|
|
807
|
+
|
|
808
|
+
if self._current_step > self.config.max_steps:
|
|
809
|
+
await self.ctx.emit(BlockEvent(
|
|
810
|
+
kind=BlockKind.ERROR,
|
|
811
|
+
op=BlockOp.APPLY,
|
|
812
|
+
data={"message": f"Max steps ({self.config.max_steps}) exceeded"},
|
|
813
|
+
))
|
|
814
|
+
break
|
|
815
|
+
|
|
816
|
+
finish_reason = await self._execute_step()
|
|
817
|
+
|
|
818
|
+
# Save assistant message (real-time persistence)
|
|
819
|
+
await self._save_assistant_message()
|
|
820
|
+
|
|
821
|
+
if finish_reason == "end_turn" and not self._tool_invocations:
|
|
822
|
+
break
|
|
823
|
+
|
|
824
|
+
if self._tool_invocations:
|
|
825
|
+
await self._process_tool_results()
|
|
826
|
+
|
|
827
|
+
# Save tool messages (real-time persistence)
|
|
828
|
+
await self._save_tool_messages()
|
|
829
|
+
|
|
830
|
+
self._tool_invocations.clear()
|
|
831
|
+
|
|
832
|
+
if not self._paused:
|
|
833
|
+
self._current_invocation.state = InvocationState.COMPLETED
|
|
834
|
+
self._current_invocation.finished_at = datetime.now()
|
|
835
|
+
|
|
836
|
+
except Exception as e:
|
|
837
|
+
self._current_invocation.state = InvocationState.FAILED
|
|
838
|
+
await self.ctx.emit(BlockEvent(
|
|
839
|
+
kind=BlockKind.ERROR,
|
|
840
|
+
op=BlockOp.APPLY,
|
|
841
|
+
data={"message": str(e)},
|
|
842
|
+
))
|
|
843
|
+
raise
|
|
844
|
+
|
|
845
|
+
finally:
|
|
846
|
+
self._running = False
|
|
847
|
+
|
|
848
|
+
async def resume(self, invocation_id: str) -> AsyncIterator[BlockEvent]:
|
|
849
|
+
"""Resume paused execution.
|
|
850
|
+
|
|
851
|
+
Args:
|
|
852
|
+
invocation_id: ID from pause()
|
|
853
|
+
|
|
854
|
+
Yields:
|
|
855
|
+
BlockEvent streaming events
|
|
856
|
+
"""
|
|
857
|
+
from ..core.context import _emit_queue_var
|
|
858
|
+
|
|
859
|
+
queue: asyncio.Queue[BlockEvent] = asyncio.Queue()
|
|
860
|
+
token = _emit_queue_var.set(queue)
|
|
861
|
+
|
|
862
|
+
try:
|
|
863
|
+
exec_task = asyncio.create_task(self._resume_internal(invocation_id))
|
|
864
|
+
get_task: asyncio.Task | None = None
|
|
865
|
+
|
|
866
|
+
# Event-driven processing - no timeout delays
|
|
867
|
+
while True:
|
|
868
|
+
# First drain any pending items from queue (non-blocking)
|
|
869
|
+
while True:
|
|
870
|
+
try:
|
|
871
|
+
block = queue.get_nowait()
|
|
872
|
+
yield block
|
|
873
|
+
except asyncio.QueueEmpty:
|
|
874
|
+
break
|
|
875
|
+
|
|
876
|
+
# Exit if task is done and queue is empty
|
|
877
|
+
if exec_task.done() and queue.empty():
|
|
878
|
+
break
|
|
879
|
+
|
|
880
|
+
# Create get_task if needed
|
|
881
|
+
if get_task is None or get_task.done():
|
|
882
|
+
get_task = asyncio.create_task(queue.get())
|
|
883
|
+
|
|
884
|
+
# Wait for EITHER: queue item OR exec_task completion
|
|
885
|
+
done, _ = await asyncio.wait(
|
|
886
|
+
{get_task, exec_task},
|
|
887
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
if get_task in done:
|
|
891
|
+
try:
|
|
892
|
+
block = get_task.result()
|
|
893
|
+
yield block
|
|
894
|
+
get_task = None
|
|
895
|
+
except asyncio.CancelledError:
|
|
896
|
+
pass
|
|
897
|
+
|
|
898
|
+
# Cancel pending get_task if any
|
|
899
|
+
if get_task and not get_task.done():
|
|
900
|
+
get_task.cancel()
|
|
901
|
+
try:
|
|
902
|
+
await get_task
|
|
903
|
+
except asyncio.CancelledError:
|
|
904
|
+
pass
|
|
905
|
+
|
|
906
|
+
# Final drain after task completion
|
|
907
|
+
while not queue.empty():
|
|
908
|
+
try:
|
|
909
|
+
block = queue.get_nowait()
|
|
910
|
+
yield block
|
|
911
|
+
except asyncio.QueueEmpty:
|
|
912
|
+
break
|
|
913
|
+
|
|
914
|
+
await exec_task
|
|
915
|
+
|
|
916
|
+
finally:
|
|
917
|
+
_emit_queue_var.reset(token)
|
|
918
|
+
|
|
919
|
+
async def _fetch_agent_context(self, input: PromptInput) -> AgentContext:
|
|
920
|
+
"""Fetch context from all providers and merge with direct tools.
|
|
921
|
+
|
|
922
|
+
Process:
|
|
923
|
+
1. Fetch from all providers and merge
|
|
924
|
+
2. Add direct tools (from create())
|
|
925
|
+
3. If providers returned subagents, create DelegateTool
|
|
926
|
+
|
|
927
|
+
Also sets ctx.input for providers to access.
|
|
928
|
+
"""
|
|
929
|
+
from ..tool.builtin import DelegateTool
|
|
930
|
+
from ..backends.subagent import ListSubAgentBackend
|
|
931
|
+
|
|
932
|
+
# Set input on context for providers to access
|
|
933
|
+
self._ctx.input = input
|
|
934
|
+
|
|
935
|
+
# Fetch from all context_providers
|
|
936
|
+
outputs: list[AgentContext] = []
|
|
937
|
+
for provider in self._context_providers:
|
|
938
|
+
try:
|
|
939
|
+
output = await provider.fetch(self._ctx)
|
|
940
|
+
outputs.append(output)
|
|
941
|
+
except Exception as e:
|
|
942
|
+
logger.warning(f"Provider {provider.name} fetch failed: {e}")
|
|
943
|
+
|
|
944
|
+
# Merge all provider outputs
|
|
945
|
+
merged = AgentContext.merge(outputs)
|
|
946
|
+
|
|
947
|
+
# Add direct tools (from create())
|
|
948
|
+
all_tools = list(self._tools) # Copy direct tools
|
|
949
|
+
seen_names = {t.name for t in all_tools}
|
|
950
|
+
|
|
951
|
+
# Add tools from providers (deduplicate)
|
|
952
|
+
for tool in merged.tools:
|
|
953
|
+
if tool.name not in seen_names:
|
|
954
|
+
seen_names.add(tool.name)
|
|
955
|
+
all_tools.append(tool)
|
|
956
|
+
|
|
957
|
+
# If providers returned subagents, create DelegateTool
|
|
958
|
+
if merged.subagents:
|
|
959
|
+
# Check if we already have a delegate tool
|
|
960
|
+
has_delegate = any(t.name == "delegate" for t in all_tools)
|
|
961
|
+
if not has_delegate:
|
|
962
|
+
backend = ListSubAgentBackend(merged.subagents)
|
|
963
|
+
tool_cls = self._delegate_tool_class or DelegateTool
|
|
964
|
+
delegate_tool = tool_cls(backend, middleware=self._middleware_chain)
|
|
965
|
+
all_tools.append(delegate_tool)
|
|
966
|
+
|
|
967
|
+
# Return merged context with combined tools
|
|
968
|
+
return AgentContext(
|
|
969
|
+
system_content=merged.system_content,
|
|
970
|
+
user_content=merged.user_content,
|
|
971
|
+
tools=all_tools,
|
|
972
|
+
messages=merged.messages,
|
|
973
|
+
subagents=merged.subagents,
|
|
974
|
+
skills=merged.skills,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
async def _build_messages(self, input: PromptInput) -> list[LLMMessage]:
|
|
978
|
+
"""Build message history for LLM.
|
|
979
|
+
|
|
980
|
+
Uses AgentContext from providers for system content, messages, etc.
|
|
981
|
+
"""
|
|
982
|
+
messages = []
|
|
983
|
+
|
|
984
|
+
# Fetch context from providers
|
|
985
|
+
self._agent_context = await self._fetch_agent_context(input)
|
|
986
|
+
|
|
987
|
+
# System message: config.system_prompt + agent_context.system_content
|
|
988
|
+
system_prompt = self.config.system_prompt or self._default_system_prompt()
|
|
989
|
+
if self._agent_context.system_content:
|
|
990
|
+
system_prompt = system_prompt + "\n\n" + self._agent_context.system_content
|
|
991
|
+
messages.append(LLMMessage(role="system", content=system_prompt))
|
|
992
|
+
|
|
993
|
+
# Historical messages from AgentContext (provided by MessageContextProvider)
|
|
994
|
+
for msg in self._agent_context.messages:
|
|
995
|
+
messages.append(LLMMessage(
|
|
996
|
+
role=msg.get("role", "user"),
|
|
997
|
+
content=msg.get("content", ""),
|
|
998
|
+
))
|
|
999
|
+
|
|
1000
|
+
# User content prefix (from providers) + current user message
|
|
1001
|
+
content = input.text
|
|
1002
|
+
if self._agent_context.user_content:
|
|
1003
|
+
content = self._agent_context.user_content + "\n\n" + content
|
|
1004
|
+
|
|
1005
|
+
if input.attachments:
|
|
1006
|
+
# Build multimodal content
|
|
1007
|
+
content_parts = [{"type": "text", "text": content}]
|
|
1008
|
+
for attachment in input.attachments:
|
|
1009
|
+
content_parts.append(attachment)
|
|
1010
|
+
content = content_parts
|
|
1011
|
+
|
|
1012
|
+
messages.append(LLMMessage(role="user", content=content))
|
|
1013
|
+
|
|
1014
|
+
return messages
|
|
1015
|
+
|
|
1016
|
+
def _default_system_prompt(self) -> str:
|
|
1017
|
+
"""Generate default system prompt with tool descriptions."""
|
|
1018
|
+
# Get tools from AgentContext (from providers)
|
|
1019
|
+
all_tools = self._agent_context.tools if self._agent_context else []
|
|
1020
|
+
|
|
1021
|
+
tool_list = []
|
|
1022
|
+
for tool in all_tools:
|
|
1023
|
+
info = tool.get_info()
|
|
1024
|
+
tool_list.append(f"- {info.name}: {info.description}")
|
|
1025
|
+
|
|
1026
|
+
tools_desc = "\n".join(tool_list) if tool_list else "No tools available."
|
|
1027
|
+
|
|
1028
|
+
return f"""You are a helpful AI assistant with access to tools.
|
|
1029
|
+
|
|
1030
|
+
Available tools:
|
|
1031
|
+
{tools_desc}
|
|
1032
|
+
|
|
1033
|
+
When you need to use a tool, make a tool call. After receiving the tool result, continue reasoning or provide your final response.
|
|
1034
|
+
|
|
1035
|
+
Think step by step and use tools when necessary to complete the user's request."""
|
|
1036
|
+
|
|
1037
|
+
def _get_effective_tool_mode(self) -> ToolInjectionMode:
|
|
1038
|
+
"""Get effective tool mode (auto-detect based on model capabilities).
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
FUNCTION_CALL if model supports tools, else PROMPT
|
|
1042
|
+
"""
|
|
1043
|
+
# If explicitly set to PROMPT, use PROMPT
|
|
1044
|
+
if self.config.tool_mode == ToolInjectionMode.PROMPT:
|
|
1045
|
+
return ToolInjectionMode.PROMPT
|
|
1046
|
+
|
|
1047
|
+
# Auto-detect: if model doesn't support tools, use PROMPT
|
|
1048
|
+
caps = self.llm.capabilities
|
|
1049
|
+
if not caps.supports_tools:
|
|
1050
|
+
logger.info(
|
|
1051
|
+
f"Model {self.llm.model} does not support function calling, "
|
|
1052
|
+
"auto-switching to PROMPT mode for tools"
|
|
1053
|
+
)
|
|
1054
|
+
return ToolInjectionMode.PROMPT
|
|
1055
|
+
|
|
1056
|
+
return ToolInjectionMode.FUNCTION_CALL
|
|
1057
|
+
|
|
1058
|
+
def _build_tool_prompt(self, tools: list) -> str:
|
|
1059
|
+
"""Build tool description for PROMPT mode injection.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
tools: List of BaseTool objects
|
|
1063
|
+
|
|
1064
|
+
Returns:
|
|
1065
|
+
Tool prompt string to inject into system message
|
|
1066
|
+
"""
|
|
1067
|
+
if not tools:
|
|
1068
|
+
return ""
|
|
1069
|
+
|
|
1070
|
+
tool_descriptions = []
|
|
1071
|
+
for tool in tools:
|
|
1072
|
+
info = tool.get_info()
|
|
1073
|
+
# Build parameter description
|
|
1074
|
+
params_desc = ""
|
|
1075
|
+
if info.parameters and "properties" in info.parameters:
|
|
1076
|
+
params = []
|
|
1077
|
+
properties = info.parameters.get("properties", {})
|
|
1078
|
+
required = info.parameters.get("required", [])
|
|
1079
|
+
for name, schema in properties.items():
|
|
1080
|
+
param_type = schema.get("type", "any")
|
|
1081
|
+
param_desc = schema.get("description", "")
|
|
1082
|
+
is_required = "required" if name in required else "optional"
|
|
1083
|
+
params.append(f" - {name} ({param_type}, {is_required}): {param_desc}")
|
|
1084
|
+
params_desc = "\n" + "\n".join(params) if params else ""
|
|
1085
|
+
|
|
1086
|
+
tool_descriptions.append(
|
|
1087
|
+
f"### {info.name}\n"
|
|
1088
|
+
f"{info.description}{params_desc}"
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
return f"""## Available Tools
|
|
1092
|
+
|
|
1093
|
+
You have access to the following tools. To use a tool, output a JSON block in this exact format:
|
|
1094
|
+
|
|
1095
|
+
```tool_call
|
|
1096
|
+
{{
|
|
1097
|
+
"tool": "tool_name",
|
|
1098
|
+
"arguments": {{
|
|
1099
|
+
"param1": "value1",
|
|
1100
|
+
"param2": "value2"
|
|
1101
|
+
}}
|
|
1102
|
+
}}
|
|
1103
|
+
```
|
|
1104
|
+
|
|
1105
|
+
IMPORTANT:
|
|
1106
|
+
- Use the exact format above with ```tool_call code block
|
|
1107
|
+
- You can make multiple tool calls in one response
|
|
1108
|
+
- Wait for tool results before continuing
|
|
1109
|
+
|
|
1110
|
+
{chr(10).join(tool_descriptions)}
|
|
1111
|
+
"""
|
|
1112
|
+
|
|
1113
|
+
def _parse_tool_calls_from_text(self, text: str) -> list[dict]:
|
|
1114
|
+
"""Parse tool calls from LLM text output (for PROMPT mode).
|
|
1115
|
+
|
|
1116
|
+
Looks for ```tool_call blocks in the format:
|
|
1117
|
+
```tool_call
|
|
1118
|
+
{"tool": "name", "arguments": {...}}
|
|
1119
|
+
```
|
|
1120
|
+
|
|
1121
|
+
Args:
|
|
1122
|
+
text: LLM output text
|
|
1123
|
+
|
|
1124
|
+
Returns:
|
|
1125
|
+
List of parsed tool calls: [{"name": str, "arguments": dict}, ...]
|
|
1126
|
+
"""
|
|
1127
|
+
import re
|
|
1128
|
+
|
|
1129
|
+
tool_calls = []
|
|
1130
|
+
|
|
1131
|
+
# Match ```tool_call ... ``` blocks
|
|
1132
|
+
pattern = r"```tool_call\s*\n?(.+?)\n?```"
|
|
1133
|
+
matches = re.findall(pattern, text, re.DOTALL)
|
|
1134
|
+
|
|
1135
|
+
for match in matches:
|
|
1136
|
+
try:
|
|
1137
|
+
data = json.loads(match.strip())
|
|
1138
|
+
if "tool" in data:
|
|
1139
|
+
tool_calls.append({
|
|
1140
|
+
"name": data["tool"],
|
|
1141
|
+
"arguments": data.get("arguments", {}),
|
|
1142
|
+
})
|
|
1143
|
+
except json.JSONDecodeError as e:
|
|
1144
|
+
logger.warning(f"Failed to parse tool call JSON: {e}")
|
|
1145
|
+
continue
|
|
1146
|
+
|
|
1147
|
+
return tool_calls
|
|
1148
|
+
|
|
1149
|
+
async def _execute_step(self) -> str | None:
|
|
1150
|
+
"""Execute a single LLM step with middleware hooks.
|
|
1151
|
+
|
|
1152
|
+
Returns:
|
|
1153
|
+
finish_reason from LLM
|
|
1154
|
+
"""
|
|
1155
|
+
# Get tools from AgentContext (from providers)
|
|
1156
|
+
all_tools = self._agent_context.tools if self._agent_context else []
|
|
1157
|
+
|
|
1158
|
+
# Determine effective tool mode (auto-detect based on capabilities)
|
|
1159
|
+
effective_tool_mode = self._get_effective_tool_mode()
|
|
1160
|
+
|
|
1161
|
+
# Get tool definitions (only for FUNCTION_CALL mode)
|
|
1162
|
+
tool_defs = None
|
|
1163
|
+
if effective_tool_mode == ToolInjectionMode.FUNCTION_CALL and all_tools:
|
|
1164
|
+
tool_defs = [
|
|
1165
|
+
ToolDefinition(
|
|
1166
|
+
name=t.name,
|
|
1167
|
+
description=t.description,
|
|
1168
|
+
input_schema=t.parameters,
|
|
1169
|
+
)
|
|
1170
|
+
for t in all_tools
|
|
1171
|
+
]
|
|
1172
|
+
|
|
1173
|
+
# For PROMPT mode, inject tools into system message
|
|
1174
|
+
if effective_tool_mode == ToolInjectionMode.PROMPT and all_tools:
|
|
1175
|
+
tool_prompt = self._build_tool_prompt(all_tools)
|
|
1176
|
+
# Inject into first system message
|
|
1177
|
+
if self._message_history and self._message_history[0].role == "system":
|
|
1178
|
+
original_content = self._message_history[0].content
|
|
1179
|
+
self._message_history[0] = LLMMessage(
|
|
1180
|
+
role="system",
|
|
1181
|
+
content=f"{original_content}\n\n{tool_prompt}",
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
# Reset buffers
|
|
1185
|
+
self._text_buffer = ""
|
|
1186
|
+
self._thinking_buffer = "" # Buffer for non-streaming thinking
|
|
1187
|
+
self._tool_invocations = []
|
|
1188
|
+
current_tool_invocation: ToolInvocation | None = None
|
|
1189
|
+
|
|
1190
|
+
# Reset block IDs for this step (each step gets fresh block IDs)
|
|
1191
|
+
self._current_text_block_id = None
|
|
1192
|
+
self._current_thinking_block_id = None
|
|
1193
|
+
|
|
1194
|
+
# Reset tool call tracking
|
|
1195
|
+
self._call_id_to_tool = {}
|
|
1196
|
+
self._tool_call_blocks = {}
|
|
1197
|
+
|
|
1198
|
+
# Build middleware context for this step
|
|
1199
|
+
from ..core.context import emit as global_emit
|
|
1200
|
+
mw_context = {
|
|
1201
|
+
"session_id": self.session.id,
|
|
1202
|
+
"invocation_id": self._current_invocation.id if self._current_invocation else "",
|
|
1203
|
+
"step": self._current_step,
|
|
1204
|
+
"agent_id": self.name,
|
|
1205
|
+
"emit": global_emit, # For middleware to emit BlockEvent/ActionEvent
|
|
1206
|
+
"backends": self.ctx.backends,
|
|
1207
|
+
"tool_mode": effective_tool_mode.value, # Add tool mode to context
|
|
1208
|
+
}
|
|
1209
|
+
|
|
1210
|
+
# Build LLM call kwargs
|
|
1211
|
+
# Note: temperature, max_tokens, timeout, retries are configured on LLMProvider
|
|
1212
|
+
llm_kwargs: dict[str, Any] = {
|
|
1213
|
+
"messages": self._message_history,
|
|
1214
|
+
"tools": tool_defs, # None for PROMPT mode
|
|
1215
|
+
}
|
|
1216
|
+
|
|
1217
|
+
# Get model capabilities
|
|
1218
|
+
caps = self.llm.capabilities
|
|
1219
|
+
|
|
1220
|
+
# Add thinking configuration (use runtime override if set)
|
|
1221
|
+
# Only if model supports thinking
|
|
1222
|
+
enable_thinking = self._get_enable_thinking()
|
|
1223
|
+
reasoning_effort = self._get_reasoning_effort()
|
|
1224
|
+
if enable_thinking:
|
|
1225
|
+
if caps.supports_thinking:
|
|
1226
|
+
llm_kwargs["enable_thinking"] = True
|
|
1227
|
+
if reasoning_effort:
|
|
1228
|
+
llm_kwargs["reasoning_effort"] = reasoning_effort
|
|
1229
|
+
else:
|
|
1230
|
+
logger.debug(
|
|
1231
|
+
f"Model {self.llm.model} does not support thinking, "
|
|
1232
|
+
"enable_thinking will be ignored"
|
|
1233
|
+
)
|
|
1234
|
+
|
|
1235
|
+
# === Middleware: on_request ===
|
|
1236
|
+
if self.middleware:
|
|
1237
|
+
llm_kwargs = await self.middleware.process_request(llm_kwargs, mw_context)
|
|
1238
|
+
if llm_kwargs is None:
|
|
1239
|
+
logger.info("LLM request cancelled by middleware")
|
|
1240
|
+
return None
|
|
1241
|
+
|
|
1242
|
+
# Debug: log message history before LLM call
|
|
1243
|
+
logger.debug(
|
|
1244
|
+
f"LLM call - Step {self._current_step}, messages: {len(self._message_history)}, "
|
|
1245
|
+
f"tools: {len(tool_defs) if tool_defs else 0}"
|
|
1246
|
+
)
|
|
1247
|
+
# Detailed message log (for debugging model issues like repeated calls)
|
|
1248
|
+
for i, msg in enumerate(self._message_history):
|
|
1249
|
+
content_preview = str(msg.content)[:300] if msg.content else "<empty>"
|
|
1250
|
+
tool_call_id = getattr(msg, 'tool_call_id', None)
|
|
1251
|
+
logger.debug(
|
|
1252
|
+
f" msg[{i}] role={msg.role}"
|
|
1253
|
+
f"{f', tool_call_id={tool_call_id}' if tool_call_id else ''}"
|
|
1254
|
+
f", content={content_preview}"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
# Call LLM
|
|
1258
|
+
await self.bus.publish(
|
|
1259
|
+
Events.LLM_START,
|
|
1260
|
+
{
|
|
1261
|
+
"provider": self.llm.provider,
|
|
1262
|
+
"model": self.llm.model,
|
|
1263
|
+
"step": self._current_step,
|
|
1264
|
+
"enable_thinking": enable_thinking,
|
|
1265
|
+
},
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
finish_reason = None
|
|
1269
|
+
llm_response_data: dict[str, Any] = {} # Collect response for middleware
|
|
1270
|
+
|
|
1271
|
+
# Reset middleware stream state
|
|
1272
|
+
if self.middleware:
|
|
1273
|
+
self.middleware.reset_stream_state()
|
|
1274
|
+
|
|
1275
|
+
async for event in self.llm.complete(**llm_kwargs):
|
|
1276
|
+
if await self._check_abort():
|
|
1277
|
+
break
|
|
1278
|
+
|
|
1279
|
+
if event.type == "content":
|
|
1280
|
+
# Text content
|
|
1281
|
+
if event.delta:
|
|
1282
|
+
# === Middleware: on_model_stream ===
|
|
1283
|
+
stream_chunk = {"delta": event.delta, "type": "content"}
|
|
1284
|
+
if self.middleware:
|
|
1285
|
+
stream_chunk = await self.middleware.process_stream_chunk(
|
|
1286
|
+
stream_chunk, mw_context
|
|
1287
|
+
)
|
|
1288
|
+
if stream_chunk is None:
|
|
1289
|
+
continue # Skip this chunk
|
|
1290
|
+
|
|
1291
|
+
delta = stream_chunk.get("delta", event.delta)
|
|
1292
|
+
self._text_buffer += delta
|
|
1293
|
+
|
|
1294
|
+
# Reuse or create block_id for text streaming
|
|
1295
|
+
if self._current_text_block_id is None:
|
|
1296
|
+
self._current_text_block_id = generate_id("blk")
|
|
1297
|
+
|
|
1298
|
+
await self.ctx.emit(BlockEvent(
|
|
1299
|
+
block_id=self._current_text_block_id,
|
|
1300
|
+
kind=BlockKind.TEXT,
|
|
1301
|
+
op=BlockOp.DELTA,
|
|
1302
|
+
data={"content": delta},
|
|
1303
|
+
))
|
|
1304
|
+
|
|
1305
|
+
await self.bus.publish(
|
|
1306
|
+
Events.LLM_STREAM,
|
|
1307
|
+
{
|
|
1308
|
+
"delta": delta,
|
|
1309
|
+
"step": self._current_step,
|
|
1310
|
+
},
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
elif event.type == "thinking":
|
|
1314
|
+
# Thinking content - only emit if thinking is enabled
|
|
1315
|
+
stream_thinking = self._get_stream_thinking()
|
|
1316
|
+
if event.delta and enable_thinking:
|
|
1317
|
+
if stream_thinking:
|
|
1318
|
+
# Reuse or create block_id for thinking streaming
|
|
1319
|
+
if self._current_thinking_block_id is None:
|
|
1320
|
+
self._current_thinking_block_id = generate_id("blk")
|
|
1321
|
+
|
|
1322
|
+
# Stream thinking in real-time
|
|
1323
|
+
await self.ctx.emit(BlockEvent(
|
|
1324
|
+
block_id=self._current_thinking_block_id,
|
|
1325
|
+
kind=BlockKind.THINKING,
|
|
1326
|
+
op=BlockOp.DELTA,
|
|
1327
|
+
data={"content": event.delta},
|
|
1328
|
+
))
|
|
1329
|
+
else:
|
|
1330
|
+
# Buffer thinking for batch output
|
|
1331
|
+
self._thinking_buffer += event.delta
|
|
1332
|
+
|
|
1333
|
+
elif event.type == "tool_call_start":
|
|
1334
|
+
# Tool call started (name known, arguments pending)
|
|
1335
|
+
if event.tool_call:
|
|
1336
|
+
tc = event.tool_call
|
|
1337
|
+
self._call_id_to_tool[tc.id] = tc.name
|
|
1338
|
+
|
|
1339
|
+
# Always emit start notification (privacy-safe, no arguments)
|
|
1340
|
+
block_id = generate_id("blk")
|
|
1341
|
+
self._tool_call_blocks[tc.id] = block_id
|
|
1342
|
+
|
|
1343
|
+
await self.ctx.emit(BlockEvent(
|
|
1344
|
+
block_id=block_id,
|
|
1345
|
+
kind=BlockKind.TOOL_USE,
|
|
1346
|
+
op=BlockOp.APPLY,
|
|
1347
|
+
data={
|
|
1348
|
+
"name": tc.name,
|
|
1349
|
+
"call_id": tc.id,
|
|
1350
|
+
"status": "streaming", # Indicate arguments are streaming
|
|
1351
|
+
},
|
|
1352
|
+
))
|
|
1353
|
+
|
|
1354
|
+
elif event.type == "tool_call_delta":
|
|
1355
|
+
# Tool arguments delta (streaming)
|
|
1356
|
+
if event.tool_call_delta:
|
|
1357
|
+
call_id = event.tool_call_delta.get("call_id")
|
|
1358
|
+
arguments_delta = event.tool_call_delta.get("arguments_delta")
|
|
1359
|
+
|
|
1360
|
+
if call_id and arguments_delta:
|
|
1361
|
+
tool_name = self._call_id_to_tool.get(call_id)
|
|
1362
|
+
if tool_name:
|
|
1363
|
+
tool = self._get_tool(tool_name)
|
|
1364
|
+
|
|
1365
|
+
# Check if tool allows streaming arguments
|
|
1366
|
+
if tool and tool.config.stream_arguments:
|
|
1367
|
+
block_id = self._tool_call_blocks.get(call_id)
|
|
1368
|
+
if block_id:
|
|
1369
|
+
await self.ctx.emit(BlockEvent(
|
|
1370
|
+
block_id=block_id,
|
|
1371
|
+
kind=BlockKind.TOOL_USE,
|
|
1372
|
+
op=BlockOp.DELTA,
|
|
1373
|
+
data={
|
|
1374
|
+
"call_id": call_id,
|
|
1375
|
+
"arguments_delta": arguments_delta,
|
|
1376
|
+
},
|
|
1377
|
+
))
|
|
1378
|
+
|
|
1379
|
+
elif event.type == "tool_call_progress":
|
|
1380
|
+
# Tool arguments progress (bytes received)
|
|
1381
|
+
if event.tool_call_progress:
|
|
1382
|
+
call_id = event.tool_call_progress.get("call_id")
|
|
1383
|
+
bytes_received = event.tool_call_progress.get("bytes_received")
|
|
1384
|
+
|
|
1385
|
+
if call_id and bytes_received is not None:
|
|
1386
|
+
block_id = self._tool_call_blocks.get(call_id)
|
|
1387
|
+
if block_id:
|
|
1388
|
+
# Always emit progress (privacy-safe, no content)
|
|
1389
|
+
await self.ctx.emit(BlockEvent(
|
|
1390
|
+
block_id=block_id,
|
|
1391
|
+
kind=BlockKind.TOOL_USE,
|
|
1392
|
+
op=BlockOp.PATCH,
|
|
1393
|
+
data={
|
|
1394
|
+
"call_id": call_id,
|
|
1395
|
+
"bytes_received": bytes_received,
|
|
1396
|
+
"status": "receiving",
|
|
1397
|
+
},
|
|
1398
|
+
))
|
|
1399
|
+
|
|
1400
|
+
elif event.type == "tool_call":
|
|
1401
|
+
# Tool call complete (arguments fully received)
|
|
1402
|
+
if event.tool_call:
|
|
1403
|
+
tc = event.tool_call
|
|
1404
|
+
invocation = ToolInvocation(
|
|
1405
|
+
tool_call_id=tc.id,
|
|
1406
|
+
tool_name=tc.name,
|
|
1407
|
+
args_raw=tc.arguments,
|
|
1408
|
+
state=ToolInvocationState.CALL,
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
# Parse arguments
|
|
1412
|
+
try:
|
|
1413
|
+
invocation.args = json.loads(tc.arguments)
|
|
1414
|
+
except json.JSONDecodeError:
|
|
1415
|
+
invocation.args = {}
|
|
1416
|
+
|
|
1417
|
+
self._tool_invocations.append(invocation)
|
|
1418
|
+
|
|
1419
|
+
# Strict mode: require tool_call_start to be received first
|
|
1420
|
+
# TODO: Uncomment below for compatibility with providers that don't send tool_call_start
|
|
1421
|
+
# block_id = self._tool_call_blocks.get(tc.id)
|
|
1422
|
+
# if block_id is None:
|
|
1423
|
+
# # No streaming start event, create block now
|
|
1424
|
+
# block_id = generate_id("blk")
|
|
1425
|
+
# self._tool_call_blocks[tc.id] = block_id
|
|
1426
|
+
# self._call_id_to_tool[tc.id] = tc.name
|
|
1427
|
+
#
|
|
1428
|
+
# # Emit APPLY with full data
|
|
1429
|
+
# await self.ctx.emit(BlockEvent(
|
|
1430
|
+
# block_id=block_id,
|
|
1431
|
+
# kind=BlockKind.TOOL_USE,
|
|
1432
|
+
# op=BlockOp.APPLY,
|
|
1433
|
+
# data={
|
|
1434
|
+
# "name": tc.name,
|
|
1435
|
+
# "call_id": tc.id,
|
|
1436
|
+
# "arguments": invocation.args,
|
|
1437
|
+
# "status": "ready",
|
|
1438
|
+
# },
|
|
1439
|
+
# ))
|
|
1440
|
+
# else:
|
|
1441
|
+
# # Update existing block with complete arguments
|
|
1442
|
+
# await self.ctx.emit(BlockEvent(
|
|
1443
|
+
# block_id=block_id,
|
|
1444
|
+
# kind=BlockKind.TOOL_USE,
|
|
1445
|
+
# op=BlockOp.PATCH,
|
|
1446
|
+
# data={
|
|
1447
|
+
# "call_id": tc.id,
|
|
1448
|
+
# "arguments": invocation.args,
|
|
1449
|
+
# "status": "ready",
|
|
1450
|
+
# },
|
|
1451
|
+
# ))
|
|
1452
|
+
|
|
1453
|
+
# Strict mode: tool_call_start must have been received
|
|
1454
|
+
block_id = self._tool_call_blocks[tc.id] # Will raise KeyError if not found
|
|
1455
|
+
await self.ctx.emit(BlockEvent(
|
|
1456
|
+
block_id=block_id,
|
|
1457
|
+
kind=BlockKind.TOOL_USE,
|
|
1458
|
+
op=BlockOp.PATCH,
|
|
1459
|
+
data={
|
|
1460
|
+
"call_id": tc.id,
|
|
1461
|
+
"arguments": invocation.args,
|
|
1462
|
+
"status": "ready",
|
|
1463
|
+
},
|
|
1464
|
+
))
|
|
1465
|
+
|
|
1466
|
+
await self.bus.publish(
|
|
1467
|
+
Events.TOOL_START,
|
|
1468
|
+
{
|
|
1469
|
+
"call_id": tc.id,
|
|
1470
|
+
"tool": tc.name,
|
|
1471
|
+
"arguments": invocation.args,
|
|
1472
|
+
},
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
elif event.type == "completed":
|
|
1476
|
+
finish_reason = event.finish_reason
|
|
1477
|
+
|
|
1478
|
+
elif event.type == "usage":
|
|
1479
|
+
if event.usage:
|
|
1480
|
+
await self.bus.publish(
|
|
1481
|
+
Events.USAGE_RECORDED,
|
|
1482
|
+
{
|
|
1483
|
+
"provider": self.llm.provider,
|
|
1484
|
+
"model": self.llm.model,
|
|
1485
|
+
"input_tokens": event.usage.input_tokens,
|
|
1486
|
+
"output_tokens": event.usage.output_tokens,
|
|
1487
|
+
"cache_read_tokens": event.usage.cache_read_tokens,
|
|
1488
|
+
"cache_write_tokens": event.usage.cache_write_tokens,
|
|
1489
|
+
"reasoning_tokens": event.usage.reasoning_tokens,
|
|
1490
|
+
},
|
|
1491
|
+
)
|
|
1492
|
+
|
|
1493
|
+
elif event.type == "error":
|
|
1494
|
+
await self.ctx.emit(BlockEvent(
|
|
1495
|
+
kind=BlockKind.ERROR,
|
|
1496
|
+
op=BlockOp.APPLY,
|
|
1497
|
+
data={"message": event.error or "Unknown LLM error"},
|
|
1498
|
+
))
|
|
1499
|
+
|
|
1500
|
+
# If thinking was buffered, emit it now
|
|
1501
|
+
if self._thinking_buffer and not self.config.stream_thinking:
|
|
1502
|
+
await self.ctx.emit(BlockEvent(
|
|
1503
|
+
kind=BlockKind.THINKING,
|
|
1504
|
+
op=BlockOp.APPLY,
|
|
1505
|
+
data={"content": self._thinking_buffer},
|
|
1506
|
+
))
|
|
1507
|
+
|
|
1508
|
+
# PROMPT mode: parse tool calls from text output
|
|
1509
|
+
if effective_tool_mode == ToolInjectionMode.PROMPT and self._text_buffer:
|
|
1510
|
+
parsed_calls = self._parse_tool_calls_from_text(self._text_buffer)
|
|
1511
|
+
for i, call in enumerate(parsed_calls):
|
|
1512
|
+
call_id = generate_id("call")
|
|
1513
|
+
invocation = ToolInvocation(
|
|
1514
|
+
tool_call_id=call_id,
|
|
1515
|
+
tool_name=call["name"],
|
|
1516
|
+
args_raw=json.dumps(call["arguments"]),
|
|
1517
|
+
args=call["arguments"],
|
|
1518
|
+
state=ToolInvocationState.CALL,
|
|
1519
|
+
)
|
|
1520
|
+
self._tool_invocations.append(invocation)
|
|
1521
|
+
|
|
1522
|
+
# Create block for tool call (no streaming events in PROMPT mode)
|
|
1523
|
+
block_id = generate_id("blk")
|
|
1524
|
+
self._tool_call_blocks[call_id] = block_id
|
|
1525
|
+
self._call_id_to_tool[call_id] = call["name"]
|
|
1526
|
+
|
|
1527
|
+
await self.ctx.emit(BlockEvent(
|
|
1528
|
+
block_id=block_id,
|
|
1529
|
+
kind=BlockKind.TOOL_USE,
|
|
1530
|
+
op=BlockOp.APPLY,
|
|
1531
|
+
data={
|
|
1532
|
+
"name": call["name"],
|
|
1533
|
+
"call_id": call_id,
|
|
1534
|
+
"arguments": call["arguments"],
|
|
1535
|
+
"status": "ready",
|
|
1536
|
+
"source": "prompt", # Indicate parsed from text
|
|
1537
|
+
},
|
|
1538
|
+
))
|
|
1539
|
+
|
|
1540
|
+
await self.bus.publish(
|
|
1541
|
+
Events.TOOL_START,
|
|
1542
|
+
{
|
|
1543
|
+
"call_id": call_id,
|
|
1544
|
+
"tool": call["name"],
|
|
1545
|
+
"arguments": call["arguments"],
|
|
1546
|
+
"source": "prompt",
|
|
1547
|
+
},
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
if parsed_calls:
|
|
1551
|
+
logger.debug(f"PROMPT mode: parsed {len(parsed_calls)} tool calls from text")
|
|
1552
|
+
|
|
1553
|
+
# === Middleware: on_response ===
|
|
1554
|
+
llm_response_data = {
|
|
1555
|
+
"text": self._text_buffer,
|
|
1556
|
+
"thinking": self._thinking_buffer,
|
|
1557
|
+
"tool_calls": len(self._tool_invocations),
|
|
1558
|
+
"finish_reason": finish_reason,
|
|
1559
|
+
}
|
|
1560
|
+
if self.middleware:
|
|
1561
|
+
llm_response_data = await self.middleware.process_response(
|
|
1562
|
+
llm_response_data, mw_context
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
await self.bus.publish(
|
|
1566
|
+
Events.LLM_END,
|
|
1567
|
+
{
|
|
1568
|
+
"step": self._current_step,
|
|
1569
|
+
"finish_reason": finish_reason,
|
|
1570
|
+
"text_length": len(self._text_buffer),
|
|
1571
|
+
"thinking_length": len(self._thinking_buffer),
|
|
1572
|
+
"tool_calls": len(self._tool_invocations),
|
|
1573
|
+
},
|
|
1574
|
+
)
|
|
1575
|
+
|
|
1576
|
+
# Add assistant message to history
|
|
1577
|
+
if self._text_buffer or self._tool_invocations:
|
|
1578
|
+
assistant_content: Any = self._text_buffer
|
|
1579
|
+
if self._tool_invocations:
|
|
1580
|
+
# Build content with tool calls
|
|
1581
|
+
content_parts = []
|
|
1582
|
+
if self._text_buffer:
|
|
1583
|
+
content_parts.append({"type": "text", "text": self._text_buffer})
|
|
1584
|
+
for inv in self._tool_invocations:
|
|
1585
|
+
content_parts.append(
|
|
1586
|
+
{
|
|
1587
|
+
"type": "tool_use",
|
|
1588
|
+
"id": inv.tool_call_id,
|
|
1589
|
+
"name": inv.tool_name,
|
|
1590
|
+
"input": inv.args,
|
|
1591
|
+
}
|
|
1592
|
+
)
|
|
1593
|
+
assistant_content = content_parts
|
|
1594
|
+
|
|
1595
|
+
self._message_history.append(
|
|
1596
|
+
LLMMessage(
|
|
1597
|
+
role="assistant",
|
|
1598
|
+
content=assistant_content,
|
|
1599
|
+
)
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
return finish_reason
|
|
1603
|
+
|
|
1604
|
+
async def _process_tool_results(self) -> None:
|
|
1605
|
+
"""Execute tool calls and add results to history.
|
|
1606
|
+
|
|
1607
|
+
Executes tools in parallel or sequentially based on config.
|
|
1608
|
+
"""
|
|
1609
|
+
if not self._tool_invocations:
|
|
1610
|
+
return
|
|
1611
|
+
|
|
1612
|
+
# Execute tools based on configuration
|
|
1613
|
+
if self.config.parallel_tool_execution:
|
|
1614
|
+
# Parallel execution using asyncio.gather with create_task
|
|
1615
|
+
# create_task ensures each task gets its own ContextVar copy
|
|
1616
|
+
tasks = [asyncio.create_task(self._execute_tool(inv)) for inv in self._tool_invocations]
|
|
1617
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
1618
|
+
else:
|
|
1619
|
+
# Sequential execution
|
|
1620
|
+
results = []
|
|
1621
|
+
for inv in self._tool_invocations:
|
|
1622
|
+
try:
|
|
1623
|
+
result = await self._execute_tool(inv)
|
|
1624
|
+
results.append(result)
|
|
1625
|
+
except Exception as e:
|
|
1626
|
+
results.append(e)
|
|
1627
|
+
|
|
1628
|
+
# Check for SuspendSignal first - must propagate
|
|
1629
|
+
for result in results:
|
|
1630
|
+
if isinstance(result, SuspendSignal):
|
|
1631
|
+
raise result
|
|
1632
|
+
|
|
1633
|
+
# Process results
|
|
1634
|
+
tool_results = []
|
|
1635
|
+
|
|
1636
|
+
for invocation, result in zip(self._tool_invocations, results):
|
|
1637
|
+
# Handle exceptions from gather
|
|
1638
|
+
if isinstance(result, Exception):
|
|
1639
|
+
error_msg = f"Tool execution error: {str(result)}"
|
|
1640
|
+
invocation.mark_result(error_msg, is_error=True)
|
|
1641
|
+
result = ToolResult.error(error_msg)
|
|
1642
|
+
|
|
1643
|
+
# Get parent block_id from tool_call mapping
|
|
1644
|
+
parent_block_id = self._tool_call_blocks.get(invocation.tool_call_id)
|
|
1645
|
+
|
|
1646
|
+
await self.ctx.emit(BlockEvent(
|
|
1647
|
+
kind=BlockKind.TOOL_RESULT,
|
|
1648
|
+
op=BlockOp.APPLY,
|
|
1649
|
+
parent_id=parent_block_id,
|
|
1650
|
+
data={
|
|
1651
|
+
"call_id": invocation.tool_call_id,
|
|
1652
|
+
"content": result.output,
|
|
1653
|
+
"is_error": invocation.is_error,
|
|
1654
|
+
},
|
|
1655
|
+
))
|
|
1656
|
+
|
|
1657
|
+
await self.bus.publish(
|
|
1658
|
+
Events.TOOL_END,
|
|
1659
|
+
{
|
|
1660
|
+
"call_id": invocation.tool_call_id,
|
|
1661
|
+
"tool": invocation.tool_name,
|
|
1662
|
+
"result": result.output[:500], # Truncate for event
|
|
1663
|
+
"is_error": invocation.is_error,
|
|
1664
|
+
"duration_ms": invocation.duration_ms,
|
|
1665
|
+
},
|
|
1666
|
+
)
|
|
1667
|
+
|
|
1668
|
+
tool_results.append(
|
|
1669
|
+
{
|
|
1670
|
+
"type": "tool_result",
|
|
1671
|
+
"tool_use_id": invocation.tool_call_id,
|
|
1672
|
+
"content": result.output,
|
|
1673
|
+
"is_error": invocation.is_error,
|
|
1674
|
+
}
|
|
1675
|
+
)
|
|
1676
|
+
|
|
1677
|
+
# Add tool results as tool messages (OpenAI format)
|
|
1678
|
+
for tr in tool_results:
|
|
1679
|
+
print(f"[DEBUG _process_tool_results] Adding tool_result to history: {tr}")
|
|
1680
|
+
self._message_history.append(
|
|
1681
|
+
LLMMessage(
|
|
1682
|
+
role="tool",
|
|
1683
|
+
content=tr["content"],
|
|
1684
|
+
tool_call_id=tr["tool_use_id"],
|
|
1685
|
+
)
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
def _save_messages_to_state(self) -> None:
|
|
1689
|
+
"""Save execution state for recovery.
|
|
1690
|
+
|
|
1691
|
+
This saves to state.execution namespace:
|
|
1692
|
+
- step: current step number
|
|
1693
|
+
- message_ids: references to raw messages (if using RawMessageMiddleware)
|
|
1694
|
+
- For legacy/fallback: message_history as serialized data
|
|
1695
|
+
|
|
1696
|
+
Note: With RawMessageMiddleware, message_ids are automatically populated
|
|
1697
|
+
by the middleware. This method saves additional execution state.
|
|
1698
|
+
"""
|
|
1699
|
+
if not self._state:
|
|
1700
|
+
return
|
|
1701
|
+
|
|
1702
|
+
# Save step to execution namespace
|
|
1703
|
+
self._state.execution["step"] = self._current_step
|
|
1704
|
+
|
|
1705
|
+
# Save invocation_id for recovery context
|
|
1706
|
+
if self._current_invocation:
|
|
1707
|
+
self._state.execution["invocation_id"] = self._current_invocation.id
|
|
1708
|
+
|
|
1709
|
+
# Fallback: if message_ids not populated by middleware, save full history
|
|
1710
|
+
# This ensures backward compatibility when RawMessageMiddleware is not used
|
|
1711
|
+
if "message_ids" not in self._state.execution:
|
|
1712
|
+
messages_data = []
|
|
1713
|
+
for msg in self._message_history:
|
|
1714
|
+
msg_dict = {"role": msg.role, "content": msg.content}
|
|
1715
|
+
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
|
|
1716
|
+
msg_dict["tool_call_id"] = msg.tool_call_id
|
|
1717
|
+
messages_data.append(msg_dict)
|
|
1718
|
+
self._state.execution["message_history"] = messages_data
|
|
1719
|
+
|
|
1720
|
+
def _clear_messages_from_state(self) -> None:
|
|
1721
|
+
"""Clear execution state after invocation completes.
|
|
1722
|
+
|
|
1723
|
+
Called when invocation completes normally. Historical messages
|
|
1724
|
+
are already persisted (truncated) via MessageStore.
|
|
1725
|
+
"""
|
|
1726
|
+
if not self._state:
|
|
1727
|
+
return
|
|
1728
|
+
|
|
1729
|
+
# Clear execution namespace
|
|
1730
|
+
self._state.execution.clear()
|
|
1731
|
+
|
|
1732
|
+
async def _trigger_message_save(self, message: dict) -> dict | None:
|
|
1733
|
+
"""Trigger on_message_save hook via middleware.
|
|
1734
|
+
|
|
1735
|
+
Message persistence is handled by MessageBackendMiddleware.
|
|
1736
|
+
Agent only triggers the hook, doesn't save directly.
|
|
1737
|
+
|
|
1738
|
+
Args:
|
|
1739
|
+
message: Message dict with role, content, etc.
|
|
1740
|
+
|
|
1741
|
+
Returns:
|
|
1742
|
+
Modified message or None if blocked
|
|
1743
|
+
"""
|
|
1744
|
+
# Check if message saving is disabled (e.g., for sub-agents with record_messages=False)
|
|
1745
|
+
if getattr(self, '_disable_message_save', False):
|
|
1746
|
+
return message
|
|
1747
|
+
|
|
1748
|
+
if not self.middleware:
|
|
1749
|
+
return message
|
|
1750
|
+
|
|
1751
|
+
namespace = getattr(self, '_message_namespace', None)
|
|
1752
|
+
mw_context = {
|
|
1753
|
+
"session_id": self.session.id,
|
|
1754
|
+
"agent_id": self.name,
|
|
1755
|
+
"namespace": namespace,
|
|
1756
|
+
}
|
|
1757
|
+
|
|
1758
|
+
return await self.middleware.process_message_save(message, mw_context)
|
|
1759
|
+
|
|
1760
|
+
async def _save_user_message(self, input: PromptInput) -> None:
|
|
1761
|
+
"""Trigger save for user message."""
|
|
1762
|
+
# Build user content
|
|
1763
|
+
content: str | list[dict] = input.text
|
|
1764
|
+
if self._agent_context and self._agent_context.user_content:
|
|
1765
|
+
content = self._agent_context.user_content + "\n\n" + input.text
|
|
1766
|
+
|
|
1767
|
+
if input.attachments:
|
|
1768
|
+
content_parts: list[dict] = [{"type": "text", "text": content}]
|
|
1769
|
+
for attachment in input.attachments:
|
|
1770
|
+
content_parts.append(attachment)
|
|
1771
|
+
content = content_parts
|
|
1772
|
+
|
|
1773
|
+
# Build message and trigger hook
|
|
1774
|
+
message = {
|
|
1775
|
+
"role": "user",
|
|
1776
|
+
"content": content,
|
|
1777
|
+
"invocation_id": self._current_invocation.id if self._current_invocation else "",
|
|
1778
|
+
}
|
|
1779
|
+
|
|
1780
|
+
await self._trigger_message_save(message)
|
|
1781
|
+
|
|
1782
|
+
async def _save_assistant_message(self) -> None:
|
|
1783
|
+
"""Trigger save for assistant message."""
|
|
1784
|
+
if not self._text_buffer and not self._tool_invocations:
|
|
1785
|
+
return
|
|
1786
|
+
|
|
1787
|
+
# Build assistant content
|
|
1788
|
+
content: str | list[dict] = self._text_buffer
|
|
1789
|
+
if self._tool_invocations:
|
|
1790
|
+
content_parts: list[dict] = []
|
|
1791
|
+
if self._text_buffer:
|
|
1792
|
+
content_parts.append({"type": "text", "text": self._text_buffer})
|
|
1793
|
+
for inv in self._tool_invocations:
|
|
1794
|
+
content_parts.append({
|
|
1795
|
+
"type": "tool_use",
|
|
1796
|
+
"id": inv.tool_call_id,
|
|
1797
|
+
"name": inv.tool_name,
|
|
1798
|
+
"input": inv.args,
|
|
1799
|
+
})
|
|
1800
|
+
content = content_parts
|
|
1801
|
+
|
|
1802
|
+
# Build message and trigger hook
|
|
1803
|
+
message = {
|
|
1804
|
+
"role": "assistant",
|
|
1805
|
+
"content": content,
|
|
1806
|
+
"invocation_id": self._current_invocation.id if self._current_invocation else "",
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
await self._trigger_message_save(message)
|
|
1810
|
+
|
|
1811
|
+
async def _save_tool_messages(self) -> None:
|
|
1812
|
+
"""Trigger save for tool result messages."""
|
|
1813
|
+
for inv in self._tool_invocations:
|
|
1814
|
+
if inv.result is not None:
|
|
1815
|
+
# Build tool result message
|
|
1816
|
+
content: list[dict] = [{
|
|
1817
|
+
"type": "tool_result",
|
|
1818
|
+
"tool_use_id": inv.tool_call_id,
|
|
1819
|
+
"content": inv.result,
|
|
1820
|
+
"is_error": inv.is_error,
|
|
1821
|
+
}]
|
|
1822
|
+
|
|
1823
|
+
message = {
|
|
1824
|
+
"role": "tool",
|
|
1825
|
+
"content": content,
|
|
1826
|
+
"tool_call_id": inv.tool_call_id,
|
|
1827
|
+
"invocation_id": self._current_invocation.id if self._current_invocation else "",
|
|
1828
|
+
}
|
|
1829
|
+
|
|
1830
|
+
await self._trigger_message_save(message)
|
|
1831
|
+
|
|
1832
|
+
def _get_tool(self, tool_name: str) -> "BaseTool | None":
|
|
1833
|
+
"""Get tool by name from agent context."""
|
|
1834
|
+
if self._agent_context:
|
|
1835
|
+
for tool in self._agent_context.tools:
|
|
1836
|
+
if tool.name == tool_name:
|
|
1837
|
+
return tool
|
|
1838
|
+
return None
|
|
1839
|
+
|
|
1840
|
+
async def _execute_tool(self, invocation: ToolInvocation) -> ToolResult:
|
|
1841
|
+
"""Execute a single tool call."""
|
|
1842
|
+
invocation.mark_call_complete()
|
|
1843
|
+
|
|
1844
|
+
# Build middleware context
|
|
1845
|
+
mw_context = {
|
|
1846
|
+
"session_id": self.session.id,
|
|
1847
|
+
"invocation_id": self._current_invocation.id if self._current_invocation else "",
|
|
1848
|
+
"tool_call_id": invocation.tool_call_id,
|
|
1849
|
+
"agent_id": self.name,
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
try:
|
|
1853
|
+
# Get tool from agent context
|
|
1854
|
+
tool = self._get_tool(invocation.tool_name)
|
|
1855
|
+
if tool is None:
|
|
1856
|
+
error_msg = f"Unknown tool: {invocation.tool_name}"
|
|
1857
|
+
invocation.mark_result(error_msg, is_error=True)
|
|
1858
|
+
return ToolResult.error(error_msg)
|
|
1859
|
+
|
|
1860
|
+
# === Middleware: on_tool_call ===
|
|
1861
|
+
if self.middleware:
|
|
1862
|
+
hook_result = await self.middleware.process_tool_call(
|
|
1863
|
+
tool, invocation.args, mw_context
|
|
1864
|
+
)
|
|
1865
|
+
if hook_result.action == HookAction.SKIP:
|
|
1866
|
+
logger.info(f"Tool {invocation.tool_name} skipped by middleware")
|
|
1867
|
+
return ToolResult(
|
|
1868
|
+
output=hook_result.message or "Skipped by middleware",
|
|
1869
|
+
is_error=False,
|
|
1870
|
+
)
|
|
1871
|
+
elif hook_result.action == HookAction.RETRY and hook_result.modified_data:
|
|
1872
|
+
invocation.args = hook_result.modified_data
|
|
1873
|
+
|
|
1874
|
+
# Create ToolContext
|
|
1875
|
+
tool_ctx = ToolContext(
|
|
1876
|
+
session_id=self.session.id,
|
|
1877
|
+
invocation_id=self._current_invocation.id if self._current_invocation else "",
|
|
1878
|
+
block_id="",
|
|
1879
|
+
call_id=invocation.tool_call_id,
|
|
1880
|
+
agent=self.config.name,
|
|
1881
|
+
abort_signal=self._abort,
|
|
1882
|
+
update_metadata=self._noop_update_metadata,
|
|
1883
|
+
middleware=self.middleware,
|
|
1884
|
+
)
|
|
1885
|
+
|
|
1886
|
+
# Execute tool (with optional timeout from tool.config)
|
|
1887
|
+
timeout = tool.config.timeout
|
|
1888
|
+
if timeout is not None:
|
|
1889
|
+
result = await asyncio.wait_for(
|
|
1890
|
+
tool.execute(invocation.args, tool_ctx),
|
|
1891
|
+
timeout=timeout,
|
|
1892
|
+
)
|
|
1893
|
+
else:
|
|
1894
|
+
# No timeout - tool runs until completion
|
|
1895
|
+
result = await tool.execute(invocation.args, tool_ctx)
|
|
1896
|
+
|
|
1897
|
+
# === Middleware: on_tool_end ===
|
|
1898
|
+
if self.middleware:
|
|
1899
|
+
hook_result = await self.middleware.process_tool_end(tool, result, mw_context)
|
|
1900
|
+
if hook_result.action == HookAction.RETRY:
|
|
1901
|
+
logger.info(f"Tool {invocation.tool_name} retry requested by middleware")
|
|
1902
|
+
|
|
1903
|
+
invocation.mark_result(result.output, is_error=result.is_error)
|
|
1904
|
+
return result
|
|
1905
|
+
|
|
1906
|
+
except asyncio.TimeoutError:
|
|
1907
|
+
timeout = tool.config.timeout if tool else None
|
|
1908
|
+
error_msg = f"Tool {invocation.tool_name} timed out after {timeout}s"
|
|
1909
|
+
invocation.mark_result(error_msg, is_error=True)
|
|
1910
|
+
return ToolResult.error(error_msg)
|
|
1911
|
+
|
|
1912
|
+
except SuspendSignal:
|
|
1913
|
+
# HITL/Suspend signal must propagate up
|
|
1914
|
+
raise
|
|
1915
|
+
|
|
1916
|
+
except Exception as e:
|
|
1917
|
+
error_msg = f"Tool execution error: {str(e)}"
|
|
1918
|
+
invocation.mark_result(error_msg, is_error=True)
|
|
1919
|
+
return ToolResult.error(error_msg)
|
|
1920
|
+
|
|
1921
|
+
async def _noop_update_metadata(self, metadata: dict[str, Any]) -> None:
|
|
1922
|
+
"""No-op metadata updater."""
|
|
1923
|
+
pass
|