aury-agent 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,16 +35,16 @@ class HITLRequest:
35
35
 
36
36
  Stored in invocation for persistence.
37
37
  """
38
- request_id: str
39
- request_type: str # ask_user, permission, form, workflow_human
38
+ hitl_id: str
39
+ hitl_type: str # ask_user, confirm, permission, external_auth, workflow_human
40
40
 
41
- # Display
42
- message: str | None = None
43
- options: list[str] | None = None
41
+ # Type-specific data
42
+ data: dict[str, Any] = field(default_factory=dict) # {message, options, ...}
44
43
 
45
44
  # Context
46
45
  tool_name: str | None = None # If triggered by tool
47
46
  node_id: str | None = None # If triggered by workflow node
47
+ block_id: str | None = None # Associated UI block
48
48
 
49
49
  # Metadata
50
50
  metadata: dict[str, Any] = field(default_factory=dict)
@@ -52,12 +52,12 @@ class HITLRequest:
52
52
  def to_dict(self) -> dict[str, Any]:
53
53
  """Convert to dictionary for serialization."""
54
54
  return {
55
- "request_id": self.request_id,
56
- "request_type": self.request_type,
57
- "message": self.message,
58
- "options": self.options,
55
+ "hitl_id": self.hitl_id,
56
+ "hitl_type": self.hitl_type,
57
+ "data": self.data,
59
58
  "tool_name": self.tool_name,
60
59
  "node_id": self.node_id,
60
+ "block_id": self.block_id,
61
61
  "metadata": self.metadata,
62
62
  }
63
63
 
@@ -65,16 +65,216 @@ class HITLRequest:
65
65
  def from_dict(cls, data: dict[str, Any]) -> "HITLRequest":
66
66
  """Create from dictionary."""
67
67
  return cls(
68
- request_id=data["request_id"],
69
- request_type=data.get("request_type", "ask_user"),
70
- message=data.get("message"),
71
- options=data.get("options"),
68
+ hitl_id=data["hitl_id"],
69
+ hitl_type=data.get("hitl_type", "ask_user"),
70
+ data=data.get("data", {}),
72
71
  tool_name=data.get("tool_name"),
73
72
  node_id=data.get("node_id"),
73
+ block_id=data.get("block_id"),
74
74
  metadata=data.get("metadata", {}),
75
75
  )
76
76
 
77
77
 
78
+ @dataclass
79
+ class ToolCheckpoint:
80
+ """Tool execution checkpoint for continuation mode.
81
+
82
+ When a tool raises HITLSuspend with resume_mode="continuation",
83
+ the framework creates a ToolCheckpoint to save the tool's execution
84
+ state. When the user responds, the tool is resumed from this checkpoint.
85
+
86
+ Use cases:
87
+ - OAuth authorization flow (wait for callback)
88
+ - Payment confirmation (wait for payment gateway)
89
+ - Multi-step wizards with user confirmation
90
+ - External system integration with async callbacks
91
+
92
+ Storage:
93
+ - Stored via CheckpointBackend (Redis/DB)
94
+ - Keyed by checkpoint_id and callback_id
95
+ - Has TTL for automatic expiration
96
+
97
+ Example:
98
+ # In tool execution:
99
+ raise HITLSuspend(
100
+ request_id="hitl_123",
101
+ request_type="external_auth",
102
+ resume_mode="continuation",
103
+ tool_state={"step": 2, "partial_data": {...}},
104
+ metadata={"auth_url": "https://...", "callback_id": "cb_456"},
105
+ )
106
+
107
+ # Framework creates ToolCheckpoint and saves it
108
+ # When callback arrives, framework loads checkpoint and resumes tool
109
+ """
110
+
111
+ # Identity
112
+ checkpoint_id: str
113
+ callback_id: str | None = None # For external callback matching
114
+
115
+ # Association
116
+ session_id: str | None = None
117
+ invocation_id: str | None = None
118
+ block_id: str | None = None # Frontend HITL block
119
+
120
+ # Tool execution context
121
+ tool_name: str = ""
122
+ tool_call_id: str = ""
123
+ params: dict[str, Any] = field(default_factory=dict) # Original params
124
+ tool_state: dict[str, Any] = field(default_factory=dict) # Internal state
125
+
126
+ # HITL info
127
+ hitl_id: str = ""
128
+ hitl_type: str = "" # ask_user, confirm, external_auth, etc.
129
+
130
+ # Status
131
+ status: str = "pending" # pending | completed | expired | failed | cancelled
132
+ expires_at: int | None = None # Unix timestamp
133
+
134
+ # User response (filled after callback/response)
135
+ user_response: Any | None = None
136
+ error: str | None = None
137
+
138
+ # Timestamps
139
+ created_at: int = 0
140
+ updated_at: int = 0
141
+
142
+ def __post_init__(self):
143
+ import time
144
+ now = int(time.time())
145
+ if not self.created_at:
146
+ self.created_at = now
147
+ if not self.updated_at:
148
+ self.updated_at = now
149
+
150
+ @property
151
+ def is_expired(self) -> bool:
152
+ """Check if checkpoint has expired."""
153
+ if self.expires_at is None:
154
+ return False
155
+ import time
156
+ return time.time() > self.expires_at
157
+
158
+ @property
159
+ def is_pending(self) -> bool:
160
+ """Check if checkpoint is waiting for response."""
161
+ return self.status == "pending" and not self.is_expired
162
+
163
+ def mark_completed(self, response: Any) -> None:
164
+ """Mark checkpoint as completed with user response."""
165
+ import time
166
+ self.status = "completed"
167
+ self.user_response = response
168
+ self.updated_at = int(time.time())
169
+
170
+ def mark_failed(self, error: str) -> None:
171
+ """Mark checkpoint as failed."""
172
+ import time
173
+ self.status = "failed"
174
+ self.error = error
175
+ self.updated_at = int(time.time())
176
+
177
+ def mark_cancelled(self, reason: str = "user_cancelled") -> None:
178
+ """Mark checkpoint as cancelled."""
179
+ import time
180
+ self.status = "cancelled"
181
+ self.error = reason
182
+ self.updated_at = int(time.time())
183
+
184
+ def to_dict(self) -> dict[str, Any]:
185
+ """Convert to dictionary for serialization."""
186
+ return {
187
+ "checkpoint_id": self.checkpoint_id,
188
+ "callback_id": self.callback_id,
189
+ "session_id": self.session_id,
190
+ "invocation_id": self.invocation_id,
191
+ "block_id": self.block_id,
192
+ "tool_name": self.tool_name,
193
+ "tool_call_id": self.tool_call_id,
194
+ "params": self.params,
195
+ "tool_state": self.tool_state,
196
+ "hitl_id": self.hitl_id,
197
+ "hitl_type": self.hitl_type,
198
+ "status": self.status,
199
+ "expires_at": self.expires_at,
200
+ "user_response": self.user_response,
201
+ "error": self.error,
202
+ "created_at": self.created_at,
203
+ "updated_at": self.updated_at,
204
+ }
205
+
206
+ @classmethod
207
+ def from_dict(cls, data: dict[str, Any]) -> "ToolCheckpoint":
208
+ """Create from dictionary."""
209
+ return cls(
210
+ checkpoint_id=data["checkpoint_id"],
211
+ callback_id=data.get("callback_id"),
212
+ session_id=data.get("session_id"),
213
+ invocation_id=data.get("invocation_id"),
214
+ block_id=data.get("block_id"),
215
+ tool_name=data.get("tool_name", ""),
216
+ tool_call_id=data.get("tool_call_id", ""),
217
+ params=data.get("params", {}),
218
+ tool_state=data.get("tool_state", {}),
219
+ hitl_id=data.get("hitl_id", ""),
220
+ hitl_type=data.get("hitl_type", ""),
221
+ status=data.get("status", "pending"),
222
+ expires_at=data.get("expires_at"),
223
+ user_response=data.get("user_response"),
224
+ error=data.get("error"),
225
+ created_at=data.get("created_at", 0),
226
+ updated_at=data.get("updated_at", 0),
227
+ )
228
+
229
+ @classmethod
230
+ def from_hitl_suspend(
231
+ cls,
232
+ suspend: HITLSuspend,
233
+ *,
234
+ tool_call_id: str,
235
+ params: dict[str, Any],
236
+ session_id: str | None = None,
237
+ invocation_id: str | None = None,
238
+ block_id: str | None = None,
239
+ expires_in: int | None = 600, # Default 10 minutes
240
+ ) -> "ToolCheckpoint":
241
+ """Create checkpoint from HITLSuspend signal.
242
+
243
+ Args:
244
+ suspend: The HITLSuspend signal
245
+ tool_call_id: Tool call ID
246
+ params: Original tool parameters
247
+ session_id: Session ID
248
+ invocation_id: Invocation ID
249
+ block_id: Frontend block ID
250
+ expires_in: Expiration in seconds (None = no expiration)
251
+ """
252
+ from ..core.types.session import generate_id
253
+ import time
254
+
255
+ checkpoint_id = suspend.checkpoint_id or generate_id("ckpt")
256
+ callback_id = suspend.metadata.get("callback_id")
257
+
258
+ expires_at = None
259
+ if expires_in is not None:
260
+ expires_at = int(time.time()) + expires_in
261
+
262
+ return cls(
263
+ checkpoint_id=checkpoint_id,
264
+ callback_id=callback_id,
265
+ session_id=session_id,
266
+ invocation_id=invocation_id,
267
+ block_id=block_id,
268
+ tool_name=suspend.tool_name or "",
269
+ tool_call_id=tool_call_id,
270
+ params=params,
271
+ tool_state=suspend.tool_state or {},
272
+ hitl_id=suspend.hitl_id,
273
+ hitl_type=suspend.hitl_type,
274
+ expires_at=expires_at,
275
+ )
276
+
277
+
78
278
  __all__ = [
79
279
  # Signals
80
280
  "SuspendSignal",
@@ -84,4 +284,5 @@ __all__ = [
84
284
  "HITLCancelledError",
85
285
  # Types
86
286
  "HITLRequest",
287
+ "ToolCheckpoint",
87
288
  ]
@@ -377,6 +377,42 @@ class ReactAgent(BaseAgent):
377
377
  ))
378
378
  break
379
379
 
380
+ # Re-fetch context from providers (providers decide whether to update)
381
+ logger.debug(
382
+ "Re-fetching agent context for step",
383
+ extra={
384
+ "invocation_id": self._current_invocation.id,
385
+ "step": self._current_step,
386
+ },
387
+ )
388
+ self._agent_context = await ctx_helpers.fetch_agent_context(
389
+ self._ctx,
390
+ input,
391
+ self._context_providers,
392
+ self._tools,
393
+ self._delegate_tool_class,
394
+ self._middleware_chain,
395
+ )
396
+
397
+ # Update system message with new context (in case providers updated system_content)
398
+ if self._message_history and self._message_history[0].role == "system":
399
+ # Rebuild system message using helper
400
+ final_system_prompt = ctx_helpers.build_system_message(
401
+ self._agent_context,
402
+ self.config.system_prompt,
403
+ input,
404
+ )
405
+
406
+ # Log if context was injected
407
+ if self._agent_context.system_content:
408
+ logger.info(
409
+ f"Updated system message with context (length: {len(self._agent_context.system_content)})",
410
+ extra={"invocation_id": self._current_invocation.id, "step": self._current_step},
411
+ )
412
+
413
+ # Update the system message
414
+ self._message_history[0] = LLMMessage(role="system", content=final_system_prompt)
415
+
380
416
  # Take snapshot before step
381
417
  snapshot_id = None
382
418
  if self.snapshot:
@@ -528,6 +564,17 @@ class ReactAgent(BaseAgent):
528
564
  if self._current_invocation:
529
565
  self._current_invocation.state = InvocationState.SUSPENDED
530
566
 
567
+ # Save agent_state for resume (only if persist_hitl_state is enabled)
568
+ if self.config.persist_hitl_state:
569
+ self._current_invocation.agent_state = {
570
+ "step": self._current_step,
571
+ "message_history": [
572
+ {"role": m.role, "content": m.content} for m in self._message_history
573
+ ],
574
+ "text_buffer": self._text_buffer,
575
+ }
576
+ self._current_invocation.step_count = self._current_step
577
+
531
578
  # Save invocation state
532
579
  if self.ctx.backends and self.ctx.backends.invocation:
533
580
  await self.ctx.backends.invocation.update(
@@ -221,31 +221,11 @@ async def build_messages(
221
221
  messages = []
222
222
 
223
223
  # System message: config.system_prompt + agent_context.system_content
224
- final_system_prompt = system_prompt or default_system_prompt(agent_context.tools)
225
-
226
- # Format system_prompt with dynamic variables
227
- now = datetime.now()
228
-
229
- # Build template variables: datetime + custom vars from input
230
- template_vars = {
231
- "current_date": now.strftime("%Y-%m-%d"),
232
- "current_time": now.strftime("%H:%M:%S"),
233
- "current_datetime": now.strftime("%Y-%m-%d %H:%M:%S"),
234
- }
235
-
236
- # Add custom variables from PromptInput (user_name, tenant, etc.)
237
- if hasattr(input, 'vars') and input.vars:
238
- template_vars.update(input.vars)
239
-
240
- try:
241
- final_system_prompt = final_system_prompt.format(**template_vars)
242
- except KeyError as e:
243
- # Log missing variable but continue
244
- logger.debug(f"System prompt template variable not found: {e}")
245
- pass
246
-
247
- if agent_context.system_content:
248
- final_system_prompt = final_system_prompt + "\n\n" + agent_context.system_content
224
+ final_system_prompt = build_system_message(
225
+ agent_context,
226
+ system_prompt,
227
+ input,
228
+ )
249
229
  messages.append(LLMMessage(role="system", content=final_system_prompt))
250
230
 
251
231
  # Historical messages from AgentContext (provided by MessageContextProvider)
@@ -283,6 +263,52 @@ async def build_messages(
283
263
  return messages
284
264
 
285
265
 
266
+ def build_system_message(
267
+ agent_context: AgentContext,
268
+ base_system_prompt: str | None,
269
+ input: "PromptInput | None" = None,
270
+ ) -> str:
271
+ """Build system message with agent context.
272
+
273
+ Args:
274
+ agent_context: Agent context with system_content, tools, etc.
275
+ base_system_prompt: Base system prompt (or None for default)
276
+ input: Prompt input for custom template variables
277
+
278
+ Returns:
279
+ Final system prompt string
280
+ """
281
+ from datetime import datetime
282
+
283
+ # Get base prompt
284
+ final_system_prompt = base_system_prompt or default_system_prompt(agent_context.tools)
285
+
286
+ # Build template variables: datetime + custom vars from input
287
+ now = datetime.now()
288
+ template_vars = {
289
+ "current_date": now.strftime("%Y-%m-%d"),
290
+ "current_time": now.strftime("%H:%M:%S"),
291
+ "current_datetime": now.strftime("%Y-%m-%d %H:%M:%S"),
292
+ }
293
+
294
+ # Add custom variables from PromptInput
295
+ if input and hasattr(input, 'vars') and input.vars:
296
+ template_vars.update(input.vars)
297
+
298
+ # Format with template variables
299
+ try:
300
+ final_system_prompt = final_system_prompt.format(**template_vars)
301
+ except KeyError as e:
302
+ logger.debug(f"System prompt template variable not found: {e}")
303
+ pass
304
+
305
+ # Append system_content if available
306
+ if agent_context.system_content:
307
+ final_system_prompt = final_system_prompt + "\n\n" + agent_context.system_content
308
+
309
+ return final_system_prompt
310
+
311
+
286
312
  def default_system_prompt(tools: list["BaseTool"]) -> str:
287
313
  """Generate default system prompt with tool descriptions.
288
314
 
@@ -47,6 +47,8 @@ def create_react_agent(
47
47
  delegate_tool_class: "type[BaseTool] | None" = None,
48
48
  # Context metadata
49
49
  context_metadata: dict | None = None,
50
+ # HITL resume support
51
+ invocation_id: str | None = None,
50
52
  ) -> "ReactAgent":
51
53
  """Create ReactAgent with minimal boilerplate.
52
54
 
@@ -91,8 +91,9 @@ async def resume_agent_internal(agent: "ReactAgent", invocation_id: str) -> None
91
91
 
92
92
  invocation = Invocation.from_dict(inv_data)
93
93
 
94
- if invocation.state != InvocationState.PAUSED:
95
- raise ValueError(f"Invocation is not paused: {invocation.state}")
94
+ # Support both PAUSED and SUSPENDED (HITL) states
95
+ if invocation.state not in (InvocationState.PAUSED, InvocationState.SUSPENDED):
96
+ raise ValueError(f"Invocation is not paused/suspended: {invocation.state}")
96
97
 
97
98
  # Restore state
98
99
  agent._current_invocation = invocation
@@ -154,6 +155,16 @@ async def resume_agent_internal(agent: "ReactAgent", invocation_id: str) -> None
154
155
  if not agent._paused:
155
156
  agent._current_invocation.state = InvocationState.COMPLETED
156
157
  agent._current_invocation.finished_at = __import__("datetime").datetime.now()
158
+
159
+ # Clear agent_state after successful completion (save space)
160
+ agent._current_invocation.agent_state = None
161
+
162
+ # Update invocation to database
163
+ if agent.ctx.backends and agent.ctx.backends.invocation:
164
+ await agent.ctx.backends.invocation.update(
165
+ agent._current_invocation.id,
166
+ agent._current_invocation.to_dict(),
167
+ )
157
168
 
158
169
  except Exception as e:
159
170
  agent._current_invocation.state = InvocationState.FAILED
aury/agents/react/step.py CHANGED
@@ -289,8 +289,12 @@ async def execute_step(agent: "ReactAgent") -> str | None:
289
289
  extra={"invocation_id": agent._current_invocation.id, "model": agent.llm.model},
290
290
  )
291
291
 
292
+ # Track whether we aborted mid-stream
293
+ aborted = False
294
+
292
295
  async for event in agent.llm.complete(**llm_kwargs):
293
296
  if await agent._check_abort():
297
+ aborted = True
294
298
  break
295
299
 
296
300
  if event.type == "content":
@@ -360,6 +364,8 @@ async def execute_step(agent: "ReactAgent") -> str | None:
360
364
 
361
365
  elif event.type == "thinking_completed":
362
366
  # Thinking completed - emit block completed status
367
+ # Note: thinking_completed from LLM means it finished naturally,
368
+ # so we always use "completed" here (not aborted)
363
369
  if agent._current_thinking_block_id and not thinking_completed_emitted:
364
370
  await agent.ctx.emit(BlockEvent(
365
371
  block_id=agent._current_thinking_block_id,
@@ -386,14 +392,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
386
392
  block_id = generate_id("blk")
387
393
  agent._tool_call_blocks[tc.id] = block_id
388
394
 
395
+ # Get display_name from tool if available
396
+ tool = agent._get_tool(tc.name)
397
+ display_name = tool.display_name if tool else tc.name
398
+
389
399
  await agent.ctx.emit(BlockEvent(
390
400
  block_id=block_id,
391
401
  kind=BlockKind.TOOL_USE,
392
402
  op=BlockOp.APPLY,
393
403
  data={
394
404
  "name": tc.name,
405
+ "display_name": display_name,
395
406
  "call_id": tc.id,
396
- "status": "streaming", # Indicate arguments are streaming
407
+ "status": "pending", # Initial status, arguments pending
397
408
  },
398
409
  ))
399
410
 
@@ -498,9 +509,13 @@ async def execute_step(agent: "ReactAgent") -> str | None:
498
509
  # Tool call complete (arguments fully received)
499
510
  if event.tool_call:
500
511
  tc = event.tool_call
512
+ # Strict mode: tool_call_start must have been received
513
+ block_id = agent._tool_call_blocks[tc.id] # Will raise KeyError if not found
514
+
501
515
  invocation = ToolInvocation(
502
516
  tool_call_id=tc.id,
503
517
  tool_name=tc.name,
518
+ block_id=block_id,
504
519
  args_raw=tc.arguments,
505
520
  state=ToolInvocationState.CALL,
506
521
  )
@@ -512,18 +527,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
512
527
  invocation.args = {}
513
528
 
514
529
  agent._tool_invocations.append(invocation)
515
-
516
- # Strict mode: tool_call_start must have been received
517
- block_id = agent._tool_call_blocks[tc.id] # Will raise KeyError if not found
530
+
531
+ # Build patch data
532
+ patch_data: dict[str, Any] = {
533
+ "call_id": tc.id,
534
+ "arguments": invocation.args,
535
+ "status": "ready",
536
+ }
537
+
518
538
  await agent.ctx.emit(BlockEvent(
519
539
  block_id=block_id,
520
540
  kind=BlockKind.TOOL_USE,
521
541
  op=BlockOp.PATCH,
522
- data={
523
- "call_id": tc.id,
524
- "arguments": invocation.args,
525
- "status": "ready",
526
- },
542
+ data=patch_data,
527
543
  ))
528
544
 
529
545
  await agent.bus.publish(
@@ -605,16 +621,27 @@ async def execute_step(agent: "ReactAgent") -> str | None:
605
621
  extra={"invocation_id": agent._current_invocation.id},
606
622
  )
607
623
 
608
- # Emit text block completed status
624
+ # Emit thinking block final status if streaming and not yet completed
625
+ if agent._current_thinking_block_id and not thinking_completed_emitted:
626
+ status = "aborted" if aborted else "completed"
627
+ await agent.ctx.emit(BlockEvent(
628
+ block_id=agent._current_thinking_block_id,
629
+ kind=BlockKind.THINKING,
630
+ op=BlockOp.PATCH,
631
+ data={"status": status},
632
+ ))
633
+
634
+ # Emit text block final status (completed or aborted)
609
635
  if agent._current_text_block_id:
636
+ status = "aborted" if aborted else "completed"
610
637
  await agent.ctx.emit(BlockEvent(
611
638
  block_id=agent._current_text_block_id,
612
639
  kind=BlockKind.TEXT,
613
640
  op=BlockOp.PATCH,
614
- data={"status": "completed"},
641
+ data={"status": status},
615
642
  ))
616
643
 
617
- # If thinking was buffered, emit it now
644
+ # If thinking was buffered, emit it now (non-streaming mode)
618
645
  if agent._thinking_buffer and not agent.config.stream_thinking:
619
646
  await agent.ctx.emit(BlockEvent(
620
647
  kind=BlockKind.THINKING,