aury-agent 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,31 +9,33 @@ All services (llm, tools, storage, etc.) are accessed through ctx.
9
9
  from __future__ import annotations
10
10
 
11
11
  import asyncio
12
- import json
13
- from dataclasses import asdict
14
12
  from datetime import datetime
15
13
  from typing import Any, AsyncIterator, ClassVar, Literal, TYPE_CHECKING
16
14
 
17
- from ..core.base import AgentConfig, BaseAgent, ToolInjectionMode
15
+ from ..core.base import AgentConfig, BaseAgent
18
16
  from ..core.context import InvocationContext
19
17
  from ..core.logging import react_logger as logger
20
18
  from ..core.event_bus import Events
21
- from ..context_providers import ContextProvider, AgentContext
19
+ from ..context_providers import AgentContext
22
20
  from ..core.types.block import BlockEvent, BlockKind, BlockOp
23
- from ..llm import LLMMessage, ToolDefinition
21
+ from ..llm import LLMMessage
24
22
  from ..middleware import HookAction
25
23
  from ..core.types import (
26
24
  Invocation,
27
25
  InvocationState,
28
26
  PromptInput,
29
- ToolContext,
30
- ToolResult,
31
27
  ToolInvocation,
32
- ToolInvocationState,
33
28
  generate_id,
34
29
  )
35
- from ..core.state import State
36
- from ..core.signals import SuspendSignal, HITLSuspend
30
+ from ..core.signals import SuspendSignal
31
+
32
+ # Import helper modules
33
+ from . import context as ctx_helpers
34
+ from . import step as step_helpers
35
+ from . import tools as tool_helpers
36
+ from . import persistence as persist_helpers
37
+ from . import pause as pause_helpers
38
+ from .factory import SessionNotFoundError, create_react_agent, restore_react_agent
37
39
 
38
40
  if TYPE_CHECKING:
39
41
  from ..llm import LLMProvider
@@ -41,17 +43,12 @@ if TYPE_CHECKING:
41
43
  from ..core.types.tool import BaseTool
42
44
  from ..core.types.session import Session
43
45
  from ..backends import Backends
44
- from ..backends.state import StateBackend
45
46
  from ..backends.snapshot import SnapshotBackend
46
47
  from ..backends.subagent import AgentConfig as SubAgentConfig
47
48
  from ..core.event_bus import Bus
48
49
  from ..middleware import MiddlewareChain, Middleware
49
50
  from ..memory import MemoryManager
50
-
51
-
52
- class SessionNotFoundError(Exception):
53
- """Raised when session is not found in storage."""
54
- pass
51
+ from ..context_providers import ContextProvider
55
52
 
56
53
 
57
54
  class ReactAgent(BaseAgent):
@@ -76,6 +73,8 @@ class ReactAgent(BaseAgent):
76
73
  # Class-level config
77
74
  agent_type: ClassVar[Literal["react", "workflow"]] = "react"
78
75
 
76
+ # ========== Factory methods (delegate to factory.py) ==========
77
+
79
78
  @classmethod
80
79
  def create(
81
80
  cls,
@@ -90,145 +89,28 @@ class ReactAgent(BaseAgent):
90
89
  subagents: "list[SubAgentConfig] | None" = None,
91
90
  memory: "MemoryManager | None" = None,
92
91
  snapshot: "SnapshotBackend | None" = None,
93
- # ContextProvider system
94
92
  context_providers: "list[ContextProvider] | None" = None,
95
93
  enable_history: bool = True,
96
94
  history_limit: int = 50,
97
- # Tool customization
98
95
  delegate_tool_class: "type[BaseTool] | None" = None,
99
96
  ) -> "ReactAgent":
100
- """Create ReactAgent with minimal boilerplate.
101
-
102
- This is the recommended way to create a ReactAgent for simple use cases.
103
- Session, Storage, and Bus are auto-created if not provided.
104
-
105
- Args:
106
- llm: LLM provider (required)
107
- tools: Tool registry or list of tools (optional)
108
- config: Agent configuration (optional)
109
- backends: Backends container (recommended, auto-created if None)
110
- session: Session object (auto-created if None)
111
- bus: Event bus (auto-created if None)
112
- middlewares: List of middlewares (auto-creates chain)
113
- subagents: List of sub-agent configs (auto-creates SubAgentManager)
114
- memory: Memory manager (optional)
115
- snapshot: Snapshot backend (optional)
116
- context_providers: Additional custom context providers (optional)
117
- enable_history: Enable message history (default True)
118
- history_limit: Max conversation turns to keep (default 50)
119
- delegate_tool_class: Custom DelegateTool class (optional)
120
-
121
- Returns:
122
- Configured ReactAgent ready to run
123
-
124
- Example:
125
- # Minimal
126
- agent = ReactAgent.create(llm=my_llm)
127
-
128
- # With backends
129
- agent = ReactAgent.create(
130
- llm=my_llm,
131
- backends=Backends.create_default(),
132
- )
133
-
134
- # With tools and middlewares
135
- agent = ReactAgent.create(
136
- llm=my_llm,
137
- tools=[tool1, tool2],
138
- middlewares=[MessageContainerMiddleware()],
139
- )
140
-
141
- # With sub-agents
142
- agent = ReactAgent.create(
143
- llm=my_llm,
144
- subagents=[
145
- AgentConfig(key="researcher", agent=researcher_agent),
146
- ],
147
- )
148
-
149
- # With custom context providers
150
- agent = ReactAgent.create(
151
- llm=my_llm,
152
- tools=[tool1],
153
- context_providers=[MyRAGProvider(), MyProjectProvider()],
154
- )
155
- """
156
- from ..core.event_bus import EventBus
157
- from ..core.types.session import Session, generate_id
158
- from ..backends import Backends
159
- from ..backends.subagent import ListSubAgentBackend
160
- from ..tool import ToolSet
161
- from ..tool.builtin import DelegateTool
162
- from ..middleware import MiddlewareChain, MessageBackendMiddleware
163
- from ..context_providers import MessageContextProvider
164
-
165
- # Auto-create backends if not provided
166
- if backends is None:
167
- backends = Backends.create_default()
168
-
169
- # Auto-create missing components
170
- if session is None:
171
- session = Session(id=generate_id("sess"))
172
- if bus is None:
173
- bus = EventBus()
174
-
175
- # Create middleware chain (add MessageBackendMiddleware if history enabled)
176
- middleware_chain: MiddlewareChain | None = None
177
- if middlewares or enable_history:
178
- middleware_chain = MiddlewareChain()
179
- # Add message persistence middleware first (uses backends.message)
180
- if enable_history and backends.message is not None:
181
- middleware_chain.use(MessageBackendMiddleware(max_history=history_limit))
182
- # Add user middlewares
183
- if middlewares:
184
- for mw in middlewares:
185
- middleware_chain.use(mw)
186
-
187
- # === Build tools list (direct, no provider) ===
188
- tool_list: list["BaseTool"] = []
189
- if tools is not None:
190
- if isinstance(tools, ToolSet):
191
- tool_list = list(tools.all())
192
- else:
193
- tool_list = list(tools)
194
-
195
- # Handle subagents - create DelegateTool directly
196
- if subagents:
197
- backend = ListSubAgentBackend(subagents)
198
- tool_cls = delegate_tool_class or DelegateTool
199
- delegate_tool = tool_cls(backend, middleware=middleware_chain)
200
- tool_list.append(delegate_tool)
201
-
202
- # === Build providers ===
203
- default_providers: list["ContextProvider"] = []
204
-
205
- # MessageContextProvider - for fetching history (uses backends.message)
206
- if enable_history:
207
- message_provider = MessageContextProvider(max_messages=history_limit * 2)
208
- default_providers.append(message_provider)
209
-
210
- # Combine default + custom context_providers
211
- all_providers = default_providers + (context_providers or [])
212
-
213
- # Build context
214
- ctx = InvocationContext(
215
- session=session,
216
- invocation_id=generate_id("inv"),
217
- agent_id=config.name if config else "react_agent",
97
+ """Create ReactAgent with minimal boilerplate. See factory.create_react_agent for details."""
98
+ return create_react_agent(
99
+ llm=llm,
100
+ tools=tools,
101
+ config=config,
218
102
  backends=backends,
103
+ session=session,
219
104
  bus=bus,
220
- llm=llm,
221
- middleware=middleware_chain,
105
+ middlewares=middlewares,
106
+ subagents=subagents,
222
107
  memory=memory,
223
108
  snapshot=snapshot,
109
+ context_providers=context_providers,
110
+ enable_history=enable_history,
111
+ history_limit=history_limit,
112
+ delegate_tool_class=delegate_tool_class,
224
113
  )
225
-
226
- agent = cls(ctx, config)
227
- agent._tools = tool_list # Direct tools (not from context_provider)
228
- agent._context_providers = all_providers
229
- agent._delegate_tool_class = delegate_tool_class or DelegateTool
230
- agent._middleware_chain = middleware_chain
231
- return agent
232
114
 
233
115
  @classmethod
234
116
  async def restore(
@@ -244,113 +126,18 @@ class ReactAgent(BaseAgent):
244
126
  memory: "MemoryManager | None" = None,
245
127
  snapshot: "SnapshotBackend | None" = None,
246
128
  ) -> "ReactAgent":
247
- """Restore agent from persisted state.
248
-
249
- Use this to resume an agent after:
250
- - Page refresh
251
- - Process restart
252
- - Cross-process recovery
253
-
254
- Args:
255
- session_id: Session ID to restore
256
- llm: LLM provider
257
- backends: Backends container (recommended, auto-created if None)
258
- tools: Tool registry or list of tools
259
- config: Agent configuration
260
- bus: Event bus (auto-created if None)
261
- middleware: Middleware chain
262
- memory: Memory manager
263
- snapshot: Snapshot backend
264
-
265
- Returns:
266
- Restored ReactAgent ready to continue
267
-
268
- Raises:
269
- SessionNotFoundError: If session not found
270
-
271
- Example:
272
- agent = await ReactAgent.restore(
273
- session_id="sess_xxx",
274
- backends=my_backends,
275
- llm=my_llm,
276
- )
277
-
278
- # Check if waiting for HITL response
279
- if agent.is_suspended:
280
- print(f"Waiting for: {agent.pending_request}")
281
- else:
282
- # Continue conversation
283
- await agent.run("Continue...")
284
- """
285
- from ..core.event_bus import Bus
286
- from ..core.types.session import Session, Invocation, InvocationState, generate_id
287
- from ..core.state import State
288
- from ..tool import ToolSet
289
- from ..backends import Backends
290
-
291
- # Auto-create backends if not provided
292
- if backends is None:
293
- backends = Backends.create_default()
294
-
295
- # Validate storage backend is available
296
- if backends.state is None:
297
- raise ValueError("Cannot restore: no storage backend available (backends.state is None)")
298
-
299
- storage = backends.state
300
-
301
- # 1. Load session
302
- session_data = await storage.get("sessions", session_id)
303
- if not session_data:
304
- raise SessionNotFoundError(f"Session not found: {session_id}")
305
- session = Session.from_dict(session_data)
306
-
307
- # 2. Load current invocation
308
- invocation: Invocation | None = None
309
- if session_data.get("current_invocation_id"):
310
- inv_data = await storage.get("invocations", session_data["current_invocation_id"])
311
- if inv_data:
312
- invocation = Invocation.from_dict(inv_data)
313
-
314
- # 3. Load state
315
- state = State(storage, session_id)
316
- await state.restore()
317
-
318
- # 4. Handle tools
319
- tool_set: ToolSet | None = None
320
- if tools is not None:
321
- if isinstance(tools, ToolSet):
322
- tool_set = tools
323
- else:
324
- tool_set = ToolSet()
325
- for tool in tools:
326
- tool_set.add(tool)
327
- else:
328
- tool_set = ToolSet()
329
-
330
- # 5. Create bus if needed
331
- if bus is None:
332
- bus = Bus()
333
-
334
- # 6. Build context
335
- ctx = InvocationContext(
336
- session=session,
337
- invocation_id=invocation.id if invocation else generate_id("inv"),
338
- agent_id=config.name if config else "react_agent",
129
+ """Restore agent from persisted state. See factory.restore_react_agent for details."""
130
+ return await restore_react_agent(
131
+ session_id=session_id,
132
+ llm=llm,
339
133
  backends=backends,
134
+ tools=tools,
135
+ config=config,
340
136
  bus=bus,
341
- llm=llm,
342
- tools=tool_set,
343
137
  middleware=middleware,
344
138
  memory=memory,
345
139
  snapshot=snapshot,
346
140
  )
347
-
348
- # 7. Create agent
349
- agent = cls(ctx, config)
350
- agent._restored_invocation = invocation
351
- agent._state = state
352
-
353
- return agent
354
141
 
355
142
  def __init__(
356
143
  self,
@@ -364,7 +151,7 @@ class ReactAgent(BaseAgent):
364
151
  config: Agent configuration
365
152
 
366
153
  Raises:
367
- ValueError: If ctx.llm or ctx.tools is None
154
+ ValueError: If ctx.llm is None
368
155
  """
369
156
  super().__init__(ctx, config)
370
157
 
@@ -393,13 +180,13 @@ class ReactAgent(BaseAgent):
393
180
 
394
181
  # Restore support
395
182
  self._restored_invocation: "Invocation | None" = None
396
- self._state: "State | None" = None
183
+ self._state: "Any | None" = None # State object for checkpoint
397
184
 
398
185
  # Direct tools (passed to create())
399
186
  self._tools: list["BaseTool"] = []
400
187
 
401
188
  # ContextProviders for context engineering
402
- self._context_providers: list[ContextProvider] = []
189
+ self._context_providers: list["ContextProvider"] = []
403
190
 
404
191
  # DelegateTool class and middleware for dynamic subagent handling
405
192
  self._delegate_tool_class: type | None = None
@@ -418,7 +205,7 @@ class ReactAgent(BaseAgent):
418
205
  return False
419
206
 
420
207
  @property
421
- def state(self) -> "State | None":
208
+ def state(self) -> "Any | None":
422
209
  """Get session state (for checkpoint/restore)."""
423
210
  return self._state
424
211
 
@@ -457,6 +244,8 @@ class ReactAgent(BaseAgent):
457
244
  return self._run_config["stream_thinking"]
458
245
  return self.config.stream_thinking
459
246
 
247
+ # ========== Main execution ==========
248
+
460
249
  async def _execute(self, input: PromptInput | str) -> None:
461
250
  """Execute the React loop.
462
251
 
@@ -467,10 +256,6 @@ class ReactAgent(BaseAgent):
467
256
  if isinstance(input, str):
468
257
  input = PromptInput(text=input)
469
258
 
470
- # NOTE: 如果需要 HITL 恢复到同一个 invocation(而不是创建新的),
471
- # 可以检查 self._restored_invocation.state == SUSPENDED 并恢复精确状态。
472
- # 当前设计:每次 run() 都创建新 invocation,HITL 回复也是新 invocation。
473
-
474
259
  self.reset()
475
260
  self._running = True
476
261
 
@@ -488,7 +273,7 @@ class ReactAgent(BaseAgent):
488
273
  "session_id": self.session.id,
489
274
  "agent_id": self.name,
490
275
  "agent_type": self.agent_type,
491
- "emit": global_emit, # For middleware to emit ActionEvent
276
+ "emit": global_emit,
492
277
  "backends": self.ctx.backends,
493
278
  }
494
279
 
@@ -502,15 +287,19 @@ class ReactAgent(BaseAgent):
502
287
  )
503
288
  mw_context["invocation_id"] = self._current_invocation.id
504
289
 
505
- logger.debug("Created invocation", extra={"invocation_id": self._current_invocation.id})
290
+ logger.info("Created invocation", extra={"invocation_id": self._current_invocation.id})
506
291
 
507
292
  # === Middleware: on_agent_start ===
508
293
  if self.middleware:
294
+ logger.info(
295
+ "Calling middleware: on_agent_start",
296
+ extra={"invocation_id": self._current_invocation.id},
297
+ )
509
298
  hook_result = await self.middleware.process_agent_start(
510
299
  self.name, input, mw_context
511
300
  )
512
301
  if hook_result.action == HookAction.STOP:
513
- logger.info("Agent stopped by middleware on_agent_start")
302
+ logger.warning("Agent stopped by middleware on_agent_start", extra={"invocation_id": self._current_invocation.id})
514
303
  await self.ctx.emit(BlockEvent(
515
304
  kind=BlockKind.ERROR,
516
305
  op=BlockOp.APPLY,
@@ -518,7 +307,7 @@ class ReactAgent(BaseAgent):
518
307
  ))
519
308
  return
520
309
  elif hook_result.action == HookAction.SKIP:
521
- logger.info("Agent skipped by middleware on_agent_start")
310
+ logger.warning("Agent skipped by middleware on_agent_start", extra={"invocation_id": self._current_invocation.id})
522
311
  return
523
312
 
524
313
  await self.bus.publish(
@@ -529,18 +318,49 @@ class ReactAgent(BaseAgent):
529
318
  },
530
319
  )
531
320
 
532
- # Build initial messages (loads history from storage)
533
- self._message_history = await self._build_messages(input)
321
+ # Fetch context from providers
322
+ logger.info("Fetching agent context", extra={"invocation_id": self._current_invocation.id})
323
+ self._agent_context = await ctx_helpers.fetch_agent_context(
324
+ self._ctx,
325
+ input,
326
+ self._context_providers,
327
+ self._tools,
328
+ self._delegate_tool_class,
329
+ self._middleware_chain,
330
+ )
331
+
332
+ # Build initial messages
333
+ logger.info("Building message history", extra={"invocation_id": self._current_invocation.id})
334
+ self._message_history = await ctx_helpers.build_messages(
335
+ input,
336
+ self._agent_context,
337
+ self.config.system_prompt,
338
+ )
534
339
  self._current_step = 0
340
+ logger.info(
341
+ "Built message history",
342
+ extra={
343
+ "invocation_id": self._current_invocation.id,
344
+ "message_count": len(self._message_history),
345
+ },
346
+ )
535
347
 
536
348
  # Save user message (real-time persistence)
537
- await self._save_user_message(input)
349
+ logger.info("Saving user message", extra={"invocation_id": self._current_invocation.id})
350
+ await persist_helpers.save_user_message(self, input)
538
351
 
539
- # 3. Main loop
352
+ # Main loop
540
353
  finish_reason = None
541
354
 
542
355
  while not await self._check_abort():
543
356
  self._current_step += 1
357
+ logger.info(
358
+ "Starting step",
359
+ extra={
360
+ "invocation_id": self._current_invocation.id,
361
+ "step": self._current_step,
362
+ },
363
+ )
544
364
 
545
365
  # Check step limit
546
366
  if self._current_step > self.config.max_steps:
@@ -564,14 +384,30 @@ class ReactAgent(BaseAgent):
564
384
  snapshot_id = await self.snapshot.track()
565
385
 
566
386
  # Execute step
567
- finish_reason = await self._execute_step()
387
+ logger.info(
388
+ "Executing LLM request",
389
+ extra={
390
+ "invocation_id": self._current_invocation.id,
391
+ "step": self._current_step,
392
+ },
393
+ )
394
+ finish_reason = await step_helpers.execute_step(self)
395
+ logger.info(
396
+ "LLM response received",
397
+ extra={
398
+ "invocation_id": self._current_invocation.id,
399
+ "step": self._current_step,
400
+ "finish_reason": finish_reason,
401
+ "tool_count": len(self._tool_invocations),
402
+ },
403
+ )
568
404
 
569
405
  # Save assistant message (real-time persistence)
570
- await self._save_assistant_message()
406
+ await persist_helpers.save_assistant_message(self)
571
407
 
572
408
  # Save message_history to state and checkpoint
573
409
  if self._state:
574
- self._save_messages_to_state()
410
+ persist_helpers.save_messages_to_state(self)
575
411
  await self._state.checkpoint()
576
412
 
577
413
  # Check if we should exit
@@ -580,26 +416,57 @@ class ReactAgent(BaseAgent):
580
416
 
581
417
  # Process tool results and continue
582
418
  if self._tool_invocations:
583
- await self._process_tool_results()
419
+ logger.info(
420
+ "Processing tool invocations",
421
+ extra={
422
+ "invocation_id": self._current_invocation.id,
423
+ "step": self._current_step,
424
+ "tool_count": len(self._tool_invocations),
425
+ "tools": ", ".join([inv.tool_name for inv in self._tool_invocations]),
426
+ },
427
+ )
428
+ await tool_helpers.process_tool_results(self)
429
+ logger.info(
430
+ "Tool results processed",
431
+ extra={
432
+ "invocation_id": self._current_invocation.id,
433
+ "step": self._current_step,
434
+ },
435
+ )
584
436
 
585
437
  # Save tool messages (real-time persistence)
586
- await self._save_tool_messages()
438
+ await persist_helpers.save_tool_messages(self)
587
439
 
588
440
  self._tool_invocations.clear()
589
441
 
590
442
  # Save message_history to state and checkpoint
591
443
  if self._state:
592
- self._save_messages_to_state()
444
+ persist_helpers.save_messages_to_state(self)
593
445
  await self._state.checkpoint()
594
446
 
595
- # 4. Check if aborted
447
+ # Check if aborted
596
448
  is_aborted = self.is_cancelled
597
449
 
598
- # 5. Complete invocation
450
+ # Complete invocation
599
451
  if is_aborted:
600
452
  self._current_invocation.state = InvocationState.ABORTED
453
+ logger.info(
454
+ "Invocation aborted by user",
455
+ extra={
456
+ "invocation_id": self._current_invocation.id,
457
+ "steps": self._current_step,
458
+ },
459
+ )
601
460
  else:
602
461
  self._current_invocation.state = InvocationState.COMPLETED
462
+ logger.info(
463
+ "Invocation completed successfully",
464
+ extra={
465
+ "invocation_id": self._current_invocation.id,
466
+ "steps": self._current_step,
467
+ "finish_reason": finish_reason,
468
+ },
469
+ )
603
470
  self._current_invocation.finished_at = datetime.now()
604
471
 
605
472
  # Save to invocation backend
@@ -638,20 +505,19 @@ class ReactAgent(BaseAgent):
638
505
  )
639
506
 
640
507
  # Clear message_history from State after successful completion
641
- # Historical messages are already persisted (truncated) via MessageStore
642
- self._clear_messages_from_state()
508
+ persist_helpers.clear_messages_from_state(self)
643
509
  if self._state:
644
510
  await self._state.checkpoint()
645
511
 
646
512
  except SuspendSignal as e:
647
513
  # HITL/Suspend signal - invocation waits for user input
648
- logger.info(
649
- "Agent suspended",
514
+ logger.warning(
515
+ "Agent suspended (HITL)",
650
516
  extra={
651
- "invocation_id": self._current_invocation.id
652
- if self._current_invocation
653
- else None,
517
+ "invocation_id": self._current_invocation.id if self._current_invocation else None,
654
518
  "signal_type": type(e).__name__,
519
+ "request_type": getattr(e, "request_type", None),
520
+ "request_id": getattr(e, "request_id", None),
655
521
  },
656
522
  )
657
523
 
@@ -668,10 +534,9 @@ class ReactAgent(BaseAgent):
668
534
  # Save pending_request to execution state
669
535
  if self._state:
670
536
  self._state.execution["pending_request"] = e.to_dict()
671
- self._save_messages_to_state()
537
+ persist_helpers.save_messages_to_state(self)
672
538
  await self._state.checkpoint()
673
539
 
674
- # Don't raise - just return to exit cleanly
675
540
  return
676
541
 
677
542
  except Exception as e:
@@ -679,9 +544,7 @@ class ReactAgent(BaseAgent):
679
544
  "ReactAgent run failed",
680
545
  extra={
681
546
  "error": str(e),
682
- "invocation_id": self._current_invocation.id
683
- if self._current_invocation
684
- else None,
547
+ "invocation_id": self._current_invocation.id if self._current_invocation else None,
685
548
  },
686
549
  exc_info=True,
687
550
  )
@@ -690,8 +553,10 @@ class ReactAgent(BaseAgent):
690
553
  if self.middleware:
691
554
  processed_error = await self.middleware.process_error(e, mw_context)
692
555
  if processed_error is None:
693
- # Error suppressed by middleware
694
- logger.info("Error suppressed by middleware")
556
+ logger.warning(
557
+ "Error suppressed by middleware",
558
+ extra={"invocation_id": self._current_invocation.id if self._current_invocation else None},
559
+ )
695
560
  return
696
561
 
697
562
  if self._current_invocation:
@@ -709,1214 +574,26 @@ class ReactAgent(BaseAgent):
709
574
  self._running = False
710
575
  self._restored_invocation = None
711
576
 
712
- async def pause(self) -> str:
713
- """Pause execution and return invocation ID for later resume.
714
-
715
- Saves current state to the invocation for later resumption.
716
-
717
- Returns:
718
- Invocation ID for resuming
719
- """
720
- if not self._current_invocation:
721
- raise RuntimeError("No active invocation to pause")
722
-
723
- # Mark as paused
724
- self._paused = True
725
- self._current_invocation.mark_paused()
726
-
727
- # Save state for resumption
728
- self._current_invocation.agent_state = {
729
- "step": self._current_step,
730
- "message_history": [
731
- {"role": m.role, "content": m.content} for m in self._message_history
732
- ],
733
- "text_buffer": self._text_buffer,
734
- }
735
- self._current_invocation.step_count = self._current_step
736
-
737
- # Save pending tool calls
738
- self._current_invocation.pending_tool_ids = [
739
- inv.tool_call_id
740
- for inv in self._tool_invocations
741
- if inv.state == ToolInvocationState.CALL
742
- ]
743
-
744
- # Persist invocation
745
- if self.ctx.backends and self.ctx.backends.invocation:
746
- await self.ctx.backends.invocation.update(
747
- self._current_invocation.id,
748
- self._current_invocation.to_dict(),
749
- )
750
-
751
- await self.bus.publish(
752
- Events.INVOCATION_PAUSE,
753
- {
754
- "invocation_id": self._current_invocation.id,
755
- "step": self._current_step,
756
- },
757
- )
758
-
759
- return self._current_invocation.id
577
+ # ========== Pause/resume (delegate to pause.py) ==========
760
578
 
761
- async def _resume_internal(self, invocation_id: str) -> None:
762
- """Internal resume logic using emit."""
763
- # Load invocation
764
- if not self.ctx.backends or not self.ctx.backends.invocation:
765
- raise ValueError("No invocation backend available")
766
- inv_data = await self.ctx.backends.invocation.get(invocation_id)
767
- if not inv_data:
768
- raise ValueError(f"Invocation not found: {invocation_id}")
769
-
770
- invocation = Invocation.from_dict(inv_data)
771
-
772
- if invocation.state != InvocationState.PAUSED:
773
- raise ValueError(f"Invocation is not paused: {invocation.state}")
774
-
775
- # Restore state
776
- self._current_invocation = invocation
777
- self._paused = False
778
- self._running = True
779
-
780
- agent_state = invocation.agent_state or {}
781
- self._current_step = agent_state.get("step", 0)
782
- self._text_buffer = agent_state.get("text_buffer", "")
783
-
784
- # Restore message history
785
- self._message_history = [
786
- LLMMessage(role=m["role"], content=m["content"])
787
- for m in agent_state.get("message_history", [])
788
- ]
789
-
790
- # Mark as running
791
- invocation.state = InvocationState.RUNNING
792
-
793
- await self.bus.publish(
794
- Events.INVOCATION_RESUME,
795
- {
796
- "invocation_id": invocation_id,
797
- "step": self._current_step,
798
- },
799
- )
800
-
801
- # Continue execution loop
802
- try:
803
- finish_reason = None
804
-
805
- while not await self._check_abort() and not self._paused:
806
- self._current_step += 1
807
-
808
- if self._current_step > self.config.max_steps:
809
- await self.ctx.emit(BlockEvent(
810
- kind=BlockKind.ERROR,
811
- op=BlockOp.APPLY,
812
- data={"message": f"Max steps ({self.config.max_steps}) exceeded"},
813
- ))
814
- break
815
-
816
- finish_reason = await self._execute_step()
817
-
818
- # Save assistant message (real-time persistence)
819
- await self._save_assistant_message()
820
-
821
- if finish_reason == "end_turn" and not self._tool_invocations:
822
- break
823
-
824
- if self._tool_invocations:
825
- await self._process_tool_results()
826
-
827
- # Save tool messages (real-time persistence)
828
- await self._save_tool_messages()
829
-
830
- self._tool_invocations.clear()
831
-
832
- if not self._paused:
833
- self._current_invocation.state = InvocationState.COMPLETED
834
- self._current_invocation.finished_at = datetime.now()
835
-
836
- except Exception as e:
837
- self._current_invocation.state = InvocationState.FAILED
838
- await self.ctx.emit(BlockEvent(
839
- kind=BlockKind.ERROR,
840
- op=BlockOp.APPLY,
841
- data={"message": str(e)},
842
- ))
843
- raise
844
-
845
- finally:
846
- self._running = False
579
+ async def pause(self) -> str:
580
+ """Pause execution and return invocation ID for later resume."""
581
+ return await pause_helpers.pause_agent(self)
847
582
 
848
583
  async def resume(self, invocation_id: str) -> AsyncIterator[BlockEvent]:
849
- """Resume paused execution.
850
-
851
- Args:
852
- invocation_id: ID from pause()
853
-
854
- Yields:
855
- BlockEvent streaming events
856
- """
857
- from ..core.context import _emit_queue_var
858
-
859
- queue: asyncio.Queue[BlockEvent] = asyncio.Queue()
860
- token = _emit_queue_var.set(queue)
861
-
862
- try:
863
- exec_task = asyncio.create_task(self._resume_internal(invocation_id))
864
- get_task: asyncio.Task | None = None
865
-
866
- # Event-driven processing - no timeout delays
867
- while True:
868
- # First drain any pending items from queue (non-blocking)
869
- while True:
870
- try:
871
- block = queue.get_nowait()
872
- yield block
873
- except asyncio.QueueEmpty:
874
- break
875
-
876
- # Exit if task is done and queue is empty
877
- if exec_task.done() and queue.empty():
878
- break
879
-
880
- # Create get_task if needed
881
- if get_task is None or get_task.done():
882
- get_task = asyncio.create_task(queue.get())
883
-
884
- # Wait for EITHER: queue item OR exec_task completion
885
- done, _ = await asyncio.wait(
886
- {get_task, exec_task},
887
- return_when=asyncio.FIRST_COMPLETED,
888
- )
889
-
890
- if get_task in done:
891
- try:
892
- block = get_task.result()
893
- yield block
894
- get_task = None
895
- except asyncio.CancelledError:
896
- pass
897
-
898
- # Cancel pending get_task if any
899
- if get_task and not get_task.done():
900
- get_task.cancel()
901
- try:
902
- await get_task
903
- except asyncio.CancelledError:
904
- pass
905
-
906
- # Final drain after task completion
907
- while not queue.empty():
908
- try:
909
- block = queue.get_nowait()
910
- yield block
911
- except asyncio.QueueEmpty:
912
- break
913
-
914
- await exec_task
915
-
916
- finally:
917
- _emit_queue_var.reset(token)
918
-
919
- async def _fetch_agent_context(self, input: PromptInput) -> AgentContext:
920
- """Fetch context from all providers and merge with direct tools.
921
-
922
- Process:
923
- 1. Fetch from all providers and merge
924
- 2. Add direct tools (from create())
925
- 3. If providers returned subagents, create DelegateTool
926
-
927
- Also sets ctx.input for providers to access.
928
- """
929
- from ..tool.builtin import DelegateTool
930
- from ..backends.subagent import ListSubAgentBackend
931
-
932
- # Set input on context for providers to access
933
- self._ctx.input = input
934
-
935
- # Fetch from all context_providers
936
- outputs: list[AgentContext] = []
937
- for provider in self._context_providers:
938
- try:
939
- output = await provider.fetch(self._ctx)
940
- outputs.append(output)
941
- except Exception as e:
942
- logger.warning(f"Provider {provider.name} fetch failed: {e}")
943
-
944
- # Merge all provider outputs
945
- merged = AgentContext.merge(outputs)
946
-
947
- # Add direct tools (from create())
948
- all_tools = list(self._tools) # Copy direct tools
949
- seen_names = {t.name for t in all_tools}
950
-
951
- # Add tools from providers (deduplicate)
952
- for tool in merged.tools:
953
- if tool.name not in seen_names:
954
- seen_names.add(tool.name)
955
- all_tools.append(tool)
956
-
957
- # If providers returned subagents, create DelegateTool
958
- if merged.subagents:
959
- # Check if we already have a delegate tool
960
- has_delegate = any(t.name == "delegate" for t in all_tools)
961
- if not has_delegate:
962
- backend = ListSubAgentBackend(merged.subagents)
963
- tool_cls = self._delegate_tool_class or DelegateTool
964
- delegate_tool = tool_cls(backend, middleware=self._middleware_chain)
965
- all_tools.append(delegate_tool)
966
-
967
- # Return merged context with combined tools
968
- return AgentContext(
969
- system_content=merged.system_content,
970
- user_content=merged.user_content,
971
- tools=all_tools,
972
- messages=merged.messages,
973
- subagents=merged.subagents,
974
- skills=merged.skills,
975
- )
976
-
977
- async def _build_messages(self, input: PromptInput) -> list[LLMMessage]:
978
- """Build message history for LLM.
979
-
980
- Uses AgentContext from providers for system content, messages, etc.
981
- """
982
- messages = []
983
-
984
- # Fetch context from providers
985
- self._agent_context = await self._fetch_agent_context(input)
986
-
987
- # System message: config.system_prompt + agent_context.system_content
988
- system_prompt = self.config.system_prompt or self._default_system_prompt()
989
- if self._agent_context.system_content:
990
- system_prompt = system_prompt + "\n\n" + self._agent_context.system_content
991
- messages.append(LLMMessage(role="system", content=system_prompt))
992
-
993
- # Historical messages from AgentContext (provided by MessageContextProvider)
994
- for msg in self._agent_context.messages:
995
- messages.append(LLMMessage(
996
- role=msg.get("role", "user"),
997
- content=msg.get("content", ""),
998
- ))
999
-
1000
- # User content prefix (from providers) + current user message
1001
- content = input.text
1002
- if self._agent_context.user_content:
1003
- content = self._agent_context.user_content + "\n\n" + content
1004
-
1005
- if input.attachments:
1006
- # Build multimodal content
1007
- content_parts = [{"type": "text", "text": content}]
1008
- for attachment in input.attachments:
1009
- content_parts.append(attachment)
1010
- content = content_parts
1011
-
1012
- messages.append(LLMMessage(role="user", content=content))
1013
-
1014
- return messages
1015
-
1016
- def _default_system_prompt(self) -> str:
1017
- """Generate default system prompt with tool descriptions."""
1018
- # Get tools from AgentContext (from providers)
1019
- all_tools = self._agent_context.tools if self._agent_context else []
1020
-
1021
- tool_list = []
1022
- for tool in all_tools:
1023
- info = tool.get_info()
1024
- tool_list.append(f"- {info.name}: {info.description}")
1025
-
1026
- tools_desc = "\n".join(tool_list) if tool_list else "No tools available."
1027
-
1028
- return f"""You are a helpful AI assistant with access to tools.
1029
-
1030
- Available tools:
1031
- {tools_desc}
1032
-
1033
- When you need to use a tool, make a tool call. After receiving the tool result, continue reasoning or provide your final response.
1034
-
1035
- Think step by step and use tools when necessary to complete the user's request."""
1036
-
1037
- def _get_effective_tool_mode(self) -> ToolInjectionMode:
1038
- """Get effective tool mode (auto-detect based on model capabilities).
1039
-
1040
- Returns:
1041
- FUNCTION_CALL if model supports tools, else PROMPT
1042
- """
1043
- # If explicitly set to PROMPT, use PROMPT
1044
- if self.config.tool_mode == ToolInjectionMode.PROMPT:
1045
- return ToolInjectionMode.PROMPT
1046
-
1047
- # Auto-detect: if model doesn't support tools, use PROMPT
1048
- caps = self.llm.capabilities
1049
- if not caps.supports_tools:
1050
- logger.info(
1051
- f"Model {self.llm.model} does not support function calling, "
1052
- "auto-switching to PROMPT mode for tools"
1053
- )
1054
- return ToolInjectionMode.PROMPT
1055
-
1056
- return ToolInjectionMode.FUNCTION_CALL
1057
-
1058
- def _build_tool_prompt(self, tools: list) -> str:
1059
- """Build tool description for PROMPT mode injection.
1060
-
1061
- Args:
1062
- tools: List of BaseTool objects
1063
-
1064
- Returns:
1065
- Tool prompt string to inject into system message
1066
- """
1067
- if not tools:
1068
- return ""
1069
-
1070
- tool_descriptions = []
1071
- for tool in tools:
1072
- info = tool.get_info()
1073
- # Build parameter description
1074
- params_desc = ""
1075
- if info.parameters and "properties" in info.parameters:
1076
- params = []
1077
- properties = info.parameters.get("properties", {})
1078
- required = info.parameters.get("required", [])
1079
- for name, schema in properties.items():
1080
- param_type = schema.get("type", "any")
1081
- param_desc = schema.get("description", "")
1082
- is_required = "required" if name in required else "optional"
1083
- params.append(f" - {name} ({param_type}, {is_required}): {param_desc}")
1084
- params_desc = "\n" + "\n".join(params) if params else ""
1085
-
1086
- tool_descriptions.append(
1087
- f"### {info.name}\n"
1088
- f"{info.description}{params_desc}"
1089
- )
1090
-
1091
- return f"""## Available Tools
1092
-
1093
- You have access to the following tools. To use a tool, output a JSON block in this exact format:
1094
-
1095
- ```tool_call
1096
- {{
1097
- "tool": "tool_name",
1098
- "arguments": {{
1099
- "param1": "value1",
1100
- "param2": "value2"
1101
- }}
1102
- }}
1103
- ```
1104
-
1105
- IMPORTANT:
1106
- - Use the exact format above with ```tool_call code block
1107
- - You can make multiple tool calls in one response
1108
- - Wait for tool results before continuing
1109
-
1110
- {chr(10).join(tool_descriptions)}
1111
- """
1112
-
1113
- def _parse_tool_calls_from_text(self, text: str) -> list[dict]:
1114
- """Parse tool calls from LLM text output (for PROMPT mode).
1115
-
1116
- Looks for ```tool_call blocks in the format:
1117
- ```tool_call
1118
- {"tool": "name", "arguments": {...}}
1119
- ```
1120
-
1121
- Args:
1122
- text: LLM output text
1123
-
1124
- Returns:
1125
- List of parsed tool calls: [{"name": str, "arguments": dict}, ...]
1126
- """
1127
- import re
1128
-
1129
- tool_calls = []
1130
-
1131
- # Match ```tool_call ... ``` blocks
1132
- pattern = r"```tool_call\s*\n?(.+?)\n?```"
1133
- matches = re.findall(pattern, text, re.DOTALL)
1134
-
1135
- for match in matches:
1136
- try:
1137
- data = json.loads(match.strip())
1138
- if "tool" in data:
1139
- tool_calls.append({
1140
- "name": data["tool"],
1141
- "arguments": data.get("arguments", {}),
1142
- })
1143
- except json.JSONDecodeError as e:
1144
- logger.warning(f"Failed to parse tool call JSON: {e}")
1145
- continue
1146
-
1147
- return tool_calls
1148
-
1149
- async def _execute_step(self) -> str | None:
1150
- """Execute a single LLM step with middleware hooks.
1151
-
1152
- Returns:
1153
- finish_reason from LLM
1154
- """
1155
- # Get tools from AgentContext (from providers)
1156
- all_tools = self._agent_context.tools if self._agent_context else []
1157
-
1158
- # Determine effective tool mode (auto-detect based on capabilities)
1159
- effective_tool_mode = self._get_effective_tool_mode()
1160
-
1161
- # Get tool definitions (only for FUNCTION_CALL mode)
1162
- tool_defs = None
1163
- if effective_tool_mode == ToolInjectionMode.FUNCTION_CALL and all_tools:
1164
- tool_defs = [
1165
- ToolDefinition(
1166
- name=t.name,
1167
- description=t.description,
1168
- input_schema=t.parameters,
1169
- )
1170
- for t in all_tools
1171
- ]
1172
-
1173
- # For PROMPT mode, inject tools into system message
1174
- if effective_tool_mode == ToolInjectionMode.PROMPT and all_tools:
1175
- tool_prompt = self._build_tool_prompt(all_tools)
1176
- # Inject into first system message
1177
- if self._message_history and self._message_history[0].role == "system":
1178
- original_content = self._message_history[0].content
1179
- self._message_history[0] = LLMMessage(
1180
- role="system",
1181
- content=f"{original_content}\n\n{tool_prompt}",
1182
- )
1183
-
1184
- # Reset buffers
1185
- self._text_buffer = ""
1186
- self._thinking_buffer = "" # Buffer for non-streaming thinking
1187
- self._tool_invocations = []
1188
- current_tool_invocation: ToolInvocation | None = None
1189
-
1190
- # Reset block IDs for this step (each step gets fresh block IDs)
1191
- self._current_text_block_id = None
1192
- self._current_thinking_block_id = None
1193
-
1194
- # Reset tool call tracking
1195
- self._call_id_to_tool = {}
1196
- self._tool_call_blocks = {}
1197
-
1198
- # Build middleware context for this step
1199
- from ..core.context import emit as global_emit
1200
- mw_context = {
1201
- "session_id": self.session.id,
1202
- "invocation_id": self._current_invocation.id if self._current_invocation else "",
1203
- "step": self._current_step,
1204
- "agent_id": self.name,
1205
- "emit": global_emit, # For middleware to emit BlockEvent/ActionEvent
1206
- "backends": self.ctx.backends,
1207
- "tool_mode": effective_tool_mode.value, # Add tool mode to context
1208
- }
1209
-
1210
- # Build LLM call kwargs
1211
- # Note: temperature, max_tokens, timeout, retries are configured on LLMProvider
1212
- llm_kwargs: dict[str, Any] = {
1213
- "messages": self._message_history,
1214
- "tools": tool_defs, # None for PROMPT mode
1215
- }
1216
-
1217
- # Get model capabilities
1218
- caps = self.llm.capabilities
1219
-
1220
- # Add thinking configuration (use runtime override if set)
1221
- # Only if model supports thinking
1222
- enable_thinking = self._get_enable_thinking()
1223
- reasoning_effort = self._get_reasoning_effort()
1224
- if enable_thinking:
1225
- if caps.supports_thinking:
1226
- llm_kwargs["enable_thinking"] = True
1227
- if reasoning_effort:
1228
- llm_kwargs["reasoning_effort"] = reasoning_effort
1229
- else:
1230
- logger.debug(
1231
- f"Model {self.llm.model} does not support thinking, "
1232
- "enable_thinking will be ignored"
1233
- )
1234
-
1235
- # === Middleware: on_request ===
1236
- if self.middleware:
1237
- llm_kwargs = await self.middleware.process_request(llm_kwargs, mw_context)
1238
- if llm_kwargs is None:
1239
- logger.info("LLM request cancelled by middleware")
1240
- return None
1241
-
1242
- # Debug: log message history before LLM call
1243
- logger.debug(
1244
- f"LLM call - Step {self._current_step}, messages: {len(self._message_history)}, "
1245
- f"tools: {len(tool_defs) if tool_defs else 0}"
1246
- )
1247
- # Detailed message log (for debugging model issues like repeated calls)
1248
- for i, msg in enumerate(self._message_history):
1249
- content_preview = str(msg.content)[:300] if msg.content else "<empty>"
1250
- tool_call_id = getattr(msg, 'tool_call_id', None)
1251
- logger.debug(
1252
- f" msg[{i}] role={msg.role}"
1253
- f"{f', tool_call_id={tool_call_id}' if tool_call_id else ''}"
1254
- f", content={content_preview}"
1255
- )
1256
-
1257
- # Call LLM
1258
- await self.bus.publish(
1259
- Events.LLM_START,
1260
- {
1261
- "provider": self.llm.provider,
1262
- "model": self.llm.model,
1263
- "step": self._current_step,
1264
- "enable_thinking": enable_thinking,
1265
- },
1266
- )
1267
-
1268
- finish_reason = None
1269
- llm_response_data: dict[str, Any] = {} # Collect response for middleware
1270
-
1271
- # Reset middleware stream state
1272
- if self.middleware:
1273
- self.middleware.reset_stream_state()
1274
-
1275
- async for event in self.llm.complete(**llm_kwargs):
1276
- if await self._check_abort():
1277
- break
1278
-
1279
- if event.type == "content":
1280
- # Text content
1281
- if event.delta:
1282
- # === Middleware: on_model_stream ===
1283
- stream_chunk = {"delta": event.delta, "type": "content"}
1284
- if self.middleware:
1285
- stream_chunk = await self.middleware.process_stream_chunk(
1286
- stream_chunk, mw_context
1287
- )
1288
- if stream_chunk is None:
1289
- continue # Skip this chunk
1290
-
1291
- delta = stream_chunk.get("delta", event.delta)
1292
- self._text_buffer += delta
1293
-
1294
- # Reuse or create block_id for text streaming
1295
- if self._current_text_block_id is None:
1296
- self._current_text_block_id = generate_id("blk")
1297
-
1298
- await self.ctx.emit(BlockEvent(
1299
- block_id=self._current_text_block_id,
1300
- kind=BlockKind.TEXT,
1301
- op=BlockOp.DELTA,
1302
- data={"content": delta},
1303
- ))
1304
-
1305
- await self.bus.publish(
1306
- Events.LLM_STREAM,
1307
- {
1308
- "delta": delta,
1309
- "step": self._current_step,
1310
- },
1311
- )
1312
-
1313
- elif event.type == "thinking":
1314
- # Thinking content - only emit if thinking is enabled
1315
- stream_thinking = self._get_stream_thinking()
1316
- if event.delta and enable_thinking:
1317
- if stream_thinking:
1318
- # Reuse or create block_id for thinking streaming
1319
- if self._current_thinking_block_id is None:
1320
- self._current_thinking_block_id = generate_id("blk")
1321
-
1322
- # Stream thinking in real-time
1323
- await self.ctx.emit(BlockEvent(
1324
- block_id=self._current_thinking_block_id,
1325
- kind=BlockKind.THINKING,
1326
- op=BlockOp.DELTA,
1327
- data={"content": event.delta},
1328
- ))
1329
- else:
1330
- # Buffer thinking for batch output
1331
- self._thinking_buffer += event.delta
1332
-
1333
- elif event.type == "tool_call_start":
1334
- # Tool call started (name known, arguments pending)
1335
- if event.tool_call:
1336
- tc = event.tool_call
1337
- self._call_id_to_tool[tc.id] = tc.name
1338
-
1339
- # Always emit start notification (privacy-safe, no arguments)
1340
- block_id = generate_id("blk")
1341
- self._tool_call_blocks[tc.id] = block_id
1342
-
1343
- await self.ctx.emit(BlockEvent(
1344
- block_id=block_id,
1345
- kind=BlockKind.TOOL_USE,
1346
- op=BlockOp.APPLY,
1347
- data={
1348
- "name": tc.name,
1349
- "call_id": tc.id,
1350
- "status": "streaming", # Indicate arguments are streaming
1351
- },
1352
- ))
1353
-
1354
- elif event.type == "tool_call_delta":
1355
- # Tool arguments delta (streaming)
1356
- if event.tool_call_delta:
1357
- call_id = event.tool_call_delta.get("call_id")
1358
- arguments_delta = event.tool_call_delta.get("arguments_delta")
1359
-
1360
- if call_id and arguments_delta:
1361
- tool_name = self._call_id_to_tool.get(call_id)
1362
- if tool_name:
1363
- tool = self._get_tool(tool_name)
1364
-
1365
- # Check if tool allows streaming arguments
1366
- if tool and tool.config.stream_arguments:
1367
- block_id = self._tool_call_blocks.get(call_id)
1368
- if block_id:
1369
- await self.ctx.emit(BlockEvent(
1370
- block_id=block_id,
1371
- kind=BlockKind.TOOL_USE,
1372
- op=BlockOp.DELTA,
1373
- data={
1374
- "call_id": call_id,
1375
- "arguments_delta": arguments_delta,
1376
- },
1377
- ))
1378
-
1379
- elif event.type == "tool_call_progress":
1380
- # Tool arguments progress (bytes received)
1381
- if event.tool_call_progress:
1382
- call_id = event.tool_call_progress.get("call_id")
1383
- bytes_received = event.tool_call_progress.get("bytes_received")
1384
-
1385
- if call_id and bytes_received is not None:
1386
- block_id = self._tool_call_blocks.get(call_id)
1387
- if block_id:
1388
- # Always emit progress (privacy-safe, no content)
1389
- await self.ctx.emit(BlockEvent(
1390
- block_id=block_id,
1391
- kind=BlockKind.TOOL_USE,
1392
- op=BlockOp.PATCH,
1393
- data={
1394
- "call_id": call_id,
1395
- "bytes_received": bytes_received,
1396
- "status": "receiving",
1397
- },
1398
- ))
1399
-
1400
- elif event.type == "tool_call":
1401
- # Tool call complete (arguments fully received)
1402
- if event.tool_call:
1403
- tc = event.tool_call
1404
- invocation = ToolInvocation(
1405
- tool_call_id=tc.id,
1406
- tool_name=tc.name,
1407
- args_raw=tc.arguments,
1408
- state=ToolInvocationState.CALL,
1409
- )
1410
-
1411
- # Parse arguments
1412
- try:
1413
- invocation.args = json.loads(tc.arguments)
1414
- except json.JSONDecodeError:
1415
- invocation.args = {}
1416
-
1417
- self._tool_invocations.append(invocation)
1418
-
1419
- # Strict mode: require tool_call_start to be received first
1420
- # TODO: Uncomment below for compatibility with providers that don't send tool_call_start
1421
- # block_id = self._tool_call_blocks.get(tc.id)
1422
- # if block_id is None:
1423
- # # No streaming start event, create block now
1424
- # block_id = generate_id("blk")
1425
- # self._tool_call_blocks[tc.id] = block_id
1426
- # self._call_id_to_tool[tc.id] = tc.name
1427
- #
1428
- # # Emit APPLY with full data
1429
- # await self.ctx.emit(BlockEvent(
1430
- # block_id=block_id,
1431
- # kind=BlockKind.TOOL_USE,
1432
- # op=BlockOp.APPLY,
1433
- # data={
1434
- # "name": tc.name,
1435
- # "call_id": tc.id,
1436
- # "arguments": invocation.args,
1437
- # "status": "ready",
1438
- # },
1439
- # ))
1440
- # else:
1441
- # # Update existing block with complete arguments
1442
- # await self.ctx.emit(BlockEvent(
1443
- # block_id=block_id,
1444
- # kind=BlockKind.TOOL_USE,
1445
- # op=BlockOp.PATCH,
1446
- # data={
1447
- # "call_id": tc.id,
1448
- # "arguments": invocation.args,
1449
- # "status": "ready",
1450
- # },
1451
- # ))
1452
-
1453
- # Strict mode: tool_call_start must have been received
1454
- block_id = self._tool_call_blocks[tc.id] # Will raise KeyError if not found
1455
- await self.ctx.emit(BlockEvent(
1456
- block_id=block_id,
1457
- kind=BlockKind.TOOL_USE,
1458
- op=BlockOp.PATCH,
1459
- data={
1460
- "call_id": tc.id,
1461
- "arguments": invocation.args,
1462
- "status": "ready",
1463
- },
1464
- ))
1465
-
1466
- await self.bus.publish(
1467
- Events.TOOL_START,
1468
- {
1469
- "call_id": tc.id,
1470
- "tool": tc.name,
1471
- "arguments": invocation.args,
1472
- },
1473
- )
1474
-
1475
- elif event.type == "completed":
1476
- finish_reason = event.finish_reason
1477
-
1478
- elif event.type == "usage":
1479
- if event.usage:
1480
- await self.bus.publish(
1481
- Events.USAGE_RECORDED,
1482
- {
1483
- "provider": self.llm.provider,
1484
- "model": self.llm.model,
1485
- "input_tokens": event.usage.input_tokens,
1486
- "output_tokens": event.usage.output_tokens,
1487
- "cache_read_tokens": event.usage.cache_read_tokens,
1488
- "cache_write_tokens": event.usage.cache_write_tokens,
1489
- "reasoning_tokens": event.usage.reasoning_tokens,
1490
- },
1491
- )
1492
-
1493
- elif event.type == "error":
1494
- await self.ctx.emit(BlockEvent(
1495
- kind=BlockKind.ERROR,
1496
- op=BlockOp.APPLY,
1497
- data={"message": event.error or "Unknown LLM error"},
1498
- ))
1499
-
1500
- # If thinking was buffered, emit it now
1501
- if self._thinking_buffer and not self.config.stream_thinking:
1502
- await self.ctx.emit(BlockEvent(
1503
- kind=BlockKind.THINKING,
1504
- op=BlockOp.APPLY,
1505
- data={"content": self._thinking_buffer},
1506
- ))
1507
-
1508
- # PROMPT mode: parse tool calls from text output
1509
- if effective_tool_mode == ToolInjectionMode.PROMPT and self._text_buffer:
1510
- parsed_calls = self._parse_tool_calls_from_text(self._text_buffer)
1511
- for i, call in enumerate(parsed_calls):
1512
- call_id = generate_id("call")
1513
- invocation = ToolInvocation(
1514
- tool_call_id=call_id,
1515
- tool_name=call["name"],
1516
- args_raw=json.dumps(call["arguments"]),
1517
- args=call["arguments"],
1518
- state=ToolInvocationState.CALL,
1519
- )
1520
- self._tool_invocations.append(invocation)
1521
-
1522
- # Create block for tool call (no streaming events in PROMPT mode)
1523
- block_id = generate_id("blk")
1524
- self._tool_call_blocks[call_id] = block_id
1525
- self._call_id_to_tool[call_id] = call["name"]
1526
-
1527
- await self.ctx.emit(BlockEvent(
1528
- block_id=block_id,
1529
- kind=BlockKind.TOOL_USE,
1530
- op=BlockOp.APPLY,
1531
- data={
1532
- "name": call["name"],
1533
- "call_id": call_id,
1534
- "arguments": call["arguments"],
1535
- "status": "ready",
1536
- "source": "prompt", # Indicate parsed from text
1537
- },
1538
- ))
1539
-
1540
- await self.bus.publish(
1541
- Events.TOOL_START,
1542
- {
1543
- "call_id": call_id,
1544
- "tool": call["name"],
1545
- "arguments": call["arguments"],
1546
- "source": "prompt",
1547
- },
1548
- )
1549
-
1550
- if parsed_calls:
1551
- logger.debug(f"PROMPT mode: parsed {len(parsed_calls)} tool calls from text")
1552
-
1553
- # === Middleware: on_response ===
1554
- llm_response_data = {
1555
- "text": self._text_buffer,
1556
- "thinking": self._thinking_buffer,
1557
- "tool_calls": len(self._tool_invocations),
1558
- "finish_reason": finish_reason,
1559
- }
1560
- if self.middleware:
1561
- llm_response_data = await self.middleware.process_response(
1562
- llm_response_data, mw_context
1563
- )
1564
-
1565
- await self.bus.publish(
1566
- Events.LLM_END,
1567
- {
1568
- "step": self._current_step,
1569
- "finish_reason": finish_reason,
1570
- "text_length": len(self._text_buffer),
1571
- "thinking_length": len(self._thinking_buffer),
1572
- "tool_calls": len(self._tool_invocations),
1573
- },
1574
- )
1575
-
1576
- # Add assistant message to history
1577
- if self._text_buffer or self._tool_invocations:
1578
- assistant_content: Any = self._text_buffer
1579
- if self._tool_invocations:
1580
- # Build content with tool calls
1581
- content_parts = []
1582
- if self._text_buffer:
1583
- content_parts.append({"type": "text", "text": self._text_buffer})
1584
- for inv in self._tool_invocations:
1585
- content_parts.append(
1586
- {
1587
- "type": "tool_use",
1588
- "id": inv.tool_call_id,
1589
- "name": inv.tool_name,
1590
- "input": inv.args,
1591
- }
1592
- )
1593
- assistant_content = content_parts
1594
-
1595
- self._message_history.append(
1596
- LLMMessage(
1597
- role="assistant",
1598
- content=assistant_content,
1599
- )
1600
- )
1601
-
1602
- return finish_reason
1603
-
1604
- async def _process_tool_results(self) -> None:
1605
- """Execute tool calls and add results to history.
1606
-
1607
- Executes tools in parallel or sequentially based on config.
1608
- """
1609
- if not self._tool_invocations:
1610
- return
1611
-
1612
- # Execute tools based on configuration
1613
- if self.config.parallel_tool_execution:
1614
- # Parallel execution using asyncio.gather with create_task
1615
- # create_task ensures each task gets its own ContextVar copy
1616
- tasks = [asyncio.create_task(self._execute_tool(inv)) for inv in self._tool_invocations]
1617
- results = await asyncio.gather(*tasks, return_exceptions=True)
1618
- else:
1619
- # Sequential execution
1620
- results = []
1621
- for inv in self._tool_invocations:
1622
- try:
1623
- result = await self._execute_tool(inv)
1624
- results.append(result)
1625
- except Exception as e:
1626
- results.append(e)
1627
-
1628
- # Check for SuspendSignal first - must propagate
1629
- for result in results:
1630
- if isinstance(result, SuspendSignal):
1631
- raise result
1632
-
1633
- # Process results
1634
- tool_results = []
1635
-
1636
- for invocation, result in zip(self._tool_invocations, results):
1637
- # Handle exceptions from gather
1638
- if isinstance(result, Exception):
1639
- error_msg = f"Tool execution error: {str(result)}"
1640
- invocation.mark_result(error_msg, is_error=True)
1641
- result = ToolResult.error(error_msg)
1642
-
1643
- # Get parent block_id from tool_call mapping
1644
- parent_block_id = self._tool_call_blocks.get(invocation.tool_call_id)
1645
-
1646
- await self.ctx.emit(BlockEvent(
1647
- kind=BlockKind.TOOL_RESULT,
1648
- op=BlockOp.APPLY,
1649
- parent_id=parent_block_id,
1650
- data={
1651
- "call_id": invocation.tool_call_id,
1652
- "content": result.output,
1653
- "is_error": invocation.is_error,
1654
- },
1655
- ))
584
+ """Resume paused execution."""
585
+ async for block in pause_helpers.resume_agent(self, invocation_id):
586
+ yield block
1656
587
 
1657
- await self.bus.publish(
1658
- Events.TOOL_END,
1659
- {
1660
- "call_id": invocation.tool_call_id,
1661
- "tool": invocation.tool_name,
1662
- "result": result.output[:500], # Truncate for event
1663
- "is_error": invocation.is_error,
1664
- "duration_ms": invocation.duration_ms,
1665
- },
1666
- )
1667
-
1668
- tool_results.append(
1669
- {
1670
- "type": "tool_result",
1671
- "tool_use_id": invocation.tool_call_id,
1672
- "content": result.output,
1673
- "is_error": invocation.is_error,
1674
- }
1675
- )
1676
-
1677
- # Add tool results as tool messages (OpenAI format)
1678
- for tr in tool_results:
1679
- print(f"[DEBUG _process_tool_results] Adding tool_result to history: {tr}")
1680
- self._message_history.append(
1681
- LLMMessage(
1682
- role="tool",
1683
- content=tr["content"],
1684
- tool_call_id=tr["tool_use_id"],
1685
- )
1686
- )
1687
-
1688
- def _save_messages_to_state(self) -> None:
1689
- """Save execution state for recovery.
1690
-
1691
- This saves to state.execution namespace:
1692
- - step: current step number
1693
- - message_ids: references to raw messages (if using RawMessageMiddleware)
1694
- - For legacy/fallback: message_history as serialized data
1695
-
1696
- Note: With RawMessageMiddleware, message_ids are automatically populated
1697
- by the middleware. This method saves additional execution state.
1698
- """
1699
- if not self._state:
1700
- return
1701
-
1702
- # Save step to execution namespace
1703
- self._state.execution["step"] = self._current_step
1704
-
1705
- # Save invocation_id for recovery context
1706
- if self._current_invocation:
1707
- self._state.execution["invocation_id"] = self._current_invocation.id
1708
-
1709
- # Fallback: if message_ids not populated by middleware, save full history
1710
- # This ensures backward compatibility when RawMessageMiddleware is not used
1711
- if "message_ids" not in self._state.execution:
1712
- messages_data = []
1713
- for msg in self._message_history:
1714
- msg_dict = {"role": msg.role, "content": msg.content}
1715
- if hasattr(msg, "tool_call_id") and msg.tool_call_id:
1716
- msg_dict["tool_call_id"] = msg.tool_call_id
1717
- messages_data.append(msg_dict)
1718
- self._state.execution["message_history"] = messages_data
1719
-
1720
- def _clear_messages_from_state(self) -> None:
1721
- """Clear execution state after invocation completes.
1722
-
1723
- Called when invocation completes normally. Historical messages
1724
- are already persisted (truncated) via MessageStore.
1725
- """
1726
- if not self._state:
1727
- return
1728
-
1729
- # Clear execution namespace
1730
- self._state.execution.clear()
1731
-
1732
- async def _trigger_message_save(self, message: dict) -> dict | None:
1733
- """Trigger on_message_save hook via middleware.
1734
-
1735
- Message persistence is handled by MessageBackendMiddleware.
1736
- Agent only triggers the hook, doesn't save directly.
1737
-
1738
- Args:
1739
- message: Message dict with role, content, etc.
1740
-
1741
- Returns:
1742
- Modified message or None if blocked
1743
- """
1744
- # Check if message saving is disabled (e.g., for sub-agents with record_messages=False)
1745
- if getattr(self, '_disable_message_save', False):
1746
- return message
1747
-
1748
- if not self.middleware:
1749
- return message
1750
-
1751
- namespace = getattr(self, '_message_namespace', None)
1752
- mw_context = {
1753
- "session_id": self.session.id,
1754
- "agent_id": self.name,
1755
- "namespace": namespace,
1756
- }
1757
-
1758
- return await self.middleware.process_message_save(message, mw_context)
1759
-
1760
- async def _save_user_message(self, input: PromptInput) -> None:
1761
- """Trigger save for user message."""
1762
- # Build user content
1763
- content: str | list[dict] = input.text
1764
- if self._agent_context and self._agent_context.user_content:
1765
- content = self._agent_context.user_content + "\n\n" + input.text
1766
-
1767
- if input.attachments:
1768
- content_parts: list[dict] = [{"type": "text", "text": content}]
1769
- for attachment in input.attachments:
1770
- content_parts.append(attachment)
1771
- content = content_parts
1772
-
1773
- # Build message and trigger hook
1774
- message = {
1775
- "role": "user",
1776
- "content": content,
1777
- "invocation_id": self._current_invocation.id if self._current_invocation else "",
1778
- }
1779
-
1780
- await self._trigger_message_save(message)
1781
-
1782
- async def _save_assistant_message(self) -> None:
1783
- """Trigger save for assistant message."""
1784
- if not self._text_buffer and not self._tool_invocations:
1785
- return
1786
-
1787
- # Build assistant content
1788
- content: str | list[dict] = self._text_buffer
1789
- if self._tool_invocations:
1790
- content_parts: list[dict] = []
1791
- if self._text_buffer:
1792
- content_parts.append({"type": "text", "text": self._text_buffer})
1793
- for inv in self._tool_invocations:
1794
- content_parts.append({
1795
- "type": "tool_use",
1796
- "id": inv.tool_call_id,
1797
- "name": inv.tool_name,
1798
- "input": inv.args,
1799
- })
1800
- content = content_parts
1801
-
1802
- # Build message and trigger hook
1803
- message = {
1804
- "role": "assistant",
1805
- "content": content,
1806
- "invocation_id": self._current_invocation.id if self._current_invocation else "",
1807
- }
1808
-
1809
- await self._trigger_message_save(message)
1810
-
1811
- async def _save_tool_messages(self) -> None:
1812
- """Trigger save for tool result messages."""
1813
- for inv in self._tool_invocations:
1814
- if inv.result is not None:
1815
- # Build tool result message
1816
- content: list[dict] = [{
1817
- "type": "tool_result",
1818
- "tool_use_id": inv.tool_call_id,
1819
- "content": inv.result,
1820
- "is_error": inv.is_error,
1821
- }]
1822
-
1823
- message = {
1824
- "role": "tool",
1825
- "content": content,
1826
- "tool_call_id": inv.tool_call_id,
1827
- "invocation_id": self._current_invocation.id if self._current_invocation else "",
1828
- }
1829
-
1830
- await self._trigger_message_save(message)
588
+ # ========== Helper methods used by other modules ==========
1831
589
 
1832
590
  def _get_tool(self, tool_name: str) -> "BaseTool | None":
1833
591
  """Get tool by name from agent context."""
1834
- if self._agent_context:
1835
- for tool in self._agent_context.tools:
1836
- if tool.name == tool_name:
1837
- return tool
1838
- return None
592
+ return tool_helpers.get_tool(self, tool_name)
1839
593
 
1840
- async def _execute_tool(self, invocation: ToolInvocation) -> ToolResult:
1841
- """Execute a single tool call."""
1842
- invocation.mark_call_complete()
1843
-
1844
- # Build middleware context
1845
- mw_context = {
1846
- "session_id": self.session.id,
1847
- "invocation_id": self._current_invocation.id if self._current_invocation else "",
1848
- "tool_call_id": invocation.tool_call_id,
1849
- "agent_id": self.name,
1850
- }
1851
-
1852
- try:
1853
- # Get tool from agent context
1854
- tool = self._get_tool(invocation.tool_name)
1855
- if tool is None:
1856
- error_msg = f"Unknown tool: {invocation.tool_name}"
1857
- invocation.mark_result(error_msg, is_error=True)
1858
- return ToolResult.error(error_msg)
1859
-
1860
- # === Middleware: on_tool_call ===
1861
- if self.middleware:
1862
- hook_result = await self.middleware.process_tool_call(
1863
- tool, invocation.args, mw_context
1864
- )
1865
- if hook_result.action == HookAction.SKIP:
1866
- logger.info(f"Tool {invocation.tool_name} skipped by middleware")
1867
- return ToolResult(
1868
- output=hook_result.message or "Skipped by middleware",
1869
- is_error=False,
1870
- )
1871
- elif hook_result.action == HookAction.RETRY and hook_result.modified_data:
1872
- invocation.args = hook_result.modified_data
1873
-
1874
- # Create ToolContext
1875
- tool_ctx = ToolContext(
1876
- session_id=self.session.id,
1877
- invocation_id=self._current_invocation.id if self._current_invocation else "",
1878
- block_id="",
1879
- call_id=invocation.tool_call_id,
1880
- agent=self.config.name,
1881
- abort_signal=self._abort,
1882
- update_metadata=self._noop_update_metadata,
1883
- middleware=self.middleware,
1884
- )
1885
-
1886
- # Execute tool (with optional timeout from tool.config)
1887
- timeout = tool.config.timeout
1888
- if timeout is not None:
1889
- result = await asyncio.wait_for(
1890
- tool.execute(invocation.args, tool_ctx),
1891
- timeout=timeout,
1892
- )
1893
- else:
1894
- # No timeout - tool runs until completion
1895
- result = await tool.execute(invocation.args, tool_ctx)
1896
-
1897
- # === Middleware: on_tool_end ===
1898
- if self.middleware:
1899
- hook_result = await self.middleware.process_tool_end(tool, result, mw_context)
1900
- if hook_result.action == HookAction.RETRY:
1901
- logger.info(f"Tool {invocation.tool_name} retry requested by middleware")
1902
-
1903
- invocation.mark_result(result.output, is_error=result.is_error)
1904
- return result
1905
-
1906
- except asyncio.TimeoutError:
1907
- timeout = tool.config.timeout if tool else None
1908
- error_msg = f"Tool {invocation.tool_name} timed out after {timeout}s"
1909
- invocation.mark_result(error_msg, is_error=True)
1910
- return ToolResult.error(error_msg)
1911
-
1912
- except SuspendSignal:
1913
- # HITL/Suspend signal must propagate up
1914
- raise
1915
-
1916
- except Exception as e:
1917
- error_msg = f"Tool execution error: {str(e)}"
1918
- invocation.mark_result(error_msg, is_error=True)
1919
- return ToolResult.error(error_msg)
594
+ async def _save_tool_messages(self) -> None:
595
+ """Trigger save for tool result messages."""
596
+ await persist_helpers.save_tool_messages(self)
1920
597
 
1921
598
  async def _noop_update_metadata(self, metadata: dict[str, Any]) -> None:
1922
599
  """No-op metadata updater."""