soothe-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. soothe_cli/__init__.py +5 -0
  2. soothe_cli/cli/__init__.py +1 -0
  3. soothe_cli/cli/commands/__init__.py +1 -0
  4. soothe_cli/cli/commands/autopilot_cmd.py +410 -0
  5. soothe_cli/cli/commands/config_cmd.py +277 -0
  6. soothe_cli/cli/commands/run_cmd.py +87 -0
  7. soothe_cli/cli/commands/status_cmd.py +121 -0
  8. soothe_cli/cli/commands/subagent_names.py +17 -0
  9. soothe_cli/cli/commands/thread_cmd.py +657 -0
  10. soothe_cli/cli/execution/__init__.py +6 -0
  11. soothe_cli/cli/execution/daemon.py +194 -0
  12. soothe_cli/cli/execution/headless.py +99 -0
  13. soothe_cli/cli/execution/launcher.py +31 -0
  14. soothe_cli/cli/main.py +509 -0
  15. soothe_cli/cli/renderer.py +444 -0
  16. soothe_cli/cli/stream/__init__.py +17 -0
  17. soothe_cli/cli/stream/context.py +138 -0
  18. soothe_cli/cli/stream/display_line.py +83 -0
  19. soothe_cli/cli/stream/formatter.py +412 -0
  20. soothe_cli/cli/stream/pipeline.py +521 -0
  21. soothe_cli/cli/utils.py +46 -0
  22. soothe_cli/config/__init__.py +5 -0
  23. soothe_cli/config/cli_config.py +155 -0
  24. soothe_cli/plan/__init__.py +5 -0
  25. soothe_cli/plan/rich_tree.py +54 -0
  26. soothe_cli/shared/__init__.py +107 -0
  27. soothe_cli/shared/command_router.py +246 -0
  28. soothe_cli/shared/config_loader.py +68 -0
  29. soothe_cli/shared/display_policy.py +413 -0
  30. soothe_cli/shared/essential_events.py +68 -0
  31. soothe_cli/shared/event_processor.py +823 -0
  32. soothe_cli/shared/message_processing.py +393 -0
  33. soothe_cli/shared/presentation_engine.py +173 -0
  34. soothe_cli/shared/processor_state.py +80 -0
  35. soothe_cli/shared/renderer_protocol.py +158 -0
  36. soothe_cli/shared/rendering.py +43 -0
  37. soothe_cli/shared/slash_commands.py +354 -0
  38. soothe_cli/shared/subagent_routing.py +63 -0
  39. soothe_cli/shared/suppression_state.py +188 -0
  40. soothe_cli/shared/tool_formatters/__init__.py +27 -0
  41. soothe_cli/shared/tool_formatters/base.py +109 -0
  42. soothe_cli/shared/tool_formatters/execution.py +297 -0
  43. soothe_cli/shared/tool_formatters/fallback.py +128 -0
  44. soothe_cli/shared/tool_formatters/file_ops.py +299 -0
  45. soothe_cli/shared/tool_formatters/goal_formatter.py +331 -0
  46. soothe_cli/shared/tool_formatters/media.py +291 -0
  47. soothe_cli/shared/tool_formatters/structured.py +202 -0
  48. soothe_cli/shared/tool_formatters/web.py +143 -0
  49. soothe_cli/shared/tool_output_formatter.py +227 -0
  50. soothe_cli/shared/tui_trace_log.py +40 -0
  51. soothe_cli/tui/__init__.py +5 -0
  52. soothe_cli/tui/_ask_user_types.py +50 -0
  53. soothe_cli/tui/_cli_context.py +27 -0
  54. soothe_cli/tui/_env_vars.py +56 -0
  55. soothe_cli/tui/_session_stats.py +114 -0
  56. soothe_cli/tui/_version.py +21 -0
  57. soothe_cli/tui/app.py +4992 -0
  58. soothe_cli/tui/app.tcss +302 -0
  59. soothe_cli/tui/command_registry.py +310 -0
  60. soothe_cli/tui/config.py +2381 -0
  61. soothe_cli/tui/daemon_session.py +233 -0
  62. soothe_cli/tui/file_ops.py +409 -0
  63. soothe_cli/tui/formatting.py +28 -0
  64. soothe_cli/tui/hooks.py +23 -0
  65. soothe_cli/tui/input.py +782 -0
  66. soothe_cli/tui/media_utils.py +471 -0
  67. soothe_cli/tui/model_config.py +518 -0
  68. soothe_cli/tui/output.py +69 -0
  69. soothe_cli/tui/project_utils.py +188 -0
  70. soothe_cli/tui/sessions.py +1248 -0
  71. soothe_cli/tui/skills/__init__.py +5 -0
  72. soothe_cli/tui/skills/invocation.py +74 -0
  73. soothe_cli/tui/skills/load.py +93 -0
  74. soothe_cli/tui/textual_adapter.py +1430 -0
  75. soothe_cli/tui/theme.py +838 -0
  76. soothe_cli/tui/tool_display.py +297 -0
  77. soothe_cli/tui/unicode_security.py +502 -0
  78. soothe_cli/tui/update_check.py +447 -0
  79. soothe_cli/tui/widgets/__init__.py +9 -0
  80. soothe_cli/tui/widgets/_links.py +63 -0
  81. soothe_cli/tui/widgets/approval.py +430 -0
  82. soothe_cli/tui/widgets/ask_user.py +392 -0
  83. soothe_cli/tui/widgets/autocomplete.py +666 -0
  84. soothe_cli/tui/widgets/autopilot_dashboard.py +308 -0
  85. soothe_cli/tui/widgets/autopilot_screen.py +64 -0
  86. soothe_cli/tui/widgets/chat_input.py +1834 -0
  87. soothe_cli/tui/widgets/clipboard.py +128 -0
  88. soothe_cli/tui/widgets/diff.py +240 -0
  89. soothe_cli/tui/widgets/editor.py +140 -0
  90. soothe_cli/tui/widgets/history.py +221 -0
  91. soothe_cli/tui/widgets/loading.py +194 -0
  92. soothe_cli/tui/widgets/mcp_viewer.py +352 -0
  93. soothe_cli/tui/widgets/message_store.py +693 -0
  94. soothe_cli/tui/widgets/messages.py +1720 -0
  95. soothe_cli/tui/widgets/model_selector.py +988 -0
  96. soothe_cli/tui/widgets/notification_settings.py +155 -0
  97. soothe_cli/tui/widgets/status.py +403 -0
  98. soothe_cli/tui/widgets/theme_selector.py +158 -0
  99. soothe_cli/tui/widgets/thread_selector.py +1865 -0
  100. soothe_cli/tui/widgets/tool_renderers.py +148 -0
  101. soothe_cli/tui/widgets/tool_widgets.py +254 -0
  102. soothe_cli/tui/widgets/tools.py +165 -0
  103. soothe_cli/tui/widgets/welcome.py +330 -0
  104. soothe_cli-0.1.0.dist-info/METADATA +100 -0
  105. soothe_cli-0.1.0.dist-info/RECORD +107 -0
  106. soothe_cli-0.1.0.dist-info/WHEEL +4 -0
  107. soothe_cli-0.1.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,1430 @@
1
+ """Textual UI adapter for agent execution."""
2
+ # This module has complex streaming logic ported from execution.py
3
+
4
+ from __future__ import annotations
5
+
6
+ import asyncio
7
+ import json
8
+ import logging
9
+ import time
10
+ import uuid
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Awaitable, Callable
15
+ from pathlib import Path
16
+ from typing import Protocol
17
+
18
+ from langchain.agents.middleware.human_in_the_loop import (
19
+ ApproveDecision,
20
+ EditDecision,
21
+ HITLRequest,
22
+ RejectDecision,
23
+ )
24
+ from langchain_core.messages import AIMessage
25
+ from langchain_core.runnables import RunnableConfig
26
+ from langgraph.types import Command, Interrupt
27
+ from pydantic import TypeAdapter
28
+ from rich.console import Console
29
+
30
+ from soothe_cli.tui._ask_user_types import AskUserWidgetResult, Question
31
+
32
+ # Type alias matching HITLResponse["decisions"] element type
33
+ HITLDecision = ApproveDecision | EditDecision | RejectDecision
34
+
35
+ class _TokensUpdateCallback(Protocol):
36
+ """Callback signature for `_on_tokens_update`."""
37
+
38
+ def __call__(self, count: int, *, approximate: bool = False) -> None: ...
39
+
40
+ class _TokensShowCallback(Protocol):
41
+ """Callback signature for `_on_tokens_show`."""
42
+
43
+ def __call__(self, *, approximate: bool = False) -> None: ...
44
+
45
+
46
+ from soothe_cli.shared.essential_events import is_essential_progress_event_type
47
+ from soothe_cli.tui._ask_user_types import AskUserRequest
48
+ from soothe_cli.tui._cli_context import CLIContext # noqa: TC001
49
+ from soothe_cli.tui._session_stats import (
50
+ ModelStats as ModelStats,
51
+ )
52
+ from soothe_cli.tui._session_stats import (
53
+ SessionStats as SessionStats,
54
+ )
55
+ from soothe_cli.tui._session_stats import (
56
+ SpinnerStatus as SpinnerStatus,
57
+ )
58
+ from soothe_cli.tui._session_stats import (
59
+ format_token_count as format_token_count,
60
+ )
61
+ from soothe_cli.tui.config import build_stream_config
62
+ from soothe_cli.tui.file_ops import FileOpTracker
63
+ from soothe_cli.tui.formatting import format_duration
64
+ from soothe_cli.tui.hooks import dispatch_hook
65
+ from soothe_cli.tui.input import MediaTracker, parse_file_mentions
66
+ from soothe_cli.tui.media_utils import create_multimodal_content
67
+ from soothe_cli.tui.tool_display import format_tool_message_content
68
+ from soothe_cli.tui.widgets.messages import (
69
+ AppMessage,
70
+ AssistantMessage,
71
+ DiffMessage,
72
+ SummarizationMessage,
73
+ ToolCallMessage,
74
+ )
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+ _hitl_adapter_cache: TypeAdapter | None = None
79
+ """Lazy singleton for the HITL request validator."""
80
+
81
+
82
+ def _get_hitl_request_adapter(hitl_request_type: type) -> TypeAdapter:
83
+ """Return a cached `TypeAdapter(HITLRequest)`.
84
+
85
+ Avoids re-compiling the pydantic schema on every `execute_task_textual` call.
86
+
87
+ Args:
88
+ hitl_request_type: The `HITLRequest` class (passed in because
89
+ it is imported locally by the caller).
90
+
91
+ Returns:
92
+ Shared `TypeAdapter` instance.
93
+ """
94
+ global _hitl_adapter_cache # noqa: PLW0603
95
+ if _hitl_adapter_cache is None:
96
+ from pydantic import TypeAdapter
97
+
98
+ _hitl_adapter_cache = TypeAdapter(hitl_request_type)
99
+ return _hitl_adapter_cache
100
+
101
+
102
+ def print_usage_table(
103
+ stats: SessionStats,
104
+ wall_time: float,
105
+ console: Console,
106
+ ) -> None:
107
+ """Print a model-usage stats table to a Rich console.
108
+
109
+ When the session spans multiple models each gets its own row with a
110
+ totals row appended; single-model sessions show one row.
111
+
112
+ Args:
113
+ stats: Cumulative session stats.
114
+ wall_time: Total wall-clock time in seconds.
115
+ console: Rich console for output.
116
+ """
117
+ from rich.table import Table
118
+
119
+ has_time = wall_time >= 0.1 # noqa: PLR2004
120
+ if not (stats.request_count or stats.input_tokens or has_time):
121
+ return
122
+
123
+ if stats.per_model:
124
+ multi_model = len(stats.per_model) > 1
125
+
126
+ table = Table(
127
+ show_header=True,
128
+ header_style="bold",
129
+ box=None,
130
+ padding=(0, 2, 0, 0),
131
+ show_edge=False,
132
+ )
133
+ table.add_column("Model", style="dim")
134
+ table.add_column("Reqs", justify="right", style="dim")
135
+ table.add_column("InputTok", justify="right", style="dim")
136
+ table.add_column("OutputTok", justify="right", style="dim")
137
+
138
+ if multi_model:
139
+ for model_name, ms in stats.per_model.items():
140
+ table.add_row(
141
+ model_name,
142
+ str(ms.request_count),
143
+ format_token_count(ms.input_tokens),
144
+ format_token_count(ms.output_tokens),
145
+ )
146
+ table.add_row(
147
+ "Total",
148
+ str(stats.request_count),
149
+ format_token_count(stats.input_tokens),
150
+ format_token_count(stats.output_tokens),
151
+ )
152
+ else:
153
+ model_label = next(iter(stats.per_model))
154
+ table.add_row(
155
+ model_label,
156
+ str(stats.request_count),
157
+ format_token_count(stats.input_tokens),
158
+ format_token_count(stats.output_tokens),
159
+ )
160
+
161
+ console.print()
162
+ console.print("[bold]Usage Stats[/bold]")
163
+ console.print(table)
164
+ if has_time:
165
+ console.print()
166
+ console.print(
167
+ f"Agent active {format_duration(wall_time)}",
168
+ style="dim",
169
+ highlight=False,
170
+ )
171
+
172
+
173
+ _ask_user_adapter_cache: TypeAdapter | None = None
174
+ """Lazy singleton for the `ask_user` interrupt validator."""
175
+
176
+
177
+ def _get_ask_user_adapter() -> TypeAdapter:
178
+ """Return a cached `TypeAdapter(AskUserRequest)`.
179
+
180
+ Returns:
181
+ Shared `TypeAdapter` instance.
182
+ """
183
+ global _ask_user_adapter_cache # noqa: PLW0603
184
+ if _ask_user_adapter_cache is None:
185
+ from pydantic import TypeAdapter
186
+
187
+ _ask_user_adapter_cache = TypeAdapter(AskUserRequest)
188
+ return _ask_user_adapter_cache
189
+
190
+
191
+ def _is_summarization_chunk(metadata: dict | None) -> bool:
192
+ """Check if a message chunk is from summarization middleware.
193
+
194
+ The summarization model is invoked with
195
+ `config={"metadata": {"lc_source": "summarization"}}`
196
+ (see `langchain.agents.middleware.summarization`), which
197
+ LangChain's callback system merges into the stream metadata dict.
198
+
199
+ Args:
200
+ metadata: The metadata dict from the stream chunk.
201
+
202
+ Returns:
203
+ Whether the chunk is from summarization and should be filtered.
204
+ """
205
+ if metadata is None:
206
+ return False
207
+ return metadata.get("lc_source") == "summarization"
208
+
209
+
210
+ def _extract_custom_output_text(data: dict[str, Any]) -> str | None:
211
+ """Extract assistant-visible text from daemon custom output events."""
212
+ from soothe_sdk import strip_internal_tags
213
+ from soothe_sdk.events import (
214
+ AGENT_LOOP_COMPLETED,
215
+ CHITCHAT_RESPONSE,
216
+ FINAL_REPORT,
217
+ )
218
+
219
+ event_type = str(data.get("type", ""))
220
+ if event_type == CHITCHAT_RESPONSE:
221
+ content = data.get("content", "")
222
+ cleaned = strip_internal_tags(str(content))
223
+ return cleaned or None
224
+ if event_type == AGENT_LOOP_COMPLETED:
225
+ content = data.get("final_stdout_message", "")
226
+ cleaned = strip_internal_tags(str(content))
227
+ return cleaned or None
228
+ if event_type == FINAL_REPORT:
229
+ content = data.get("content", data.get("summary", ""))
230
+ cleaned = strip_internal_tags(str(content))
231
+ return cleaned or None
232
+ return None
233
+
234
+
235
+ def _format_progress_event_lines_for_tui(
236
+ event_data: dict[str, Any],
237
+ namespace: tuple[str, ...],
238
+ *,
239
+ pipeline: Any,
240
+ ) -> list[str]:
241
+ """Format essential progress events with the same pipeline as CLI."""
242
+ event_type = str(event_data.get("type", ""))
243
+ if not is_essential_progress_event_type(event_type):
244
+ return []
245
+
246
+ event_for_pipeline = dict(event_data)
247
+ event_for_pipeline["namespace"] = list(namespace)
248
+ lines = pipeline.process(event_for_pipeline)
249
+
250
+ rendered: list[str] = []
251
+ for line in lines:
252
+ line_text = line.format().lstrip("\n").strip()
253
+ if line_text:
254
+ rendered.append(line_text)
255
+ return rendered
256
+
257
+
258
+ class TextualUIAdapter:
259
+ """Adapter for rendering agent output to Textual widgets.
260
+
261
+ This adapter provides an abstraction layer between the agent execution and the
262
+ Textual UI, allowing streaming output to be rendered as widgets.
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ mount_message: Callable[..., Awaitable[None]],
268
+ update_status: Callable[[str], None],
269
+ request_approval: Callable[..., Awaitable[Any]],
270
+ on_auto_approve_enabled: Callable[[], None] | None = None,
271
+ set_spinner: Callable[[SpinnerStatus], Awaitable[None]] | None = None,
272
+ set_active_message: Callable[[str | None], None] | None = None,
273
+ sync_message_content: Callable[[str, str], None] | None = None,
274
+ request_ask_user: (
275
+ Callable[
276
+ [list[Question]],
277
+ Awaitable[asyncio.Future[AskUserWidgetResult] | None],
278
+ ]
279
+ | None
280
+ ) = None,
281
+ ) -> None:
282
+ """Initialize the adapter."""
283
+ self._mount_message = mount_message
284
+ """Async callback to mount a message widget to the chat."""
285
+
286
+ self._update_status = update_status
287
+ """Callback to update the status bar text."""
288
+
289
+ self._request_approval = request_approval
290
+ """Async callback that returns a Future for HITL approval."""
291
+
292
+ self._on_auto_approve_enabled = on_auto_approve_enabled
293
+ """Callback invoked when auto-approve is enabled via the HITL approval
294
+ menu.
295
+
296
+ Fired when the user selects "Auto-approve all" from an approval dialog,
297
+ allowing the app to sync its status bar and session state.
298
+ """
299
+
300
+ self._set_spinner = set_spinner
301
+ """Callback to show/hide loading spinner."""
302
+
303
+ self._set_active_message = set_active_message
304
+ """Callback to set the active streaming message ID (pass `None` to clear)."""
305
+
306
+ self._sync_message_content = sync_message_content
307
+ """Callback to sync final message content back to the store after streaming."""
308
+
309
+ self._request_ask_user = request_ask_user
310
+ """Async callback for `ask_user` interrupts.
311
+
312
+ When awaited, returns a `Future` that resolves to user answers.
313
+ """
314
+
315
+ # State tracking
316
+ self._current_tool_messages: dict[str, ToolCallMessage] = {}
317
+ """Map of tool call IDs to their message widgets."""
318
+
319
+ # Token display callbacks (set by the app after construction)
320
+ self._on_tokens_update: _TokensUpdateCallback | None = None
321
+ """Called with total context tokens after each LLM response."""
322
+
323
+ self._on_tokens_hide: Callable[[], None] | None = None
324
+ """Called to hide the token display during streaming."""
325
+
326
+ self._on_tokens_show: _TokensShowCallback | None = None
327
+ """Called to restore the token display with the cached value."""
328
+
329
+ def finalize_pending_tools_with_error(self, error: str) -> None:
330
+ """Mark all pending/running tool widgets as error and clear tracking.
331
+
332
+ This is used as a safety net when an unexpected exception aborts
333
+ streaming before matching `ToolMessage` results are received.
334
+
335
+ Args:
336
+ error: Error text to display in each pending tool widget.
337
+ """
338
+ for tool_msg in list(self._current_tool_messages.values()):
339
+ tool_msg.set_error(error)
340
+ self._current_tool_messages.clear()
341
+
342
+ # Clear active streaming message to avoid stale "active" state in the store.
343
+ if self._set_active_message:
344
+ self._set_active_message(None)
345
+
346
+
347
+ def _build_interrupted_ai_message(
348
+ pending_text_by_namespace: dict[tuple, str],
349
+ current_tool_messages: dict[str, Any],
350
+ ) -> AIMessage | None:
351
+ """Build an AIMessage capturing interrupted state (text + tool calls).
352
+
353
+ Args:
354
+ pending_text_by_namespace: Dict of accumulated text by namespace
355
+ current_tool_messages: Dict of tool_id -> ToolCallMessage widget
356
+
357
+ Returns:
358
+ AIMessage with accumulated content and tool calls, or None if empty.
359
+ """
360
+ from langchain_core.messages import AIMessage
361
+
362
+ main_ns_key = ()
363
+ accumulated_text = pending_text_by_namespace.get(main_ns_key, "").strip()
364
+
365
+ # Reconstruct tool_calls from displayed tool messages
366
+ tool_calls = []
367
+ for tool_id, tool_widget in list(current_tool_messages.items()):
368
+ tool_calls.append(
369
+ {
370
+ "id": tool_id,
371
+ "name": tool_widget._tool_name,
372
+ "args": tool_widget._args,
373
+ }
374
+ )
375
+
376
+ if not accumulated_text and not tool_calls:
377
+ return None
378
+
379
+ return AIMessage(
380
+ content=accumulated_text,
381
+ tool_calls=tool_calls or [],
382
+ )
383
+
384
+
385
+ def _read_mentioned_file(file_path: Path, max_embed_bytes: int) -> str:
386
+ """Read a mentioned file for inline embedding (sync, for use with to_thread).
387
+
388
+ Args:
389
+ file_path: Resolved path to the file.
390
+ max_embed_bytes: Size threshold; larger files get a reference only.
391
+
392
+ Returns:
393
+ Markdown snippet with the file content or a size-exceeded reference.
394
+ """
395
+ file_size = file_path.stat().st_size
396
+ if file_size > max_embed_bytes:
397
+ size_kb = file_size // 1024
398
+ return (
399
+ f"\n### {file_path.name}\n"
400
+ f"Path: `{file_path}`\n"
401
+ f"Size: {size_kb}KB (too large to embed, "
402
+ "use read_file tool to view)"
403
+ )
404
+ content = file_path.read_text(encoding="utf-8")
405
+ return f"\n### {file_path.name}\nPath: `{file_path}`\n```\n{content}\n```"
406
+
407
+
408
+ async def execute_task_textual(
409
+ user_input: str,
410
+ agent: Any, # noqa: ANN401 # Dynamic agent graph type
411
+ assistant_id: str | None,
412
+ session_state: Any, # noqa: ANN401 # Dynamic session state type
413
+ adapter: TextualUIAdapter,
414
+ image_tracker: MediaTracker | None = None,
415
+ context: CLIContext | None = None,
416
+ *,
417
+ daemon_session: Any = None, # noqa: ANN401 # Daemon-backed TUI session
418
+ sandbox_type: str | None = None,
419
+ message_kwargs: dict[str, Any] | None = None,
420
+ turn_stats: SessionStats | None = None,
421
+ skip_daemon_send_turn: bool = False,
422
+ ) -> SessionStats:
423
+ """Execute a task with output directed to Textual UI.
424
+
425
+ This is the Textual-compatible version of execute_task() that uses
426
+ the TextualUIAdapter for all UI operations.
427
+
428
+ Args:
429
+ user_input: The user's input message
430
+ agent: The LangGraph agent to execute
431
+ daemon_session: Optional daemon-backed session for direct websocket
432
+ streaming. When provided, this becomes the primary execution path.
433
+ assistant_id: The agent identifier
434
+ session_state: Session state with auto_approve flag
435
+ adapter: The TextualUIAdapter for UI operations
436
+ image_tracker: Optional tracker for images
437
+ context: Optional `CLIContext` with model override and params, passed
438
+ to the graph via `context=`.
439
+ sandbox_type: Sandbox provider name for trace metadata, or `None`
440
+ if no sandbox is active.
441
+ message_kwargs: Extra fields merged into the stream input message
442
+ dict (e.g., `additional_kwargs` for persisting skill metadata
443
+ in the checkpoint).
444
+ turn_stats: Pre-created `SessionStats` to accumulate into.
445
+
446
+ When the caller holds a reference to the same object, stats are
447
+ available even if this coroutine is cancelled before it can return.
448
+
449
+ If `None`, a new instance is created internally.
450
+ skip_daemon_send_turn: When ``True`` with ``daemon_session`` set, skip
451
+ ``send_turn`` and only consume chunks (daemon already queued the
452
+ prompt, e.g. after ``invoke_skill``).
453
+
454
+ Returns:
455
+ Stats accumulated over this turn (request count, token counts,
456
+ wall-clock time).
457
+
458
+ Raises:
459
+ ValidationError: If HITL request validation fails (re-raised).
460
+ """
461
+ from langchain.agents.middleware.human_in_the_loop import (
462
+ ApproveDecision,
463
+ HITLRequest,
464
+ RejectDecision,
465
+ )
466
+ from langchain_core.messages import HumanMessage, ToolMessage
467
+ from langgraph.types import Command
468
+ from pydantic import ValidationError
469
+
470
+ from soothe_cli.cli.stream import StreamDisplayPipeline
471
+
472
+ hitl_request_adapter = _get_hitl_request_adapter(HITLRequest)
473
+ ask_user_adapter = _get_ask_user_adapter()
474
+ progress_pipeline = StreamDisplayPipeline(verbosity="detailed")
475
+
476
+ # Parse file mentions and inject content if any — defer blocking I/O
477
+ prompt_text, mentioned_files = await asyncio.to_thread(parse_file_mentions, user_input)
478
+
479
+ # Max file size to embed inline (256KB, matching mistral-vibe)
480
+ # Larger files get a reference instead - use read_file tool to view them
481
+ max_embed_bytes = 256 * 1024
482
+
483
+ if mentioned_files:
484
+ context_parts = [prompt_text, "\n\n## Referenced Files\n"]
485
+ for file_path in mentioned_files:
486
+ try:
487
+ part = await asyncio.to_thread(_read_mentioned_file, file_path, max_embed_bytes)
488
+ context_parts.append(part)
489
+ except Exception as e: # noqa: BLE001 # Resilient adapter error handling
490
+ context_parts.append(f"\n### {file_path.name}\n[Error reading file: {e}]")
491
+ final_input = "\n".join(context_parts)
492
+ else:
493
+ final_input = prompt_text
494
+
495
+ # Include images and videos in the message content
496
+ images_to_send = []
497
+ videos_to_send = []
498
+ if image_tracker:
499
+ images_to_send = image_tracker.get_images()
500
+ videos_to_send = image_tracker.get_videos()
501
+ if images_to_send or videos_to_send:
502
+ message_content = create_multimodal_content(final_input, images_to_send, videos_to_send)
503
+ else:
504
+ message_content = final_input
505
+
506
+ thread_id = session_state.thread_id
507
+ config = build_stream_config(thread_id, assistant_id, sandbox_type=sandbox_type)
508
+
509
+ await dispatch_hook("session.start", {"thread_id": thread_id})
510
+
511
+ captured_input_tokens = 0
512
+ captured_output_tokens = 0
513
+ if turn_stats is None:
514
+ turn_stats = SessionStats()
515
+ start_time = time.monotonic()
516
+
517
+ # Warn if token display callbacks are only partially wired — all three
518
+ # should be set together to avoid inconsistent status-bar behavior.
519
+ token_cbs = (
520
+ adapter._on_tokens_update,
521
+ adapter._on_tokens_hide,
522
+ adapter._on_tokens_show,
523
+ )
524
+ if any(token_cbs) and not all(token_cbs):
525
+ logger.warning(
526
+ "Token callbacks partially wired (update=%s, hide=%s, show=%s); token display may behave inconsistently",
527
+ adapter._on_tokens_update is not None,
528
+ adapter._on_tokens_hide is not None,
529
+ adapter._on_tokens_show is not None,
530
+ )
531
+
532
+ # Show spinner
533
+ if adapter._set_spinner:
534
+ await adapter._set_spinner("Thinking")
535
+
536
+ # Hide token display during streaming (will be shown with accurate count at end)
537
+ if adapter._on_tokens_hide:
538
+ adapter._on_tokens_hide()
539
+
540
+ file_op_tracker = FileOpTracker(assistant_id=assistant_id)
541
+ displayed_tool_ids: set[str] = set()
542
+ tool_call_buffers: dict[str | int, dict] = {}
543
+
544
+ # Track pending text and assistant messages PER NAMESPACE to avoid interleaving
545
+ # when multiple subagents stream in parallel
546
+ pending_text_by_namespace: dict[tuple, str] = {}
547
+ assistant_message_by_namespace: dict[tuple, Any] = {}
548
+
549
+ # Clear media from tracker after creating the message
550
+ if image_tracker:
551
+ image_tracker.clear()
552
+
553
+ user_msg: dict[str, Any] = {"role": "user", "content": message_content}
554
+ if message_kwargs:
555
+ user_msg.update(message_kwargs)
556
+ stream_input: dict | Command = {"messages": [user_msg]}
557
+
558
+ # Track summarization lifecycle so spinner status and notification stay in sync.
559
+ summarization_in_progress = False
560
+
561
+ try:
562
+ while True:
563
+ interrupt_occurred = False
564
+ suppress_resumed_output = False
565
+ pending_interrupts: dict[str, HITLRequest] = {}
566
+ pending_ask_user: dict[str, AskUserRequest] = {}
567
+
568
+ if daemon_session is None:
569
+ chunk_source = agent.astream(
570
+ stream_input,
571
+ stream_mode=["messages", "updates"],
572
+ subgraphs=True,
573
+ config=config,
574
+ context=context,
575
+ durability="exit",
576
+ )
577
+ else:
578
+ if isinstance(stream_input, Command):
579
+ resume_data = getattr(stream_input, "resume", None)
580
+ if not isinstance(resume_data, dict):
581
+ raise ValueError("Invalid daemon resume payload")
582
+ await daemon_session.resume_interrupts(resume_data)
583
+ chunk_source = daemon_session.iter_turn_chunks()
584
+ elif skip_daemon_send_turn:
585
+ chunk_source = daemon_session.iter_turn_chunks()
586
+ else:
587
+ daemon_text = (
588
+ message_content if isinstance(message_content, str) else final_input
589
+ )
590
+ ctx_model = context.get("model") if context else None
591
+ raw_mp = context.get("model_params") if context else None
592
+ mp = raw_mp if isinstance(raw_mp, dict) else None
593
+ await daemon_session.send_turn(
594
+ daemon_text,
595
+ interactive=True,
596
+ model=ctx_model
597
+ if isinstance(ctx_model, str) and ctx_model.strip()
598
+ else None,
599
+ model_params=mp,
600
+ )
601
+ chunk_source = daemon_session.iter_turn_chunks()
602
+
603
+ async for chunk in chunk_source:
604
+ if not isinstance(chunk, tuple) or len(chunk) != 3: # noqa: PLR2004 # stream chunk is a 3-tuple (namespace, mode, data)
605
+ logger.debug("Skipping non-3-tuple chunk: %s", type(chunk).__name__)
606
+ continue
607
+
608
+ namespace, current_stream_mode, data = chunk
609
+
610
+ # Convert namespace to hashable tuple for dict keys
611
+ ns_key = tuple(namespace) if namespace else ()
612
+
613
+ # Filter out subagent outputs - only show main agent (empty
614
+ # namespace). Subagents run via Task tool and should only
615
+ # report back to the main agent
616
+ is_main_agent = ns_key == ()
617
+
618
+ # Handle UPDATES stream - for interrupts and todos
619
+ if current_stream_mode == "updates":
620
+ if not isinstance(data, dict):
621
+ continue
622
+
623
+ # Check for interrupts
624
+ if "__interrupt__" in data:
625
+ interrupts: list[Interrupt] = data["__interrupt__"]
626
+ if interrupts:
627
+ for interrupt_obj in interrupts:
628
+ iv = interrupt_obj.value
629
+ if isinstance(iv, dict) and iv.get("type") == "ask_user":
630
+ try:
631
+ validated_ask_user = ask_user_adapter.validate_python(iv)
632
+ pending_ask_user[interrupt_obj.id] = validated_ask_user
633
+ interrupt_occurred = True
634
+ await dispatch_hook("input.required", {})
635
+ except ValidationError:
636
+ logger.exception("Invalid ask_user interrupt payload")
637
+ raise
638
+ else:
639
+ try:
640
+ validated_request = hitl_request_adapter.validate_python(iv)
641
+ pending_interrupts[interrupt_obj.id] = validated_request
642
+ interrupt_occurred = True
643
+ await dispatch_hook("input.required", {})
644
+ except ValidationError: # noqa: TRY203 # Re-raise preserves exception context in handler
645
+ raise
646
+
647
+ # Check for todo updates (not yet implemented in Textual UI)
648
+ chunk_data = next(iter(data.values())) if data else None
649
+ if chunk_data and isinstance(chunk_data, dict) and "todos" in chunk_data:
650
+ pass # Future: render todo list widget
651
+
652
+ # Handle MESSAGES stream - for content and tool calls
653
+ elif current_stream_mode == "messages":
654
+ # Skip subagent outputs - only render main agent content in chat
655
+ if not is_main_agent:
656
+ logger.debug("Skipping subagent message ns=%s", ns_key)
657
+ continue
658
+
659
+ if not isinstance(data, tuple) or len(data) != 2: # noqa: PLR2004 # message stream data is a 2-tuple (message, metadata)
660
+ logger.debug(
661
+ "Skipping non-2-tuple message data: type=%s",
662
+ type(data).__name__,
663
+ )
664
+ continue
665
+
666
+ message, metadata = data
667
+ logger.debug(
668
+ "Processing message: type=%s id=%s has_content_blocks=%s",
669
+ type(message).__name__,
670
+ getattr(message, "id", None),
671
+ hasattr(message, "content_blocks"),
672
+ )
673
+
674
+ # Filter out summarization model output, but keep UI feedback.
675
+ # The summarization model streams AIMessage chunks tagged
676
+ # with lc_source="summarization" in the callback metadata.
677
+ # These are hidden from the user; only the spinner and a
678
+ # notification widget provide feedback.
679
+ if _is_summarization_chunk(metadata):
680
+ if not summarization_in_progress:
681
+ summarization_in_progress = True
682
+ if adapter._set_spinner:
683
+ await adapter._set_spinner("Offloading")
684
+ continue
685
+
686
+ # Regular (non-summarization) chunks resumed — summarization
687
+ # has finished. Mount the notification and reset the spinner.
688
+ if summarization_in_progress:
689
+ summarization_in_progress = False
690
+ try:
691
+ await adapter._mount_message(SummarizationMessage())
692
+ except Exception:
693
+ logger.debug(
694
+ "Failed to mount summarization notification",
695
+ exc_info=True,
696
+ )
697
+ if adapter._set_spinner and not adapter._current_tool_messages:
698
+ await adapter._set_spinner("Thinking")
699
+
700
+ if isinstance(message, HumanMessage):
701
+ content = message.text
702
+ # Flush pending text for this namespace
703
+ pending_text = pending_text_by_namespace.get(ns_key, "")
704
+ if content and pending_text:
705
+ await _flush_assistant_text_ns(
706
+ adapter,
707
+ pending_text,
708
+ ns_key,
709
+ assistant_message_by_namespace,
710
+ )
711
+ pending_text_by_namespace[ns_key] = ""
712
+ continue
713
+
714
+ if isinstance(message, ToolMessage):
715
+ tool_name = getattr(message, "name", "")
716
+ tool_status = getattr(message, "status", "success")
717
+ tool_content = format_tool_message_content(message.content)
718
+ record = file_op_tracker.complete_with_message(message)
719
+
720
+ # Update tool call status with output
721
+ tool_id = getattr(message, "tool_call_id", None)
722
+ if tool_id and tool_id in adapter._current_tool_messages:
723
+ # Pop before widget calls so the dict drains even
724
+ # if set_success/set_error raises.
725
+ tool_msg = adapter._current_tool_messages.pop(tool_id)
726
+ output_str = str(tool_content) if tool_content else ""
727
+ if tool_status == "success":
728
+ tool_msg.set_success(output_str)
729
+ else:
730
+ tool_msg.set_error(output_str or "Error")
731
+ await dispatch_hook(
732
+ "tool.error",
733
+ {"tool_names": [tool_msg._tool_name]},
734
+ )
735
+ elif tool_id:
736
+ logger.debug(
737
+ "ToolMessage tool_call_id=%s not in "
738
+ "_current_tool_messages; spinner gating "
739
+ "may be stale",
740
+ tool_id,
741
+ )
742
+
743
+ # Reshow spinner only when all in-flight tools have
744
+ # completed (avoids premature "Thinking..." when
745
+ # parallel tool calls are active).
746
+ if adapter._set_spinner and not adapter._current_tool_messages:
747
+ await adapter._set_spinner("Thinking")
748
+
749
+ # Show file operation results - always show diffs in chat
750
+ if record:
751
+ pending_text = pending_text_by_namespace.get(ns_key, "")
752
+ if pending_text:
753
+ await _flush_assistant_text_ns(
754
+ adapter,
755
+ pending_text,
756
+ ns_key,
757
+ assistant_message_by_namespace,
758
+ )
759
+ pending_text_by_namespace[ns_key] = ""
760
+ if record.diff:
761
+ await adapter._mount_message(
762
+ DiffMessage(record.diff, record.display_path)
763
+ )
764
+ continue
765
+
766
+ # Extract token usage (before content_blocks check
767
+ # - usage may be on any chunk)
768
+ if hasattr(message, "usage_metadata"):
769
+ usage = message.usage_metadata
770
+ if usage:
771
+ input_toks = usage.get("input_tokens", 0)
772
+ output_toks = usage.get("output_tokens", 0)
773
+ total_toks = usage.get("total_tokens", 0)
774
+ from soothe_cli.tui.config import settings
775
+
776
+ active_model = settings.model_name or ""
777
+ if input_toks or output_toks:
778
+ # Model gives split counts — preferred path
779
+ turn_stats.record_request(active_model, input_toks, output_toks)
780
+ captured_input_tokens = max(
781
+ captured_input_tokens, input_toks + output_toks
782
+ )
783
+ elif total_toks:
784
+ # Fallback: model gives only total (no split)
785
+ turn_stats.record_request(active_model, total_toks, 0)
786
+ captured_input_tokens = max(captured_input_tokens, total_toks)
787
+
788
+ # Check if this is an AIMessageChunk with content
789
+ if not hasattr(message, "content_blocks"):
790
+ logger.debug(
791
+ "Message has no content_blocks: type=%s",
792
+ type(message).__name__,
793
+ )
794
+ continue
795
+
796
+ # Process content blocks
797
+ blocks = message.content_blocks
798
+ logger.debug(
799
+ "content_blocks count=%d blocks=%s",
800
+ len(blocks),
801
+ repr(blocks)[:500],
802
+ )
803
+ for block in blocks:
804
+ block_type = block.get("type")
805
+
806
+ if block_type == "text":
807
+ text = block.get("text", "")
808
+ if text:
809
+ # Track accumulated text for reference
810
+ pending_text = pending_text_by_namespace.get(ns_key, "")
811
+ pending_text += text
812
+ pending_text_by_namespace[ns_key] = pending_text
813
+
814
+ # Get or create assistant message for this namespace
815
+ current_msg = assistant_message_by_namespace.get(ns_key)
816
+ if current_msg is None:
817
+ # Hide spinner when assistant starts responding
818
+ if adapter._set_spinner:
819
+ await adapter._set_spinner(None)
820
+ msg_id = f"asst-{uuid.uuid4().hex[:8]}"
821
+ # Mark active BEFORE mounting so pruning
822
+ # (triggered by mount) won't remove it
823
+ # (_mount_message can trigger
824
+ # _prune_old_messages if the window exceeds
825
+ # WINDOW_SIZE.)
826
+ if adapter._set_active_message:
827
+ adapter._set_active_message(msg_id)
828
+ current_msg = AssistantMessage(id=msg_id)
829
+ await adapter._mount_message(current_msg)
830
+ assistant_message_by_namespace[ns_key] = current_msg
831
+
832
+ # Append just the new text chunk for smoother
833
+ # streaming (uses MarkdownStream internally for
834
+ # better performance)
835
+ await current_msg.append_content(text)
836
+
837
+ elif block_type in {"tool_call_chunk", "tool_call"}:
838
+ chunk_name = block.get("name")
839
+ chunk_args = block.get("args")
840
+ chunk_id = block.get("id")
841
+ chunk_index = block.get("index")
842
+
843
+ buffer_key: str | int
844
+ if chunk_index is not None:
845
+ buffer_key = chunk_index
846
+ elif chunk_id is not None:
847
+ buffer_key = chunk_id
848
+ else:
849
+ buffer_key = f"unknown-{len(tool_call_buffers)}"
850
+
851
+ buffer = tool_call_buffers.setdefault(
852
+ buffer_key,
853
+ {
854
+ "name": None,
855
+ "id": None,
856
+ "args": None,
857
+ "args_parts": [],
858
+ },
859
+ )
860
+
861
+ if chunk_name:
862
+ buffer["name"] = chunk_name
863
+ if chunk_id:
864
+ buffer["id"] = chunk_id
865
+
866
+ if isinstance(chunk_args, dict):
867
+ buffer["args"] = chunk_args
868
+ buffer["args_parts"] = []
869
+ elif isinstance(chunk_args, str):
870
+ if chunk_args:
871
+ parts: list[str] = buffer.setdefault("args_parts", [])
872
+ if not parts or chunk_args != parts[-1]:
873
+ parts.append(chunk_args)
874
+ buffer["args"] = "".join(parts)
875
+ elif chunk_args is not None:
876
+ buffer["args"] = chunk_args
877
+
878
+ buffer_name = buffer.get("name")
879
+ buffer_id = buffer.get("id")
880
+ if buffer_name is None:
881
+ continue
882
+
883
+ parsed_args = buffer.get("args")
884
+ if isinstance(parsed_args, str):
885
+ if not parsed_args:
886
+ continue
887
+ try:
888
+ parsed_args = json.loads(parsed_args)
889
+ except json.JSONDecodeError:
890
+ continue
891
+ elif parsed_args is None:
892
+ continue
893
+
894
+ if not isinstance(parsed_args, dict):
895
+ parsed_args = {"value": parsed_args}
896
+
897
+ # Flush pending text before tool call
898
+ pending_text = pending_text_by_namespace.get(ns_key, "")
899
+ if pending_text:
900
+ await _flush_assistant_text_ns(
901
+ adapter,
902
+ pending_text,
903
+ ns_key,
904
+ assistant_message_by_namespace,
905
+ )
906
+ pending_text_by_namespace[ns_key] = ""
907
+ assistant_message_by_namespace.pop(ns_key, None)
908
+
909
+ logger.debug(
910
+ "Tool call buffer: name=%s id=%s args=%s",
911
+ buffer_name,
912
+ buffer_id,
913
+ repr(parsed_args)[:200],
914
+ )
915
+ if buffer_id is not None and buffer_id not in displayed_tool_ids:
916
+ displayed_tool_ids.add(buffer_id)
917
+ file_op_tracker.start_operation(buffer_name, parsed_args, buffer_id)
918
+
919
+ # Hide spinner before showing tool call
920
+ if adapter._set_spinner:
921
+ await adapter._set_spinner(None)
922
+
923
+ # Mount tool call message
924
+ logger.debug(
925
+ "Mounting ToolCallMessage: %s(%s)",
926
+ buffer_name,
927
+ repr(parsed_args)[:200],
928
+ )
929
+ tool_msg = ToolCallMessage(buffer_name, parsed_args)
930
+ await adapter._mount_message(tool_msg)
931
+ adapter._current_tool_messages[buffer_id] = tool_msg
932
+
933
+ tool_call_buffers.pop(buffer_key, None)
934
+
935
+ if getattr(message, "chunk_position", None) == "last":
936
+ pending_text = pending_text_by_namespace.get(ns_key, "")
937
+ if pending_text:
938
+ await _flush_assistant_text_ns(
939
+ adapter,
940
+ pending_text,
941
+ ns_key,
942
+ assistant_message_by_namespace,
943
+ )
944
+ pending_text_by_namespace[ns_key] = ""
945
+ assistant_message_by_namespace.pop(ns_key, None)
946
+
947
+ elif current_stream_mode == "custom":
948
+ if isinstance(data, dict):
949
+ event_type = str(data.get("type", ""))
950
+ if event_type.startswith("soothe.error"):
951
+ error_text = str(
952
+ data.get("error") or data.get("message") or "Agent error"
953
+ )
954
+ await adapter._mount_message(AppMessage(error_text))
955
+ if adapter._set_spinner:
956
+ await adapter._set_spinner(None)
957
+ continue
958
+ if output_text := _extract_custom_output_text(data):
959
+ pending_text = pending_text_by_namespace.get(ns_key, "")
960
+ if pending_text:
961
+ await _flush_assistant_text_ns(
962
+ adapter,
963
+ pending_text,
964
+ ns_key,
965
+ assistant_message_by_namespace,
966
+ )
967
+ pending_text_by_namespace[ns_key] = ""
968
+ assistant_message_by_namespace.pop(ns_key, None)
969
+ output_widget = AssistantMessage(
970
+ output_text, id=f"asst-{uuid.uuid4().hex[:8]}"
971
+ )
972
+ await adapter._mount_message(output_widget)
973
+ await output_widget.write_initial_content()
974
+ if adapter._sync_message_content and output_widget.id:
975
+ adapter._sync_message_content(output_widget.id, output_text)
976
+ if adapter._set_active_message:
977
+ adapter._set_active_message(None)
978
+ if adapter._set_spinner:
979
+ await adapter._set_spinner(None)
980
+ continue
981
+ progress_lines = _format_progress_event_lines_for_tui(
982
+ data,
983
+ ns_key,
984
+ pipeline=progress_pipeline,
985
+ )
986
+ if progress_lines:
987
+ pending_text = pending_text_by_namespace.get(ns_key, "")
988
+ if pending_text:
989
+ await _flush_assistant_text_ns(
990
+ adapter,
991
+ pending_text,
992
+ ns_key,
993
+ assistant_message_by_namespace,
994
+ )
995
+ pending_text_by_namespace[ns_key] = ""
996
+ assistant_message_by_namespace.pop(ns_key, None)
997
+ for progress_line in progress_lines:
998
+ await adapter._mount_message(AppMessage(progress_line))
999
+ continue
1000
+
1001
+ # Reset summarization state if stream ended mid-summarization
1002
+ # (e.g. middleware error, stream exhausted before regular chunks).
1003
+ if summarization_in_progress:
1004
+ summarization_in_progress = False
1005
+ try:
1006
+ await adapter._mount_message(SummarizationMessage())
1007
+ except Exception:
1008
+ logger.debug(
1009
+ "Failed to mount summarization notification",
1010
+ exc_info=True,
1011
+ )
1012
+ if adapter._set_spinner and not adapter._current_tool_messages:
1013
+ await adapter._set_spinner("Thinking")
1014
+
1015
+ # Flush any remaining text from all namespaces
1016
+ for ns_key, pending_text in list(pending_text_by_namespace.items()):
1017
+ if pending_text:
1018
+ await _flush_assistant_text_ns(
1019
+ adapter, pending_text, ns_key, assistant_message_by_namespace
1020
+ )
1021
+ pending_text_by_namespace.clear()
1022
+ assistant_message_by_namespace.clear()
1023
+
1024
+ # Handle HITL after stream completes
1025
+ if interrupt_occurred:
1026
+ any_rejected = False
1027
+ resume_payload: dict[str, Any] = {}
1028
+
1029
+ for interrupt_id, ask_req in list(pending_ask_user.items()):
1030
+ questions = ask_req["questions"]
1031
+
1032
+ if adapter._request_ask_user:
1033
+ if adapter._set_spinner:
1034
+ await adapter._set_spinner(None)
1035
+ result: dict[str, Any] = {
1036
+ "type": "error",
1037
+ "error": "ask_user callback returned no response",
1038
+ }
1039
+ try:
1040
+ future = await adapter._request_ask_user(questions)
1041
+ except Exception:
1042
+ logger.exception("Failed to mount ask_user widget")
1043
+ result = {
1044
+ "type": "error",
1045
+ "error": "failed to display ask_user prompt",
1046
+ }
1047
+ future = None
1048
+
1049
+ if future is None:
1050
+ logger.error("ask_user callback returned no Future; reporting as error")
1051
+ else:
1052
+ try:
1053
+ future_result = await future
1054
+ if isinstance(future_result, dict):
1055
+ result = future_result
1056
+ else:
1057
+ logger.error(
1058
+ "ask_user future returned non-dict result: %s",
1059
+ type(future_result).__name__,
1060
+ )
1061
+ result = {
1062
+ "type": "error",
1063
+ "error": "invalid ask_user widget result",
1064
+ }
1065
+ except Exception:
1066
+ logger.exception(
1067
+ "ask_user future resolution failed; reporting as error"
1068
+ )
1069
+ result = {
1070
+ "type": "error",
1071
+ "error": "failed to receive ask_user response",
1072
+ }
1073
+
1074
+ result_type = result.get("type")
1075
+ if result_type == "answered":
1076
+ answers = result.get("answers", [])
1077
+ if isinstance(answers, list):
1078
+ resume_payload[interrupt_id] = {"answers": answers}
1079
+ tool_id = ask_req["tool_call_id"]
1080
+ if tool_id in adapter._current_tool_messages:
1081
+ tool_msg = adapter._current_tool_messages[tool_id]
1082
+ tool_msg.set_success("User answered")
1083
+ adapter._current_tool_messages.pop(tool_id, None)
1084
+ else:
1085
+ logger.error(
1086
+ "ask_user answered payload had non-list answers: %s",
1087
+ type(answers).__name__,
1088
+ )
1089
+ resume_payload[interrupt_id] = {
1090
+ "status": "error",
1091
+ "error": "invalid ask_user answers payload",
1092
+ "answers": ["" for _ in questions],
1093
+ }
1094
+ any_rejected = True
1095
+ elif result_type == "cancelled":
1096
+ resume_payload[interrupt_id] = {
1097
+ "status": "cancelled",
1098
+ "answers": ["" for _ in questions],
1099
+ }
1100
+ any_rejected = True
1101
+ else:
1102
+ error_text = result.get("error")
1103
+ if not isinstance(error_text, str) or not error_text:
1104
+ error_text = "ask_user interaction failed"
1105
+ resume_payload[interrupt_id] = {
1106
+ "status": "error",
1107
+ "error": error_text,
1108
+ "answers": ["" for _ in questions],
1109
+ }
1110
+ any_rejected = True
1111
+ else:
1112
+ logger.warning(
1113
+ "ask_user interrupt received but no UI callback is registered; reporting as error"
1114
+ )
1115
+ resume_payload[interrupt_id] = {
1116
+ "status": "error",
1117
+ "error": "ask_user not supported by this UI",
1118
+ "answers": ["" for _ in questions],
1119
+ }
1120
+
1121
+ for interrupt_id, hitl_request in list(pending_interrupts.items()):
1122
+ action_requests = hitl_request["action_requests"]
1123
+
1124
+ if session_state.auto_approve:
1125
+ decisions: list[HITLDecision] = [
1126
+ ApproveDecision(type="approve") for _ in action_requests
1127
+ ]
1128
+ resume_payload[interrupt_id] = {"decisions": decisions}
1129
+ for tool_msg in list(adapter._current_tool_messages.values()):
1130
+ tool_msg.set_running()
1131
+ else:
1132
+ # Batch approval - one dialog for all parallel tool calls
1133
+ await dispatch_hook(
1134
+ "permission.request",
1135
+ {"tool_names": [r.get("name", "") for r in action_requests]},
1136
+ )
1137
+ future = await adapter._request_approval(action_requests, assistant_id)
1138
+ decision = await future
1139
+
1140
+ if isinstance(decision, dict):
1141
+ decision_type = decision.get("type")
1142
+
1143
+ if decision_type == "auto_approve_all":
1144
+ session_state.auto_approve = True
1145
+ if adapter._on_auto_approve_enabled:
1146
+ adapter._on_auto_approve_enabled()
1147
+ decisions = [
1148
+ ApproveDecision(type="approve") for _ in action_requests
1149
+ ]
1150
+ tool_msgs = list(adapter._current_tool_messages.values())
1151
+ for tool_msg in tool_msgs:
1152
+ tool_msg.set_running()
1153
+ for action_request in action_requests:
1154
+ tool_name = action_request.get("name")
1155
+ if tool_name in {
1156
+ "write_file",
1157
+ "edit_file",
1158
+ }:
1159
+ args = action_request.get("args", {})
1160
+ if isinstance(args, dict):
1161
+ file_op_tracker.mark_hitl_approved(tool_name, args)
1162
+
1163
+ elif decision_type == "approve":
1164
+ decisions = [
1165
+ ApproveDecision(type="approve") for _ in action_requests
1166
+ ]
1167
+ tool_msgs = list(adapter._current_tool_messages.values())
1168
+ for tool_msg in tool_msgs:
1169
+ tool_msg.set_running()
1170
+ for action_request in action_requests:
1171
+ tool_name = action_request.get("name")
1172
+ if tool_name in {
1173
+ "write_file",
1174
+ "edit_file",
1175
+ }:
1176
+ args = action_request.get("args", {})
1177
+ if isinstance(args, dict):
1178
+ file_op_tracker.mark_hitl_approved(tool_name, args)
1179
+
1180
+ elif decision_type == "reject":
1181
+ decisions = [RejectDecision(type="reject") for _ in action_requests]
1182
+ tool_msgs = list(adapter._current_tool_messages.values())
1183
+ for tool_msg in tool_msgs:
1184
+ tool_msg.set_rejected()
1185
+ adapter._current_tool_messages.clear()
1186
+ any_rejected = True
1187
+ else:
1188
+ logger.warning(
1189
+ "Unexpected HITL decision type: %s",
1190
+ decision_type,
1191
+ )
1192
+ decisions = [RejectDecision(type="reject") for _ in action_requests]
1193
+ for tool_msg in list(adapter._current_tool_messages.values()):
1194
+ tool_msg.set_rejected()
1195
+ adapter._current_tool_messages.clear()
1196
+ any_rejected = True
1197
+ else:
1198
+ logger.warning(
1199
+ "HITL decision was not a dict: %s",
1200
+ type(decision).__name__,
1201
+ )
1202
+ decisions = [RejectDecision(type="reject") for _ in action_requests]
1203
+ for tool_msg in list(adapter._current_tool_messages.values()):
1204
+ tool_msg.set_rejected()
1205
+ adapter._current_tool_messages.clear()
1206
+ any_rejected = True
1207
+
1208
+ resume_payload[interrupt_id] = {"decisions": decisions}
1209
+
1210
+ if any_rejected:
1211
+ break
1212
+
1213
+ suppress_resumed_output = any_rejected
1214
+
1215
+ if interrupt_occurred and resume_payload:
1216
+ if suppress_resumed_output and not pending_ask_user:
1217
+ await adapter._mount_message(
1218
+ AppMessage("Command rejected. Tell the agent what you'd like instead.")
1219
+ )
1220
+ turn_stats.wall_time_seconds = time.monotonic() - start_time
1221
+ return turn_stats
1222
+
1223
+ stream_input = Command(resume=resume_payload)
1224
+ else:
1225
+ await dispatch_hook("task.complete", {"thread_id": thread_id})
1226
+ break
1227
+
1228
+ except (asyncio.CancelledError, KeyboardInterrupt):
1229
+ await _handle_interrupt_cleanup(
1230
+ adapter=adapter,
1231
+ agent=agent,
1232
+ config=config,
1233
+ pending_text_by_namespace=pending_text_by_namespace,
1234
+ captured_input_tokens=captured_input_tokens,
1235
+ captured_output_tokens=captured_output_tokens,
1236
+ turn_stats=turn_stats,
1237
+ start_time=start_time,
1238
+ )
1239
+ return turn_stats
1240
+
1241
+ # Update token count and return stats
1242
+ turn_stats.wall_time_seconds = time.monotonic() - start_time
1243
+ await _report_and_persist_tokens(
1244
+ adapter,
1245
+ agent,
1246
+ config,
1247
+ captured_input_tokens,
1248
+ captured_output_tokens,
1249
+ )
1250
+ return turn_stats
1251
+
1252
+
1253
+ async def _handle_interrupt_cleanup(
1254
+ *,
1255
+ adapter: TextualUIAdapter,
1256
+ agent: Any, # noqa: ANN401 # Dynamic agent graph type
1257
+ config: RunnableConfig,
1258
+ pending_text_by_namespace: dict[tuple, str],
1259
+ captured_input_tokens: int,
1260
+ captured_output_tokens: int,
1261
+ turn_stats: SessionStats,
1262
+ start_time: float,
1263
+ ) -> None:
1264
+ """Shared cleanup for CancelledError and KeyboardInterrupt.
1265
+
1266
+ Args:
1267
+ adapter: UI adapter with display callbacks.
1268
+ agent: The LangGraph agent.
1269
+ config: Runnable config with `thread_id`.
1270
+ pending_text_by_namespace: Accumulated text per namespace.
1271
+ captured_input_tokens: Input tokens captured before interrupt.
1272
+ captured_output_tokens: Output tokens captured before interrupt.
1273
+ turn_stats: Stats for the current turn.
1274
+ start_time: Monotonic timestamp when the turn began.
1275
+ """
1276
+ from langchain_core.messages import HumanMessage
1277
+
1278
+ # Clear active message immediately so it won't block pruning.
1279
+ # If we don't do this, the store still thinks it's active and protects
1280
+ # from pruning, which breaks get_messages_to_prune(), potentially
1281
+ # blocking all future pruning.
1282
+ if adapter._set_active_message:
1283
+ adapter._set_active_message(None)
1284
+
1285
+ # Hide spinner (may still show a stale status if interrupted)
1286
+ if adapter._set_spinner:
1287
+ await adapter._set_spinner(None)
1288
+
1289
+ await adapter._mount_message(AppMessage("Interrupted by user"))
1290
+
1291
+ interrupted_msg = _build_interrupted_ai_message(
1292
+ pending_text_by_namespace,
1293
+ adapter._current_tool_messages,
1294
+ )
1295
+
1296
+ # Save accumulated state before marking tools as rejected (best-effort).
1297
+ # State update failures shouldn't prevent cleanup.
1298
+ try:
1299
+ if interrupted_msg:
1300
+ await agent.aupdate_state(config, {"messages": [interrupted_msg]})
1301
+
1302
+ cancellation_msg = HumanMessage(
1303
+ content="[SYSTEM] Task interrupted by user. Previous operation was cancelled."
1304
+ )
1305
+ await agent.aupdate_state(config, {"messages": [cancellation_msg]})
1306
+ except Exception:
1307
+ logger.warning("Failed to save interrupted state", exc_info=True)
1308
+
1309
+ # Mark tools as rejected AFTER saving state
1310
+ for tool_msg in list(adapter._current_tool_messages.values()):
1311
+ tool_msg.set_rejected()
1312
+ adapter._current_tool_messages.clear()
1313
+
1314
+ # Keep the token count marked stale whenever interrupted state was captured,
1315
+ # including tool-only turns after assistant text was already flushed.
1316
+ approximate = interrupted_msg is not None
1317
+
1318
+ turn_stats.wall_time_seconds = time.monotonic() - start_time
1319
+ await _report_and_persist_tokens(
1320
+ adapter,
1321
+ agent,
1322
+ config,
1323
+ captured_input_tokens,
1324
+ captured_output_tokens,
1325
+ shield=True,
1326
+ approximate=approximate,
1327
+ )
1328
+
1329
+
1330
+ async def _persist_context_tokens(
1331
+ agent: Any, # noqa: ANN401 # Dynamic agent graph type
1332
+ config: RunnableConfig,
1333
+ tokens: int,
1334
+ ) -> None:
1335
+ """Best-effort persist of the context token count into graph state.
1336
+
1337
+ Args:
1338
+ agent: The LangGraph agent (must support `aupdate_state`).
1339
+ config: Runnable config with `thread_id`.
1340
+ tokens: Total context tokens to persist.
1341
+ """
1342
+ try:
1343
+ await agent.aupdate_state(config, {"_context_tokens": tokens})
1344
+ except Exception: # non-critical; stale count on resume is acceptable
1345
+ logger.warning(
1346
+ "Failed to persist _context_tokens=%d; token count may be stale on resume",
1347
+ tokens,
1348
+ exc_info=True,
1349
+ )
1350
+
1351
+
1352
+ async def _report_and_persist_tokens(
1353
+ adapter: TextualUIAdapter,
1354
+ agent: Any, # noqa: ANN401 # Dynamic agent graph type
1355
+ config: RunnableConfig,
1356
+ captured_input_tokens: int,
1357
+ captured_output_tokens: int,
1358
+ *,
1359
+ shield: bool = False,
1360
+ approximate: bool = False,
1361
+ ) -> None:
1362
+ """Update the token display and best-effort persist to graph state.
1363
+
1364
+ Args:
1365
+ adapter: UI adapter with token callbacks.
1366
+ agent: The LangGraph agent.
1367
+ config: Runnable config with `thread_id` in its configurable dict.
1368
+ captured_input_tokens: Total input tokens captured during the turn.
1369
+ captured_output_tokens: Total output tokens captured during the turn.
1370
+ shield: When `True`, suppress exceptions and `CancelledError` from the
1371
+ persist call so that interrupt handlers can safely await this.
1372
+ approximate: When `True`, signal to the UI that the count is stale
1373
+ (e.g. after an interrupted generation) by appending "+".
1374
+ """
1375
+ if captured_input_tokens or captured_output_tokens:
1376
+ if adapter._on_tokens_update:
1377
+ adapter._on_tokens_update(captured_input_tokens, approximate=approximate)
1378
+ if shield:
1379
+ try:
1380
+ await _persist_context_tokens(agent, config, captured_input_tokens)
1381
+ except (Exception, asyncio.CancelledError):
1382
+ logger.debug(
1383
+ "Token persist suppressed during interrupt cleanup",
1384
+ exc_info=True,
1385
+ )
1386
+ else:
1387
+ await _persist_context_tokens(agent, config, captured_input_tokens)
1388
+ elif adapter._on_tokens_show:
1389
+ adapter._on_tokens_show(approximate=approximate)
1390
+
1391
+
1392
+ async def _flush_assistant_text_ns(
1393
+ adapter: TextualUIAdapter,
1394
+ text: str,
1395
+ ns_key: tuple,
1396
+ assistant_message_by_namespace: dict[tuple, Any],
1397
+ ) -> None:
1398
+ """Flush accumulated assistant text for a specific namespace.
1399
+
1400
+ Finalizes the streaming by stopping the MarkdownStream.
1401
+ If no message exists yet, creates one with the full content.
1402
+ """
1403
+ if not text.strip():
1404
+ return
1405
+
1406
+ current_msg = assistant_message_by_namespace.get(ns_key)
1407
+ if current_msg is None:
1408
+ # No message was created during streaming - create one with full content
1409
+ msg_id = f"asst-{uuid.uuid4().hex[:8]}"
1410
+ current_msg = AssistantMessage(text, id=msg_id)
1411
+ await adapter._mount_message(current_msg)
1412
+ await current_msg.write_initial_content()
1413
+ assistant_message_by_namespace[ns_key] = current_msg
1414
+ else:
1415
+ # Stop the stream to finalize the content
1416
+ await current_msg.stop_stream()
1417
+
1418
+ # When the AssistantMessage was first mounted and recorded in the
1419
+ # MessageStore, it had empty content (streaming hadn't started yet).
1420
+ # Now that streaming is done, the widget holds the full text in
1421
+ # `_content`, but the store's MessageData still has `content=""`.
1422
+ # If the message is later pruned and re-hydrated, `to_widget()` would
1423
+ # recreate it from that stale empty string. This call copies the
1424
+ # widget's final content back into the store so re-hydration works.
1425
+ if adapter._sync_message_content and current_msg.id:
1426
+ adapter._sync_message_content(current_msg.id, current_msg._content)
1427
+
1428
+ # Clear active message since streaming is done
1429
+ if adapter._set_active_message:
1430
+ adapter._set_active_message(None)