celltype-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. celltype_cli-0.1.0.dist-info/METADATA +267 -0
  2. celltype_cli-0.1.0.dist-info/RECORD +89 -0
  3. celltype_cli-0.1.0.dist-info/WHEEL +4 -0
  4. celltype_cli-0.1.0.dist-info/entry_points.txt +2 -0
  5. celltype_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  6. ct/__init__.py +3 -0
  7. ct/agent/__init__.py +0 -0
  8. ct/agent/case_studies.py +426 -0
  9. ct/agent/config.py +523 -0
  10. ct/agent/doctor.py +544 -0
  11. ct/agent/knowledge.py +523 -0
  12. ct/agent/loop.py +99 -0
  13. ct/agent/mcp_server.py +478 -0
  14. ct/agent/orchestrator.py +733 -0
  15. ct/agent/runner.py +656 -0
  16. ct/agent/sandbox.py +481 -0
  17. ct/agent/session.py +145 -0
  18. ct/agent/system_prompt.py +186 -0
  19. ct/agent/trace_store.py +228 -0
  20. ct/agent/trajectory.py +169 -0
  21. ct/agent/types.py +182 -0
  22. ct/agent/workflows.py +462 -0
  23. ct/api/__init__.py +1 -0
  24. ct/api/app.py +211 -0
  25. ct/api/config.py +120 -0
  26. ct/api/engine.py +124 -0
  27. ct/cli.py +1448 -0
  28. ct/data/__init__.py +0 -0
  29. ct/data/compute_providers.json +59 -0
  30. ct/data/cro_database.json +395 -0
  31. ct/data/downloader.py +238 -0
  32. ct/data/loaders.py +252 -0
  33. ct/kb/__init__.py +5 -0
  34. ct/kb/benchmarks.py +147 -0
  35. ct/kb/governance.py +106 -0
  36. ct/kb/ingest.py +415 -0
  37. ct/kb/reasoning.py +129 -0
  38. ct/kb/schema_monitor.py +162 -0
  39. ct/kb/substrate.py +387 -0
  40. ct/models/__init__.py +0 -0
  41. ct/models/llm.py +370 -0
  42. ct/tools/__init__.py +195 -0
  43. ct/tools/_compound_resolver.py +297 -0
  44. ct/tools/biomarker.py +368 -0
  45. ct/tools/cellxgene.py +282 -0
  46. ct/tools/chemistry.py +1371 -0
  47. ct/tools/claude.py +390 -0
  48. ct/tools/clinical.py +1153 -0
  49. ct/tools/clue.py +249 -0
  50. ct/tools/code.py +1069 -0
  51. ct/tools/combination.py +397 -0
  52. ct/tools/compute.py +402 -0
  53. ct/tools/cro.py +413 -0
  54. ct/tools/data_api.py +2114 -0
  55. ct/tools/design.py +295 -0
  56. ct/tools/dna.py +575 -0
  57. ct/tools/experiment.py +604 -0
  58. ct/tools/expression.py +655 -0
  59. ct/tools/files.py +957 -0
  60. ct/tools/genomics.py +1387 -0
  61. ct/tools/http_client.py +146 -0
  62. ct/tools/imaging.py +319 -0
  63. ct/tools/intel.py +223 -0
  64. ct/tools/literature.py +743 -0
  65. ct/tools/network.py +422 -0
  66. ct/tools/notification.py +111 -0
  67. ct/tools/omics.py +3330 -0
  68. ct/tools/ops.py +1230 -0
  69. ct/tools/parity.py +649 -0
  70. ct/tools/pk.py +245 -0
  71. ct/tools/protein.py +678 -0
  72. ct/tools/regulatory.py +643 -0
  73. ct/tools/remote_data.py +179 -0
  74. ct/tools/report.py +181 -0
  75. ct/tools/repurposing.py +376 -0
  76. ct/tools/safety.py +1280 -0
  77. ct/tools/shell.py +178 -0
  78. ct/tools/singlecell.py +533 -0
  79. ct/tools/statistics.py +552 -0
  80. ct/tools/structure.py +882 -0
  81. ct/tools/target.py +901 -0
  82. ct/tools/translational.py +123 -0
  83. ct/tools/viability.py +218 -0
  84. ct/ui/__init__.py +0 -0
  85. ct/ui/markdown.py +31 -0
  86. ct/ui/status.py +258 -0
  87. ct/ui/suggestions.py +567 -0
  88. ct/ui/terminal.py +1456 -0
  89. ct/ui/traces.py +112 -0
ct/agent/runner.py ADDED
@@ -0,0 +1,656 @@
1
+ """
2
+ AgentRunner: query entry point using the Claude Agent SDK.
3
+
4
+ Replaces the Plan-then-Execute architecture (Planner → Executor → Synthesis)
5
+ with a single agentic loop where Claude directly orchestrates all domain tools.
6
+
7
+ Uses ``ClaudeSDKClient`` (not ``query()``) because only the client supports
8
+ custom MCP tools.
9
+ """
10
+
11
+ import asyncio
12
+ import logging
13
+ import os
14
+ import time
15
+ import traceback
16
+
17
+ from ct.agent.types import ExecutionResult, Plan, Step
18
+
19
+ logger = logging.getLogger("ct.runner")
20
+
21
+
22
+ # ------------------------------------------------------------------
23
+ # Testable message processing (extracted from _run_async)
24
+ # ------------------------------------------------------------------
25
+
26
+ async def process_messages(
27
+ messages_iter,
28
+ trace_renderer=None,
29
+ headless=False,
30
+ trace_events: list[dict] | None = None,
31
+ thinking_status=None,
32
+ runner=None,
33
+ on_activity=None,
34
+ ):
35
+ """Process an async iterable of SDK messages into structured results.
36
+
37
+ This is extracted from ``AgentRunner._run_async`` so it can be tested
38
+ with mock message streams without a live SDK client.
39
+
40
+ Args:
41
+ messages_iter: Async iterable of SDK messages.
42
+ trace_renderer: Optional TraceRenderer for console output.
43
+ headless: If True, suppress console output.
44
+ trace_events: Optional list to append trace events to. When provided,
45
+ each TextBlock, ToolUseBlock, and ToolResultBlock produces a
46
+ trace event dict for downstream notebook/export consumers.
47
+ thinking_status: Optional ThinkingStatus to stop on first message.
48
+
49
+ Returns:
50
+ dict with keys: full_text, tool_calls, result_msg, streamed_len
51
+ """
52
+ # Lazy imports — these may not be available in unit tests without
53
+ # the SDK installed, but callers pass mock objects anyway.
54
+ try:
55
+ from claude_agent_sdk import (
56
+ AssistantMessage,
57
+ ResultMessage,
58
+ TextBlock,
59
+ ToolUseBlock,
60
+ ToolResultBlock,
61
+ StreamEvent,
62
+ )
63
+ except ImportError:
64
+ from claude_agent_sdk import (
65
+ AssistantMessage,
66
+ ResultMessage,
67
+ TextBlock,
68
+ ToolUseBlock,
69
+ )
70
+ ToolResultBlock = None
71
+ StreamEvent = None
72
+
73
+ full_text: list[str] = []
74
+ tool_calls: list[dict] = []
75
+ inflight: dict[str, dict] = {} # tool_use_id → {name, input, start_time}
76
+ result_msg = None
77
+ streamed_len = 0 # characters already displayed via StreamEvent
78
+
79
+ async for message in messages_iter:
80
+
81
+ # --- StreamEvent (partial streaming) ---
82
+ if StreamEvent is not None and isinstance(message, StreamEvent):
83
+ event = getattr(message, "event", None) or {}
84
+ if isinstance(event, dict):
85
+ delta = event.get("delta", {})
86
+ if isinstance(delta, dict) and delta.get("type") == "text_delta":
87
+ text = delta.get("text", "")
88
+ if text:
89
+ # Track streamed length but don't print raw text —
90
+ # the full TextBlock will be rendered as markdown
91
+ streamed_len += len(text)
92
+ continue
93
+
94
+ # --- AssistantMessage ---
95
+ if isinstance(message, AssistantMessage):
96
+ for block in (message.content or []):
97
+ if isinstance(block, TextBlock):
98
+ # Stop the spinner when showing complete text block
99
+ if thinking_status is not None:
100
+ thinking_status.stop()
101
+ thinking_status = None
102
+ if runner is not None:
103
+ runner._active_spinner = None
104
+
105
+ text = block.text or ""
106
+ full_text.append(text)
107
+ # Trace capture
108
+ if trace_events is not None and text.strip():
109
+ trace_events.append({
110
+ "type": "text",
111
+ "content": text,
112
+ "timestamp": time.time(),
113
+ })
114
+ # Render as markdown (streamed deltas are tracked but not printed)
115
+ if not headless and trace_renderer:
116
+ streamed_len = 0 # reset for next turn
117
+ trace_renderer.render_reasoning(text)
118
+ # Activity callback — show snippet of reasoning
119
+ if on_activity and text.strip():
120
+ snippet = text.strip().replace("\n", " ")[:40]
121
+ on_activity(snippet)
122
+
123
+ elif isinstance(block, ToolUseBlock):
124
+ # Restart spinner while waiting for tool result
125
+ if thinking_status is None and not headless and trace_renderer:
126
+ try:
127
+ from ct.ui.status import ThinkingStatus
128
+ thinking_status = ThinkingStatus(trace_renderer.console, phase="evaluating")
129
+ thinking_status.__enter__()
130
+ thinking_status.start_async_refresh()
131
+ if runner is not None:
132
+ runner._active_spinner = thinking_status
133
+ except ImportError:
134
+ pass
135
+
136
+ block_id = getattr(block, "id", "") or ""
137
+ now = time.time()
138
+ inflight[block_id] = {
139
+ "name": block.name,
140
+ "input": block.input,
141
+ "start_time": now,
142
+ }
143
+ tool_calls.append({
144
+ "name": block.name,
145
+ "input": block.input,
146
+ })
147
+ # Trace capture
148
+ if trace_events is not None:
149
+ trace_events.append({
150
+ "type": "tool_start",
151
+ "tool": block.name.replace("mcp__ct-tools__", ""),
152
+ "input": block.input,
153
+ "tool_use_id": block_id,
154
+ "timestamp": now,
155
+ })
156
+ if not headless and trace_renderer:
157
+ trace_renderer.render_tool_start(block.name, block.input)
158
+ # Activity callback — show tool name
159
+ if on_activity:
160
+ clean = block.name.replace("mcp__ct-tools__", "")
161
+ on_activity(f"\u25b8 {clean}")
162
+
163
+ elif ToolResultBlock is not None and isinstance(block, ToolResultBlock):
164
+ tool_use_id = getattr(block, "tool_use_id", "") or ""
165
+ is_error = getattr(block, "is_error", False)
166
+
167
+ # Extract result text from content
168
+ content = getattr(block, "content", None)
169
+ result_text = ""
170
+ if isinstance(content, list):
171
+ for item in content:
172
+ if isinstance(item, dict) and item.get("type") == "text":
173
+ result_text += item.get("text", "")
174
+ elif isinstance(content, str):
175
+ result_text = content
176
+
177
+ # Match to inflight tracker
178
+ tracked = inflight.pop(tool_use_id, None)
179
+ duration = 0.0
180
+ tool_name = ""
181
+ tool_input = {}
182
+ if tracked:
183
+ duration = time.time() - tracked["start_time"]
184
+ tool_name = tracked["name"]
185
+ tool_input = tracked["input"]
186
+ else:
187
+ logger.warning(
188
+ "Orphan ToolResultBlock with tool_use_id=%s",
189
+ tool_use_id,
190
+ )
191
+
192
+ # Update the matching tool_calls entry with result
193
+ for tc in reversed(tool_calls):
194
+ if tc["name"] == tool_name and "result_text" not in tc:
195
+ tc["result_text"] = result_text
196
+ tc["duration_s"] = duration
197
+ break
198
+
199
+ # Trace capture
200
+ if trace_events is not None:
201
+ clean_tool = tool_name.replace("mcp__ct-tools__", "")
202
+ trace_events.append({
203
+ "type": "tool_result",
204
+ "tool": clean_tool,
205
+ "tool_use_id": tool_use_id,
206
+ "result_text": result_text,
207
+ "is_error": is_error,
208
+ "duration_s": duration,
209
+ "timestamp": time.time(),
210
+ })
211
+
212
+ if not headless and trace_renderer:
213
+ if is_error:
214
+ trace_renderer.render_tool_error(
215
+ tool_name or "unknown", result_text
216
+ )
217
+ else:
218
+ trace_renderer.render_tool_complete(
219
+ tool_name or "unknown",
220
+ tool_input,
221
+ result_text,
222
+ duration,
223
+ )
224
+
225
+ # --- ResultMessage ---
226
+ elif isinstance(message, ResultMessage):
227
+ # Final message, make sure animation is stopped
228
+ if thinking_status is not None:
229
+ thinking_status.stop()
230
+ thinking_status = None
231
+ if runner is not None:
232
+ runner._active_spinner = None
233
+
234
+ result_msg = message
235
+
236
+ return {
237
+ "full_text": full_text,
238
+ "tool_calls": tool_calls,
239
+ "result_msg": result_msg,
240
+ "streamed_len": streamed_len,
241
+ }
242
+
243
+
244
+ class AgentRunner:
245
+ """Run queries via the Claude Agent SDK agentic loop.
246
+
247
+ All 192 domain tools are exposed as MCP tools. Claude handles planning,
248
+ execution, error recovery, and synthesis in one conversation.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ session,
254
+ trajectory=None,
255
+ headless: bool = False,
256
+ trace_store=None,
257
+ ):
258
+ self.session = session
259
+ self.trajectory = trajectory
260
+ self._headless = headless
261
+ self.trace_store = trace_store
262
+
263
+ # ------------------------------------------------------------------
264
+ # Public API
265
+ # ------------------------------------------------------------------
266
+
267
+ def run(
268
+ self,
269
+ query: str,
270
+ context: dict | None = None,
271
+ progress_callback=None,
272
+ ) -> ExecutionResult:
273
+ """Execute a query synchronously (blocking wrapper around async)."""
274
+ return asyncio.run(
275
+ self._run_async(query, context, progress_callback)
276
+ )
277
+
278
+ async def _run_async(
279
+ self,
280
+ query: str,
281
+ context: dict | None = None,
282
+ progress_callback=None,
283
+ ) -> ExecutionResult:
284
+ """Execute a query using the Agent SDK agentic loop.
285
+
286
+ Uses ``ClaudeSDKClient`` (bidirectional client) which supports custom
287
+ MCP tools, unlike ``query()`` which does not.
288
+ """
289
+ from claude_agent_sdk import (
290
+ ClaudeSDKClient,
291
+ ClaudeAgentOptions,
292
+ )
293
+
294
+ # Start spinner immediately — user should see feedback the moment they hit Enter
295
+ thinking_status = None
296
+ if not self._headless:
297
+ from ct.ui.status import ThinkingStatus
298
+ thinking_status = ThinkingStatus(self.session.console, phase="planning")
299
+ thinking_status.__enter__()
300
+ thinking_status.start_async_refresh()
301
+ self._active_spinner = thinking_status
302
+ from ct.agent.mcp_server import create_ct_mcp_server
303
+ from ct.agent.system_prompt import build_system_prompt
304
+ from ct.ui.traces import TraceRenderer
305
+
306
+ t0 = time.time()
307
+ config = self.session.config
308
+ ctx = context or {}
309
+
310
+ # ----- Build MCP server -----
311
+ exclude_cats = set()
312
+ if not config.get("agent.enable_experimental_tools", False):
313
+ from ct.tools import EXPERIMENTAL_CATEGORIES
314
+ exclude_cats = set(EXPERIMENTAL_CATEGORIES)
315
+
316
+ server, sandbox, tool_names, code_trace_buffer = create_ct_mcp_server(
317
+ self.session,
318
+ exclude_categories=exclude_cats,
319
+ )
320
+
321
+ # ----- Build system prompt -----
322
+ data_context = None
323
+ data_dir = ctx.get("data_dir")
324
+ if data_dir:
325
+ data_context = f"Data directory: {data_dir}\n"
326
+ config.set("sandbox.extra_read_dirs", str(data_dir))
327
+
328
+ history = None
329
+ if self.trajectory and self.trajectory.turns:
330
+ history = self.trajectory.context_for_planner()
331
+
332
+ system_prompt = build_system_prompt(
333
+ self.session,
334
+ tool_names=tool_names,
335
+ data_context=data_context,
336
+ history=history,
337
+ )
338
+
339
+ # ----- Configure Agent SDK -----
340
+ model = config.get("llm.model") or "claude-sonnet-4-5-20250929"
341
+ max_turns = int(config.get("agent.max_sdk_turns", 30))
342
+
343
+ allowed_tools = [f"mcp__ct-tools__{name}" for name in tool_names]
344
+
345
+ _STRIP_VARS = {
346
+ "CLAUDECODE",
347
+ "CLAUDE_CODE_SESSION_ID",
348
+ "CLAUDE_CODE_PARENT_SESSION_ID",
349
+ }
350
+ clean_env = {
351
+ k: v for k, v in os.environ.items()
352
+ if k not in _STRIP_VARS
353
+ }
354
+ api_key = config.llm_api_key("anthropic")
355
+ if api_key:
356
+ clean_env["ANTHROPIC_API_KEY"] = api_key
357
+ # Suppress warnings in SDK subprocess (matplotlib, pydeseq2, numpy, etc.)
358
+ clean_env["PYTHONWARNINGS"] = "ignore"
359
+
360
+ # Plan mode: use SDK's built-in plan permission mode.
361
+ # In plan mode, Claude outputs a plan then calls ExitPlanMode.
362
+ # We intercept that to show the plan and ask for approval.
363
+ plan_preview = bool(config.get("agent.plan_preview", False))
364
+ permission_mode = "plan" if (plan_preview and not self._headless) else "bypassPermissions"
365
+
366
+ # Enable streaming for real-time output
367
+ options_kwargs = dict(
368
+ system_prompt=system_prompt,
369
+ model=model,
370
+ max_turns=max_turns,
371
+ mcp_servers={"ct-tools": server},
372
+ allowed_tools=allowed_tools,
373
+ permission_mode=permission_mode,
374
+ env=clean_env,
375
+ hooks={}, # Disable inherited hooks (e.g. from Claude Code)
376
+ )
377
+
378
+ if plan_preview and not self._headless:
379
+ options_kwargs["can_use_tool"] = self._plan_approval_hook()
380
+
381
+ # Try to enable partial message streaming (graceful fallback)
382
+ try:
383
+ options = ClaudeAgentOptions(
384
+ include_partial_messages=True,
385
+ **options_kwargs,
386
+ )
387
+ except TypeError:
388
+ # SDK version doesn't support include_partial_messages
389
+ logger.info("SDK does not support include_partial_messages, using non-streaming")
390
+ options = ClaudeAgentOptions(**options_kwargs)
391
+
392
+ # ----- Build user prompt -----
393
+ user_prompt = query
394
+ context_parts = []
395
+ if ctx.get("compound_smiles"):
396
+ context_parts.append(f"Compound SMILES: {ctx['compound_smiles']}")
397
+ if ctx.get("target"):
398
+ context_parts.append(f"Target: {ctx['target']}")
399
+ if ctx.get("indication"):
400
+ context_parts.append(f"Indication: {ctx['indication']}")
401
+ # Inject mention context if present
402
+ if ctx.get("mention_context"):
403
+ context_parts.append(ctx["mention_context"])
404
+ if context_parts:
405
+ user_prompt = query + "\n\nContext:\n" + "\n".join(context_parts)
406
+
407
+ # ----- Create trace renderer -----
408
+ trace_renderer = TraceRenderer(self.session.console)
409
+
410
+ # ----- Prepare trace capture -----
411
+ trace_events: list[dict] | None = None
412
+ if self.trace_store is not None:
413
+ trace_events = []
414
+
415
+ # ----- Run the agentic loop via ClaudeSDKClient -----
416
+ try:
417
+ async with ClaudeSDKClient(options=options) as client:
418
+ await client.query(user_prompt)
419
+ result = await process_messages(
420
+ client.receive_response(),
421
+ trace_renderer=trace_renderer,
422
+ headless=self._headless,
423
+ trace_events=trace_events,
424
+ thinking_status=thinking_status,
425
+ runner=self,
426
+ on_activity=progress_callback,
427
+ )
428
+ except Exception as e:
429
+ logger.error("Agent SDK query failed: %s\n%s", e, traceback.format_exc())
430
+ duration = time.time() - t0
431
+ return self._make_error_result(query, str(e), duration)
432
+ finally:
433
+ # Ensure animation is cleaned up even on error
434
+ if thinking_status is not None:
435
+ thinking_status.stop()
436
+
437
+ duration = time.time() - t0
438
+
439
+ full_text = result["full_text"]
440
+ tool_calls = result["tool_calls"]
441
+ result_msg = result["result_msg"]
442
+
443
+ # ----- Build ExecutionResult -----
444
+ summary = "\n".join(full_text).strip()
445
+ if not summary:
446
+ summary = "(Agent produced no text output)"
447
+
448
+ answer = None
449
+ if sandbox:
450
+ result_var = sandbox.get_variable("result")
451
+ if isinstance(result_var, dict):
452
+ answer = result_var.get("answer")
453
+
454
+ steps = []
455
+ for i, tc in enumerate(tool_calls, 1):
456
+ step = Step(
457
+ id=i,
458
+ tool=tc["name"].replace("mcp__ct-tools__", ""),
459
+ description=f"Called {tc['name']}",
460
+ tool_args=tc.get("input", {}),
461
+ )
462
+ step.status = "completed"
463
+ steps.append(step)
464
+
465
+ plan = Plan(query=query, steps=steps)
466
+
467
+ cost_usd = 0.0
468
+ if result_msg:
469
+ cost_usd = getattr(result_msg, "total_cost_usd", 0.0) or 0.0
470
+
471
+ exec_result = ExecutionResult(
472
+ plan=plan,
473
+ summary=summary,
474
+ raw_results={"tool_calls": tool_calls, "answer": answer},
475
+ duration_s=duration,
476
+ iterations=1,
477
+ metadata={
478
+ "sdk_cost_usd": cost_usd,
479
+ "sdk_turns": getattr(result_msg, "num_turns", 0) if result_msg else 0,
480
+ "sdk_duration_ms": getattr(result_msg, "duration_ms", 0) if result_msg else 0,
481
+ "tool_call_count": len(tool_calls),
482
+ },
483
+ )
484
+
485
+ # ----- Inject tool_result events from code_trace_buffer -----
486
+ # The SDK stream typically does NOT include ToolResultBlock messages,
487
+ # so process_messages() only produces tool_start events for code tools.
488
+ # MCP handlers write structured results (code, stdout, plots) to
489
+ # code_trace_buffer. We match buffer entries to tool_start events
490
+ # by tool name in sequential order, and insert tool_result events
491
+ # immediately after each tool_start.
492
+ if trace_events is not None and trace_events:
493
+ buffer_iter = iter(code_trace_buffer)
494
+ # Also create tool_result events for non-code tools from tool_calls
495
+ non_code_results = {}
496
+ for tc in tool_calls:
497
+ name = tc["name"].replace("mcp__ct-tools__", "")
498
+ if name not in ("run_python", "run_r") and "result_text" in tc:
499
+ key = name + ":" + str(tc.get("input", {}))
500
+ non_code_results[key] = tc
501
+
502
+ enriched: list[dict] = []
503
+ non_code_iter_idx = {} # track which non-code tool_calls we've used
504
+ for event in trace_events:
505
+ enriched.append(event)
506
+ if event.get("type") != "tool_start":
507
+ continue
508
+
509
+ tool = event.get("tool", "")
510
+ tool_use_id = event.get("tool_use_id", "")
511
+
512
+ if tool in ("run_python", "run_r"):
513
+ meta = next(buffer_iter, None)
514
+ if meta:
515
+ enriched.append({
516
+ "type": "tool_result",
517
+ "tool": tool,
518
+ "tool_use_id": tool_use_id,
519
+ "result_text": meta.get("stdout", ""),
520
+ "is_error": bool(meta.get("error")),
521
+ "duration_s": 0.0,
522
+ "code": meta.get("code", ""),
523
+ "stdout": meta.get("stdout", ""),
524
+ "plots": meta.get("plots", []),
525
+ "exports": meta.get("exports", []),
526
+ "error": meta.get("error"),
527
+ "timestamp": time.time(),
528
+ })
529
+ else:
530
+ # For non-code tools, find matching result from tool_calls
531
+ for tc in tool_calls:
532
+ tc_name = tc["name"].replace("mcp__ct-tools__", "")
533
+ if tc_name == tool and "result_text" in tc and not tc.get("_used"):
534
+ tc["_used"] = True
535
+ enriched.append({
536
+ "type": "tool_result",
537
+ "tool": tool,
538
+ "tool_use_id": tool_use_id,
539
+ "result_text": tc["result_text"],
540
+ "is_error": False,
541
+ "duration_s": tc.get("duration_s", 0.0),
542
+ "timestamp": time.time(),
543
+ })
544
+ break
545
+
546
+ trace_events = enriched
547
+
548
+ # ----- Flush trace events -----
549
+ if self.trace_store is not None and trace_events:
550
+ try:
551
+ self.trace_store.add_events(
552
+ trace_events,
553
+ query=query,
554
+ model=model,
555
+ duration_s=duration,
556
+ cost_usd=cost_usd,
557
+ )
558
+ self.trace_store.flush()
559
+ except Exception as e:
560
+ logger.warning("Failed to flush trace: %s", e)
561
+
562
+ if not self._headless and result_msg:
563
+ self._print_usage(result_msg, duration)
564
+
565
+ return exec_result
566
+
567
+ # ------------------------------------------------------------------
568
+ # Plan mode
569
+ # ------------------------------------------------------------------
570
+
571
+ def _plan_approval_hook(self):
572
+ """Return a can_use_tool callback for SDK plan mode.
573
+
574
+ Intercepts the ExitPlanMode call to show Claude's plan and ask
575
+ for user approval. All other tool calls are auto-allowed.
576
+ """
577
+ console = self.session.console
578
+ # Shared ref so process_messages can keep it in sync
579
+ self._active_spinner = None
580
+
581
+ async def _hook(tool_name, input_data, context):
582
+ if tool_name == "ExitPlanMode":
583
+ # Stop the spinner so it doesn't interfere with input()
584
+ if self._active_spinner is not None:
585
+ self._active_spinner.stop()
586
+ self._active_spinner = None
587
+
588
+ # Claude is requesting to exit plan mode and start executing
589
+ console.print("\n [bold cyan]Proposed Plan[/bold cyan]")
590
+ # The plan text may be in the input data or in Claude's
591
+ # preceding text output (which the user already saw streamed).
592
+ if isinstance(input_data, dict):
593
+ for key in ("plan", "description", "summary"):
594
+ if key in input_data and input_data[key]:
595
+ console.print(f" {input_data[key]}")
596
+ break
597
+ console.print()
598
+
599
+ try:
600
+ answer = input(" Execute this plan? [Y/n] ").strip().lower()
601
+ except (EOFError, KeyboardInterrupt):
602
+ answer = "n"
603
+
604
+ if answer in ("", "y", "yes"):
605
+ return {"allow": True, "updated_input": input_data}
606
+ else:
607
+ # Ask what to change so Claude can revise the plan
608
+ try:
609
+ feedback = input(" What would you change? ").strip()
610
+ except (EOFError, KeyboardInterrupt):
611
+ feedback = ""
612
+
613
+ msg = f"User rejected the plan. Feedback: {feedback}" if feedback else "User rejected the plan."
614
+ return {"allow": False, "message": msg}
615
+
616
+ # All other tools: allow
617
+ return {"allow": True, "updated_input": input_data}
618
+
619
+ return _hook
620
+
621
+ # ------------------------------------------------------------------
622
+ # Console output helpers
623
+ # ------------------------------------------------------------------
624
+
625
+ def _print_usage(self, result_msg, duration: float):
626
+ """Print cost and usage summary."""
627
+ cost = getattr(result_msg, "total_cost_usd", 0)
628
+ turns = getattr(result_msg, "num_turns", 0)
629
+ parts = []
630
+ if cost:
631
+ parts.append(f"${cost:.2f}")
632
+ if turns:
633
+ parts.append(f"{turns} turns")
634
+ if duration >= 60:
635
+ mins = int(duration // 60)
636
+ secs = int(duration % 60)
637
+ parts.append(f"{mins}m {secs}s")
638
+ else:
639
+ parts.append(f"{duration:.1f}s")
640
+ self.session.console.print(f"\n [dim]{' | '.join(parts)}[/dim]")
641
+
642
+ # ------------------------------------------------------------------
643
+ # Error handling
644
+ # ------------------------------------------------------------------
645
+
646
+ @staticmethod
647
+ def _make_error_result(query: str, error: str, duration: float) -> ExecutionResult:
648
+ """Build an ExecutionResult representing a failed query."""
649
+ plan = Plan(query=query, steps=[])
650
+ return ExecutionResult(
651
+ plan=plan,
652
+ summary=f"Agent SDK error: {error}",
653
+ raw_results={"error": error},
654
+ duration_s=duration,
655
+ iterations=1,
656
+ )