aury-agent 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/backends/__init__.py +8 -0
- aury/agents/backends/hitl/__init__.py +8 -0
- aury/agents/backends/hitl/memory.py +100 -0
- aury/agents/backends/hitl/types.py +132 -0
- aury/agents/core/base.py +5 -0
- aury/agents/core/context.py +1 -0
- aury/agents/core/signals.py +37 -17
- aury/agents/core/types/__init__.py +0 -2
- aury/agents/core/types/block.py +6 -23
- aury/agents/core/types/session.py +10 -3
- aury/agents/core/types/tool.py +194 -18
- aury/agents/hitl/__init__.py +2 -0
- aury/agents/hitl/ask_user.py +59 -47
- aury/agents/hitl/exceptions.py +214 -13
- aury/agents/react/agent.py +47 -0
- aury/agents/react/context.py +51 -25
- aury/agents/react/factory.py +2 -0
- aury/agents/react/pause.py +13 -2
- aury/agents/react/step.py +39 -12
- aury/agents/react/tools.py +277 -147
- aury/agents/tool/builtin/ask_user.py +1 -5
- aury/agents/tool/builtin/delegate.py +3 -15
- aury/agents/tool/builtin/plan.py +1 -5
- aury/agents/tool/builtin/thinking.py +1 -6
- aury/agents/tool/builtin/yield_result.py +1 -6
- {aury_agent-0.0.12.dist-info → aury_agent-0.0.14.dist-info}/METADATA +1 -1
- {aury_agent-0.0.12.dist-info → aury_agent-0.0.14.dist-info}/RECORD +29 -26
- {aury_agent-0.0.12.dist-info → aury_agent-0.0.14.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.12.dist-info → aury_agent-0.0.14.dist-info}/entry_points.txt +0 -0
aury/agents/hitl/exceptions.py
CHANGED
|
@@ -35,16 +35,16 @@ class HITLRequest:
|
|
|
35
35
|
|
|
36
36
|
Stored in invocation for persistence.
|
|
37
37
|
"""
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
hitl_id: str
|
|
39
|
+
hitl_type: str # ask_user, confirm, permission, external_auth, workflow_human
|
|
40
40
|
|
|
41
|
-
#
|
|
42
|
-
|
|
43
|
-
options: list[str] | None = None
|
|
41
|
+
# Type-specific data
|
|
42
|
+
data: dict[str, Any] = field(default_factory=dict) # {message, options, ...}
|
|
44
43
|
|
|
45
44
|
# Context
|
|
46
45
|
tool_name: str | None = None # If triggered by tool
|
|
47
46
|
node_id: str | None = None # If triggered by workflow node
|
|
47
|
+
block_id: str | None = None # Associated UI block
|
|
48
48
|
|
|
49
49
|
# Metadata
|
|
50
50
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
@@ -52,12 +52,12 @@ class HITLRequest:
|
|
|
52
52
|
def to_dict(self) -> dict[str, Any]:
|
|
53
53
|
"""Convert to dictionary for serialization."""
|
|
54
54
|
return {
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
"
|
|
58
|
-
"options": self.options,
|
|
55
|
+
"hitl_id": self.hitl_id,
|
|
56
|
+
"hitl_type": self.hitl_type,
|
|
57
|
+
"data": self.data,
|
|
59
58
|
"tool_name": self.tool_name,
|
|
60
59
|
"node_id": self.node_id,
|
|
60
|
+
"block_id": self.block_id,
|
|
61
61
|
"metadata": self.metadata,
|
|
62
62
|
}
|
|
63
63
|
|
|
@@ -65,16 +65,216 @@ class HITLRequest:
|
|
|
65
65
|
def from_dict(cls, data: dict[str, Any]) -> "HITLRequest":
|
|
66
66
|
"""Create from dictionary."""
|
|
67
67
|
return cls(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
options=data.get("options"),
|
|
68
|
+
hitl_id=data["hitl_id"],
|
|
69
|
+
hitl_type=data.get("hitl_type", "ask_user"),
|
|
70
|
+
data=data.get("data", {}),
|
|
72
71
|
tool_name=data.get("tool_name"),
|
|
73
72
|
node_id=data.get("node_id"),
|
|
73
|
+
block_id=data.get("block_id"),
|
|
74
74
|
metadata=data.get("metadata", {}),
|
|
75
75
|
)
|
|
76
76
|
|
|
77
77
|
|
|
78
|
+
@dataclass
|
|
79
|
+
class ToolCheckpoint:
|
|
80
|
+
"""Tool execution checkpoint for continuation mode.
|
|
81
|
+
|
|
82
|
+
When a tool raises HITLSuspend with resume_mode="continuation",
|
|
83
|
+
the framework creates a ToolCheckpoint to save the tool's execution
|
|
84
|
+
state. When the user responds, the tool is resumed from this checkpoint.
|
|
85
|
+
|
|
86
|
+
Use cases:
|
|
87
|
+
- OAuth authorization flow (wait for callback)
|
|
88
|
+
- Payment confirmation (wait for payment gateway)
|
|
89
|
+
- Multi-step wizards with user confirmation
|
|
90
|
+
- External system integration with async callbacks
|
|
91
|
+
|
|
92
|
+
Storage:
|
|
93
|
+
- Stored via CheckpointBackend (Redis/DB)
|
|
94
|
+
- Keyed by checkpoint_id and callback_id
|
|
95
|
+
- Has TTL for automatic expiration
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
# In tool execution:
|
|
99
|
+
raise HITLSuspend(
|
|
100
|
+
request_id="hitl_123",
|
|
101
|
+
request_type="external_auth",
|
|
102
|
+
resume_mode="continuation",
|
|
103
|
+
tool_state={"step": 2, "partial_data": {...}},
|
|
104
|
+
metadata={"auth_url": "https://...", "callback_id": "cb_456"},
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Framework creates ToolCheckpoint and saves it
|
|
108
|
+
# When callback arrives, framework loads checkpoint and resumes tool
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Identity
|
|
112
|
+
checkpoint_id: str
|
|
113
|
+
callback_id: str | None = None # For external callback matching
|
|
114
|
+
|
|
115
|
+
# Association
|
|
116
|
+
session_id: str | None = None
|
|
117
|
+
invocation_id: str | None = None
|
|
118
|
+
block_id: str | None = None # Frontend HITL block
|
|
119
|
+
|
|
120
|
+
# Tool execution context
|
|
121
|
+
tool_name: str = ""
|
|
122
|
+
tool_call_id: str = ""
|
|
123
|
+
params: dict[str, Any] = field(default_factory=dict) # Original params
|
|
124
|
+
tool_state: dict[str, Any] = field(default_factory=dict) # Internal state
|
|
125
|
+
|
|
126
|
+
# HITL info
|
|
127
|
+
hitl_id: str = ""
|
|
128
|
+
hitl_type: str = "" # ask_user, confirm, external_auth, etc.
|
|
129
|
+
|
|
130
|
+
# Status
|
|
131
|
+
status: str = "pending" # pending | completed | expired | failed | cancelled
|
|
132
|
+
expires_at: int | None = None # Unix timestamp
|
|
133
|
+
|
|
134
|
+
# User response (filled after callback/response)
|
|
135
|
+
user_response: Any | None = None
|
|
136
|
+
error: str | None = None
|
|
137
|
+
|
|
138
|
+
# Timestamps
|
|
139
|
+
created_at: int = 0
|
|
140
|
+
updated_at: int = 0
|
|
141
|
+
|
|
142
|
+
def __post_init__(self):
|
|
143
|
+
import time
|
|
144
|
+
now = int(time.time())
|
|
145
|
+
if not self.created_at:
|
|
146
|
+
self.created_at = now
|
|
147
|
+
if not self.updated_at:
|
|
148
|
+
self.updated_at = now
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def is_expired(self) -> bool:
|
|
152
|
+
"""Check if checkpoint has expired."""
|
|
153
|
+
if self.expires_at is None:
|
|
154
|
+
return False
|
|
155
|
+
import time
|
|
156
|
+
return time.time() > self.expires_at
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def is_pending(self) -> bool:
|
|
160
|
+
"""Check if checkpoint is waiting for response."""
|
|
161
|
+
return self.status == "pending" and not self.is_expired
|
|
162
|
+
|
|
163
|
+
def mark_completed(self, response: Any) -> None:
|
|
164
|
+
"""Mark checkpoint as completed with user response."""
|
|
165
|
+
import time
|
|
166
|
+
self.status = "completed"
|
|
167
|
+
self.user_response = response
|
|
168
|
+
self.updated_at = int(time.time())
|
|
169
|
+
|
|
170
|
+
def mark_failed(self, error: str) -> None:
|
|
171
|
+
"""Mark checkpoint as failed."""
|
|
172
|
+
import time
|
|
173
|
+
self.status = "failed"
|
|
174
|
+
self.error = error
|
|
175
|
+
self.updated_at = int(time.time())
|
|
176
|
+
|
|
177
|
+
def mark_cancelled(self, reason: str = "user_cancelled") -> None:
|
|
178
|
+
"""Mark checkpoint as cancelled."""
|
|
179
|
+
import time
|
|
180
|
+
self.status = "cancelled"
|
|
181
|
+
self.error = reason
|
|
182
|
+
self.updated_at = int(time.time())
|
|
183
|
+
|
|
184
|
+
def to_dict(self) -> dict[str, Any]:
|
|
185
|
+
"""Convert to dictionary for serialization."""
|
|
186
|
+
return {
|
|
187
|
+
"checkpoint_id": self.checkpoint_id,
|
|
188
|
+
"callback_id": self.callback_id,
|
|
189
|
+
"session_id": self.session_id,
|
|
190
|
+
"invocation_id": self.invocation_id,
|
|
191
|
+
"block_id": self.block_id,
|
|
192
|
+
"tool_name": self.tool_name,
|
|
193
|
+
"tool_call_id": self.tool_call_id,
|
|
194
|
+
"params": self.params,
|
|
195
|
+
"tool_state": self.tool_state,
|
|
196
|
+
"hitl_id": self.hitl_id,
|
|
197
|
+
"hitl_type": self.hitl_type,
|
|
198
|
+
"status": self.status,
|
|
199
|
+
"expires_at": self.expires_at,
|
|
200
|
+
"user_response": self.user_response,
|
|
201
|
+
"error": self.error,
|
|
202
|
+
"created_at": self.created_at,
|
|
203
|
+
"updated_at": self.updated_at,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_dict(cls, data: dict[str, Any]) -> "ToolCheckpoint":
|
|
208
|
+
"""Create from dictionary."""
|
|
209
|
+
return cls(
|
|
210
|
+
checkpoint_id=data["checkpoint_id"],
|
|
211
|
+
callback_id=data.get("callback_id"),
|
|
212
|
+
session_id=data.get("session_id"),
|
|
213
|
+
invocation_id=data.get("invocation_id"),
|
|
214
|
+
block_id=data.get("block_id"),
|
|
215
|
+
tool_name=data.get("tool_name", ""),
|
|
216
|
+
tool_call_id=data.get("tool_call_id", ""),
|
|
217
|
+
params=data.get("params", {}),
|
|
218
|
+
tool_state=data.get("tool_state", {}),
|
|
219
|
+
hitl_id=data.get("hitl_id", ""),
|
|
220
|
+
hitl_type=data.get("hitl_type", ""),
|
|
221
|
+
status=data.get("status", "pending"),
|
|
222
|
+
expires_at=data.get("expires_at"),
|
|
223
|
+
user_response=data.get("user_response"),
|
|
224
|
+
error=data.get("error"),
|
|
225
|
+
created_at=data.get("created_at", 0),
|
|
226
|
+
updated_at=data.get("updated_at", 0),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_hitl_suspend(
|
|
231
|
+
cls,
|
|
232
|
+
suspend: HITLSuspend,
|
|
233
|
+
*,
|
|
234
|
+
tool_call_id: str,
|
|
235
|
+
params: dict[str, Any],
|
|
236
|
+
session_id: str | None = None,
|
|
237
|
+
invocation_id: str | None = None,
|
|
238
|
+
block_id: str | None = None,
|
|
239
|
+
expires_in: int | None = 600, # Default 10 minutes
|
|
240
|
+
) -> "ToolCheckpoint":
|
|
241
|
+
"""Create checkpoint from HITLSuspend signal.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
suspend: The HITLSuspend signal
|
|
245
|
+
tool_call_id: Tool call ID
|
|
246
|
+
params: Original tool parameters
|
|
247
|
+
session_id: Session ID
|
|
248
|
+
invocation_id: Invocation ID
|
|
249
|
+
block_id: Frontend block ID
|
|
250
|
+
expires_in: Expiration in seconds (None = no expiration)
|
|
251
|
+
"""
|
|
252
|
+
from ..core.types.session import generate_id
|
|
253
|
+
import time
|
|
254
|
+
|
|
255
|
+
checkpoint_id = suspend.checkpoint_id or generate_id("ckpt")
|
|
256
|
+
callback_id = suspend.metadata.get("callback_id")
|
|
257
|
+
|
|
258
|
+
expires_at = None
|
|
259
|
+
if expires_in is not None:
|
|
260
|
+
expires_at = int(time.time()) + expires_in
|
|
261
|
+
|
|
262
|
+
return cls(
|
|
263
|
+
checkpoint_id=checkpoint_id,
|
|
264
|
+
callback_id=callback_id,
|
|
265
|
+
session_id=session_id,
|
|
266
|
+
invocation_id=invocation_id,
|
|
267
|
+
block_id=block_id,
|
|
268
|
+
tool_name=suspend.tool_name or "",
|
|
269
|
+
tool_call_id=tool_call_id,
|
|
270
|
+
params=params,
|
|
271
|
+
tool_state=suspend.tool_state or {},
|
|
272
|
+
hitl_id=suspend.hitl_id,
|
|
273
|
+
hitl_type=suspend.hitl_type,
|
|
274
|
+
expires_at=expires_at,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
78
278
|
__all__ = [
|
|
79
279
|
# Signals
|
|
80
280
|
"SuspendSignal",
|
|
@@ -84,4 +284,5 @@ __all__ = [
|
|
|
84
284
|
"HITLCancelledError",
|
|
85
285
|
# Types
|
|
86
286
|
"HITLRequest",
|
|
287
|
+
"ToolCheckpoint",
|
|
87
288
|
]
|
aury/agents/react/agent.py
CHANGED
|
@@ -377,6 +377,42 @@ class ReactAgent(BaseAgent):
|
|
|
377
377
|
))
|
|
378
378
|
break
|
|
379
379
|
|
|
380
|
+
# Re-fetch context from providers (providers decide whether to update)
|
|
381
|
+
logger.debug(
|
|
382
|
+
"Re-fetching agent context for step",
|
|
383
|
+
extra={
|
|
384
|
+
"invocation_id": self._current_invocation.id,
|
|
385
|
+
"step": self._current_step,
|
|
386
|
+
},
|
|
387
|
+
)
|
|
388
|
+
self._agent_context = await ctx_helpers.fetch_agent_context(
|
|
389
|
+
self._ctx,
|
|
390
|
+
input,
|
|
391
|
+
self._context_providers,
|
|
392
|
+
self._tools,
|
|
393
|
+
self._delegate_tool_class,
|
|
394
|
+
self._middleware_chain,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Update system message with new context (in case providers updated system_content)
|
|
398
|
+
if self._message_history and self._message_history[0].role == "system":
|
|
399
|
+
# Rebuild system message using helper
|
|
400
|
+
final_system_prompt = ctx_helpers.build_system_message(
|
|
401
|
+
self._agent_context,
|
|
402
|
+
self.config.system_prompt,
|
|
403
|
+
input,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Log if context was injected
|
|
407
|
+
if self._agent_context.system_content:
|
|
408
|
+
logger.info(
|
|
409
|
+
f"Updated system message with context (length: {len(self._agent_context.system_content)})",
|
|
410
|
+
extra={"invocation_id": self._current_invocation.id, "step": self._current_step},
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Update the system message
|
|
414
|
+
self._message_history[0] = LLMMessage(role="system", content=final_system_prompt)
|
|
415
|
+
|
|
380
416
|
# Take snapshot before step
|
|
381
417
|
snapshot_id = None
|
|
382
418
|
if self.snapshot:
|
|
@@ -528,6 +564,17 @@ class ReactAgent(BaseAgent):
|
|
|
528
564
|
if self._current_invocation:
|
|
529
565
|
self._current_invocation.state = InvocationState.SUSPENDED
|
|
530
566
|
|
|
567
|
+
# Save agent_state for resume (only if persist_hitl_state is enabled)
|
|
568
|
+
if self.config.persist_hitl_state:
|
|
569
|
+
self._current_invocation.agent_state = {
|
|
570
|
+
"step": self._current_step,
|
|
571
|
+
"message_history": [
|
|
572
|
+
{"role": m.role, "content": m.content} for m in self._message_history
|
|
573
|
+
],
|
|
574
|
+
"text_buffer": self._text_buffer,
|
|
575
|
+
}
|
|
576
|
+
self._current_invocation.step_count = self._current_step
|
|
577
|
+
|
|
531
578
|
# Save invocation state
|
|
532
579
|
if self.ctx.backends and self.ctx.backends.invocation:
|
|
533
580
|
await self.ctx.backends.invocation.update(
|
aury/agents/react/context.py
CHANGED
|
@@ -221,31 +221,11 @@ async def build_messages(
|
|
|
221
221
|
messages = []
|
|
222
222
|
|
|
223
223
|
# System message: config.system_prompt + agent_context.system_content
|
|
224
|
-
final_system_prompt =
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
# Build template variables: datetime + custom vars from input
|
|
230
|
-
template_vars = {
|
|
231
|
-
"current_date": now.strftime("%Y-%m-%d"),
|
|
232
|
-
"current_time": now.strftime("%H:%M:%S"),
|
|
233
|
-
"current_datetime": now.strftime("%Y-%m-%d %H:%M:%S"),
|
|
234
|
-
}
|
|
235
|
-
|
|
236
|
-
# Add custom variables from PromptInput (user_name, tenant, etc.)
|
|
237
|
-
if hasattr(input, 'vars') and input.vars:
|
|
238
|
-
template_vars.update(input.vars)
|
|
239
|
-
|
|
240
|
-
try:
|
|
241
|
-
final_system_prompt = final_system_prompt.format(**template_vars)
|
|
242
|
-
except KeyError as e:
|
|
243
|
-
# Log missing variable but continue
|
|
244
|
-
logger.debug(f"System prompt template variable not found: {e}")
|
|
245
|
-
pass
|
|
246
|
-
|
|
247
|
-
if agent_context.system_content:
|
|
248
|
-
final_system_prompt = final_system_prompt + "\n\n" + agent_context.system_content
|
|
224
|
+
final_system_prompt = build_system_message(
|
|
225
|
+
agent_context,
|
|
226
|
+
system_prompt,
|
|
227
|
+
input,
|
|
228
|
+
)
|
|
249
229
|
messages.append(LLMMessage(role="system", content=final_system_prompt))
|
|
250
230
|
|
|
251
231
|
# Historical messages from AgentContext (provided by MessageContextProvider)
|
|
@@ -283,6 +263,52 @@ async def build_messages(
|
|
|
283
263
|
return messages
|
|
284
264
|
|
|
285
265
|
|
|
266
|
+
def build_system_message(
|
|
267
|
+
agent_context: AgentContext,
|
|
268
|
+
base_system_prompt: str | None,
|
|
269
|
+
input: "PromptInput | None" = None,
|
|
270
|
+
) -> str:
|
|
271
|
+
"""Build system message with agent context.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
agent_context: Agent context with system_content, tools, etc.
|
|
275
|
+
base_system_prompt: Base system prompt (or None for default)
|
|
276
|
+
input: Prompt input for custom template variables
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Final system prompt string
|
|
280
|
+
"""
|
|
281
|
+
from datetime import datetime
|
|
282
|
+
|
|
283
|
+
# Get base prompt
|
|
284
|
+
final_system_prompt = base_system_prompt or default_system_prompt(agent_context.tools)
|
|
285
|
+
|
|
286
|
+
# Build template variables: datetime + custom vars from input
|
|
287
|
+
now = datetime.now()
|
|
288
|
+
template_vars = {
|
|
289
|
+
"current_date": now.strftime("%Y-%m-%d"),
|
|
290
|
+
"current_time": now.strftime("%H:%M:%S"),
|
|
291
|
+
"current_datetime": now.strftime("%Y-%m-%d %H:%M:%S"),
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
# Add custom variables from PromptInput
|
|
295
|
+
if input and hasattr(input, 'vars') and input.vars:
|
|
296
|
+
template_vars.update(input.vars)
|
|
297
|
+
|
|
298
|
+
# Format with template variables
|
|
299
|
+
try:
|
|
300
|
+
final_system_prompt = final_system_prompt.format(**template_vars)
|
|
301
|
+
except KeyError as e:
|
|
302
|
+
logger.debug(f"System prompt template variable not found: {e}")
|
|
303
|
+
pass
|
|
304
|
+
|
|
305
|
+
# Append system_content if available
|
|
306
|
+
if agent_context.system_content:
|
|
307
|
+
final_system_prompt = final_system_prompt + "\n\n" + agent_context.system_content
|
|
308
|
+
|
|
309
|
+
return final_system_prompt
|
|
310
|
+
|
|
311
|
+
|
|
286
312
|
def default_system_prompt(tools: list["BaseTool"]) -> str:
|
|
287
313
|
"""Generate default system prompt with tool descriptions.
|
|
288
314
|
|
aury/agents/react/factory.py
CHANGED
|
@@ -47,6 +47,8 @@ def create_react_agent(
|
|
|
47
47
|
delegate_tool_class: "type[BaseTool] | None" = None,
|
|
48
48
|
# Context metadata
|
|
49
49
|
context_metadata: dict | None = None,
|
|
50
|
+
# HITL resume support
|
|
51
|
+
invocation_id: str | None = None,
|
|
50
52
|
) -> "ReactAgent":
|
|
51
53
|
"""Create ReactAgent with minimal boilerplate.
|
|
52
54
|
|
aury/agents/react/pause.py
CHANGED
|
@@ -91,8 +91,9 @@ async def resume_agent_internal(agent: "ReactAgent", invocation_id: str) -> None
|
|
|
91
91
|
|
|
92
92
|
invocation = Invocation.from_dict(inv_data)
|
|
93
93
|
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
# Support both PAUSED and SUSPENDED (HITL) states
|
|
95
|
+
if invocation.state not in (InvocationState.PAUSED, InvocationState.SUSPENDED):
|
|
96
|
+
raise ValueError(f"Invocation is not paused/suspended: {invocation.state}")
|
|
96
97
|
|
|
97
98
|
# Restore state
|
|
98
99
|
agent._current_invocation = invocation
|
|
@@ -154,6 +155,16 @@ async def resume_agent_internal(agent: "ReactAgent", invocation_id: str) -> None
|
|
|
154
155
|
if not agent._paused:
|
|
155
156
|
agent._current_invocation.state = InvocationState.COMPLETED
|
|
156
157
|
agent._current_invocation.finished_at = __import__("datetime").datetime.now()
|
|
158
|
+
|
|
159
|
+
# Clear agent_state after successful completion (save space)
|
|
160
|
+
agent._current_invocation.agent_state = None
|
|
161
|
+
|
|
162
|
+
# Update invocation to database
|
|
163
|
+
if agent.ctx.backends and agent.ctx.backends.invocation:
|
|
164
|
+
await agent.ctx.backends.invocation.update(
|
|
165
|
+
agent._current_invocation.id,
|
|
166
|
+
agent._current_invocation.to_dict(),
|
|
167
|
+
)
|
|
157
168
|
|
|
158
169
|
except Exception as e:
|
|
159
170
|
agent._current_invocation.state = InvocationState.FAILED
|
aury/agents/react/step.py
CHANGED
|
@@ -289,8 +289,12 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
289
289
|
extra={"invocation_id": agent._current_invocation.id, "model": agent.llm.model},
|
|
290
290
|
)
|
|
291
291
|
|
|
292
|
+
# Track whether we aborted mid-stream
|
|
293
|
+
aborted = False
|
|
294
|
+
|
|
292
295
|
async for event in agent.llm.complete(**llm_kwargs):
|
|
293
296
|
if await agent._check_abort():
|
|
297
|
+
aborted = True
|
|
294
298
|
break
|
|
295
299
|
|
|
296
300
|
if event.type == "content":
|
|
@@ -360,6 +364,8 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
360
364
|
|
|
361
365
|
elif event.type == "thinking_completed":
|
|
362
366
|
# Thinking completed - emit block completed status
|
|
367
|
+
# Note: thinking_completed from LLM means it finished naturally,
|
|
368
|
+
# so we always use "completed" here (not aborted)
|
|
363
369
|
if agent._current_thinking_block_id and not thinking_completed_emitted:
|
|
364
370
|
await agent.ctx.emit(BlockEvent(
|
|
365
371
|
block_id=agent._current_thinking_block_id,
|
|
@@ -386,14 +392,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
386
392
|
block_id = generate_id("blk")
|
|
387
393
|
agent._tool_call_blocks[tc.id] = block_id
|
|
388
394
|
|
|
395
|
+
# Get display_name from tool if available
|
|
396
|
+
tool = agent._get_tool(tc.name)
|
|
397
|
+
display_name = tool.display_name if tool else tc.name
|
|
398
|
+
|
|
389
399
|
await agent.ctx.emit(BlockEvent(
|
|
390
400
|
block_id=block_id,
|
|
391
401
|
kind=BlockKind.TOOL_USE,
|
|
392
402
|
op=BlockOp.APPLY,
|
|
393
403
|
data={
|
|
394
404
|
"name": tc.name,
|
|
405
|
+
"display_name": display_name,
|
|
395
406
|
"call_id": tc.id,
|
|
396
|
-
"status": "
|
|
407
|
+
"status": "pending", # Initial status, arguments pending
|
|
397
408
|
},
|
|
398
409
|
))
|
|
399
410
|
|
|
@@ -498,9 +509,13 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
498
509
|
# Tool call complete (arguments fully received)
|
|
499
510
|
if event.tool_call:
|
|
500
511
|
tc = event.tool_call
|
|
512
|
+
# Strict mode: tool_call_start must have been received
|
|
513
|
+
block_id = agent._tool_call_blocks[tc.id] # Will raise KeyError if not found
|
|
514
|
+
|
|
501
515
|
invocation = ToolInvocation(
|
|
502
516
|
tool_call_id=tc.id,
|
|
503
517
|
tool_name=tc.name,
|
|
518
|
+
block_id=block_id,
|
|
504
519
|
args_raw=tc.arguments,
|
|
505
520
|
state=ToolInvocationState.CALL,
|
|
506
521
|
)
|
|
@@ -512,18 +527,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
512
527
|
invocation.args = {}
|
|
513
528
|
|
|
514
529
|
agent._tool_invocations.append(invocation)
|
|
515
|
-
|
|
516
|
-
#
|
|
517
|
-
|
|
530
|
+
|
|
531
|
+
# Build patch data
|
|
532
|
+
patch_data: dict[str, Any] = {
|
|
533
|
+
"call_id": tc.id,
|
|
534
|
+
"arguments": invocation.args,
|
|
535
|
+
"status": "ready",
|
|
536
|
+
}
|
|
537
|
+
|
|
518
538
|
await agent.ctx.emit(BlockEvent(
|
|
519
539
|
block_id=block_id,
|
|
520
540
|
kind=BlockKind.TOOL_USE,
|
|
521
541
|
op=BlockOp.PATCH,
|
|
522
|
-
data=
|
|
523
|
-
"call_id": tc.id,
|
|
524
|
-
"arguments": invocation.args,
|
|
525
|
-
"status": "ready",
|
|
526
|
-
},
|
|
542
|
+
data=patch_data,
|
|
527
543
|
))
|
|
528
544
|
|
|
529
545
|
await agent.bus.publish(
|
|
@@ -605,16 +621,27 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
605
621
|
extra={"invocation_id": agent._current_invocation.id},
|
|
606
622
|
)
|
|
607
623
|
|
|
608
|
-
# Emit
|
|
624
|
+
# Emit thinking block final status if streaming and not yet completed
|
|
625
|
+
if agent._current_thinking_block_id and not thinking_completed_emitted:
|
|
626
|
+
status = "aborted" if aborted else "completed"
|
|
627
|
+
await agent.ctx.emit(BlockEvent(
|
|
628
|
+
block_id=agent._current_thinking_block_id,
|
|
629
|
+
kind=BlockKind.THINKING,
|
|
630
|
+
op=BlockOp.PATCH,
|
|
631
|
+
data={"status": status},
|
|
632
|
+
))
|
|
633
|
+
|
|
634
|
+
# Emit text block final status (completed or aborted)
|
|
609
635
|
if agent._current_text_block_id:
|
|
636
|
+
status = "aborted" if aborted else "completed"
|
|
610
637
|
await agent.ctx.emit(BlockEvent(
|
|
611
638
|
block_id=agent._current_text_block_id,
|
|
612
639
|
kind=BlockKind.TEXT,
|
|
613
640
|
op=BlockOp.PATCH,
|
|
614
|
-
data={"status":
|
|
641
|
+
data={"status": status},
|
|
615
642
|
))
|
|
616
643
|
|
|
617
|
-
# If thinking was buffered, emit it now
|
|
644
|
+
# If thinking was buffered, emit it now (non-streaming mode)
|
|
618
645
|
if agent._thinking_buffer and not agent.config.stream_thinking:
|
|
619
646
|
await agent.ctx.emit(BlockEvent(
|
|
620
647
|
kind=BlockKind.THINKING,
|