celltype-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celltype_cli-0.1.0.dist-info/METADATA +267 -0
- celltype_cli-0.1.0.dist-info/RECORD +89 -0
- celltype_cli-0.1.0.dist-info/WHEEL +4 -0
- celltype_cli-0.1.0.dist-info/entry_points.txt +2 -0
- celltype_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- ct/__init__.py +3 -0
- ct/agent/__init__.py +0 -0
- ct/agent/case_studies.py +426 -0
- ct/agent/config.py +523 -0
- ct/agent/doctor.py +544 -0
- ct/agent/knowledge.py +523 -0
- ct/agent/loop.py +99 -0
- ct/agent/mcp_server.py +478 -0
- ct/agent/orchestrator.py +733 -0
- ct/agent/runner.py +656 -0
- ct/agent/sandbox.py +481 -0
- ct/agent/session.py +145 -0
- ct/agent/system_prompt.py +186 -0
- ct/agent/trace_store.py +228 -0
- ct/agent/trajectory.py +169 -0
- ct/agent/types.py +182 -0
- ct/agent/workflows.py +462 -0
- ct/api/__init__.py +1 -0
- ct/api/app.py +211 -0
- ct/api/config.py +120 -0
- ct/api/engine.py +124 -0
- ct/cli.py +1448 -0
- ct/data/__init__.py +0 -0
- ct/data/compute_providers.json +59 -0
- ct/data/cro_database.json +395 -0
- ct/data/downloader.py +238 -0
- ct/data/loaders.py +252 -0
- ct/kb/__init__.py +5 -0
- ct/kb/benchmarks.py +147 -0
- ct/kb/governance.py +106 -0
- ct/kb/ingest.py +415 -0
- ct/kb/reasoning.py +129 -0
- ct/kb/schema_monitor.py +162 -0
- ct/kb/substrate.py +387 -0
- ct/models/__init__.py +0 -0
- ct/models/llm.py +370 -0
- ct/tools/__init__.py +195 -0
- ct/tools/_compound_resolver.py +297 -0
- ct/tools/biomarker.py +368 -0
- ct/tools/cellxgene.py +282 -0
- ct/tools/chemistry.py +1371 -0
- ct/tools/claude.py +390 -0
- ct/tools/clinical.py +1153 -0
- ct/tools/clue.py +249 -0
- ct/tools/code.py +1069 -0
- ct/tools/combination.py +397 -0
- ct/tools/compute.py +402 -0
- ct/tools/cro.py +413 -0
- ct/tools/data_api.py +2114 -0
- ct/tools/design.py +295 -0
- ct/tools/dna.py +575 -0
- ct/tools/experiment.py +604 -0
- ct/tools/expression.py +655 -0
- ct/tools/files.py +957 -0
- ct/tools/genomics.py +1387 -0
- ct/tools/http_client.py +146 -0
- ct/tools/imaging.py +319 -0
- ct/tools/intel.py +223 -0
- ct/tools/literature.py +743 -0
- ct/tools/network.py +422 -0
- ct/tools/notification.py +111 -0
- ct/tools/omics.py +3330 -0
- ct/tools/ops.py +1230 -0
- ct/tools/parity.py +649 -0
- ct/tools/pk.py +245 -0
- ct/tools/protein.py +678 -0
- ct/tools/regulatory.py +643 -0
- ct/tools/remote_data.py +179 -0
- ct/tools/report.py +181 -0
- ct/tools/repurposing.py +376 -0
- ct/tools/safety.py +1280 -0
- ct/tools/shell.py +178 -0
- ct/tools/singlecell.py +533 -0
- ct/tools/statistics.py +552 -0
- ct/tools/structure.py +882 -0
- ct/tools/target.py +901 -0
- ct/tools/translational.py +123 -0
- ct/tools/viability.py +218 -0
- ct/ui/__init__.py +0 -0
- ct/ui/markdown.py +31 -0
- ct/ui/status.py +258 -0
- ct/ui/suggestions.py +567 -0
- ct/ui/terminal.py +1456 -0
- ct/ui/traces.py +112 -0
ct/agent/runner.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AgentRunner: query entry point using the Claude Agent SDK.
|
|
3
|
+
|
|
4
|
+
Replaces the Plan-then-Execute architecture (Planner → Executor → Synthesis)
|
|
5
|
+
with a single agentic loop where Claude directly orchestrates all domain tools.
|
|
6
|
+
|
|
7
|
+
Uses ``ClaudeSDKClient`` (not ``query()``) because only the client supports
|
|
8
|
+
custom MCP tools.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import time
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
17
|
+
from ct.agent.types import ExecutionResult, Plan, Step
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("ct.runner")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ------------------------------------------------------------------
|
|
23
|
+
# Testable message processing (extracted from _run_async)
|
|
24
|
+
# ------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
async def process_messages(
|
|
27
|
+
messages_iter,
|
|
28
|
+
trace_renderer=None,
|
|
29
|
+
headless=False,
|
|
30
|
+
trace_events: list[dict] | None = None,
|
|
31
|
+
thinking_status=None,
|
|
32
|
+
runner=None,
|
|
33
|
+
on_activity=None,
|
|
34
|
+
):
|
|
35
|
+
"""Process an async iterable of SDK messages into structured results.
|
|
36
|
+
|
|
37
|
+
This is extracted from ``AgentRunner._run_async`` so it can be tested
|
|
38
|
+
with mock message streams without a live SDK client.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
messages_iter: Async iterable of SDK messages.
|
|
42
|
+
trace_renderer: Optional TraceRenderer for console output.
|
|
43
|
+
headless: If True, suppress console output.
|
|
44
|
+
trace_events: Optional list to append trace events to. When provided,
|
|
45
|
+
each TextBlock, ToolUseBlock, and ToolResultBlock produces a
|
|
46
|
+
trace event dict for downstream notebook/export consumers.
|
|
47
|
+
thinking_status: Optional ThinkingStatus to stop on first message.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
dict with keys: full_text, tool_calls, result_msg, streamed_len
|
|
51
|
+
"""
|
|
52
|
+
# Lazy imports — these may not be available in unit tests without
|
|
53
|
+
# the SDK installed, but callers pass mock objects anyway.
|
|
54
|
+
try:
|
|
55
|
+
from claude_agent_sdk import (
|
|
56
|
+
AssistantMessage,
|
|
57
|
+
ResultMessage,
|
|
58
|
+
TextBlock,
|
|
59
|
+
ToolUseBlock,
|
|
60
|
+
ToolResultBlock,
|
|
61
|
+
StreamEvent,
|
|
62
|
+
)
|
|
63
|
+
except ImportError:
|
|
64
|
+
from claude_agent_sdk import (
|
|
65
|
+
AssistantMessage,
|
|
66
|
+
ResultMessage,
|
|
67
|
+
TextBlock,
|
|
68
|
+
ToolUseBlock,
|
|
69
|
+
)
|
|
70
|
+
ToolResultBlock = None
|
|
71
|
+
StreamEvent = None
|
|
72
|
+
|
|
73
|
+
full_text: list[str] = []
|
|
74
|
+
tool_calls: list[dict] = []
|
|
75
|
+
inflight: dict[str, dict] = {} # tool_use_id → {name, input, start_time}
|
|
76
|
+
result_msg = None
|
|
77
|
+
streamed_len = 0 # characters already displayed via StreamEvent
|
|
78
|
+
|
|
79
|
+
async for message in messages_iter:
|
|
80
|
+
|
|
81
|
+
# --- StreamEvent (partial streaming) ---
|
|
82
|
+
if StreamEvent is not None and isinstance(message, StreamEvent):
|
|
83
|
+
event = getattr(message, "event", None) or {}
|
|
84
|
+
if isinstance(event, dict):
|
|
85
|
+
delta = event.get("delta", {})
|
|
86
|
+
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
|
87
|
+
text = delta.get("text", "")
|
|
88
|
+
if text:
|
|
89
|
+
# Track streamed length but don't print raw text —
|
|
90
|
+
# the full TextBlock will be rendered as markdown
|
|
91
|
+
streamed_len += len(text)
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
# --- AssistantMessage ---
|
|
95
|
+
if isinstance(message, AssistantMessage):
|
|
96
|
+
for block in (message.content or []):
|
|
97
|
+
if isinstance(block, TextBlock):
|
|
98
|
+
# Stop the spinner when showing complete text block
|
|
99
|
+
if thinking_status is not None:
|
|
100
|
+
thinking_status.stop()
|
|
101
|
+
thinking_status = None
|
|
102
|
+
if runner is not None:
|
|
103
|
+
runner._active_spinner = None
|
|
104
|
+
|
|
105
|
+
text = block.text or ""
|
|
106
|
+
full_text.append(text)
|
|
107
|
+
# Trace capture
|
|
108
|
+
if trace_events is not None and text.strip():
|
|
109
|
+
trace_events.append({
|
|
110
|
+
"type": "text",
|
|
111
|
+
"content": text,
|
|
112
|
+
"timestamp": time.time(),
|
|
113
|
+
})
|
|
114
|
+
# Render as markdown (streamed deltas are tracked but not printed)
|
|
115
|
+
if not headless and trace_renderer:
|
|
116
|
+
streamed_len = 0 # reset for next turn
|
|
117
|
+
trace_renderer.render_reasoning(text)
|
|
118
|
+
# Activity callback — show snippet of reasoning
|
|
119
|
+
if on_activity and text.strip():
|
|
120
|
+
snippet = text.strip().replace("\n", " ")[:40]
|
|
121
|
+
on_activity(snippet)
|
|
122
|
+
|
|
123
|
+
elif isinstance(block, ToolUseBlock):
|
|
124
|
+
# Restart spinner while waiting for tool result
|
|
125
|
+
if thinking_status is None and not headless and trace_renderer:
|
|
126
|
+
try:
|
|
127
|
+
from ct.ui.status import ThinkingStatus
|
|
128
|
+
thinking_status = ThinkingStatus(trace_renderer.console, phase="evaluating")
|
|
129
|
+
thinking_status.__enter__()
|
|
130
|
+
thinking_status.start_async_refresh()
|
|
131
|
+
if runner is not None:
|
|
132
|
+
runner._active_spinner = thinking_status
|
|
133
|
+
except ImportError:
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
block_id = getattr(block, "id", "") or ""
|
|
137
|
+
now = time.time()
|
|
138
|
+
inflight[block_id] = {
|
|
139
|
+
"name": block.name,
|
|
140
|
+
"input": block.input,
|
|
141
|
+
"start_time": now,
|
|
142
|
+
}
|
|
143
|
+
tool_calls.append({
|
|
144
|
+
"name": block.name,
|
|
145
|
+
"input": block.input,
|
|
146
|
+
})
|
|
147
|
+
# Trace capture
|
|
148
|
+
if trace_events is not None:
|
|
149
|
+
trace_events.append({
|
|
150
|
+
"type": "tool_start",
|
|
151
|
+
"tool": block.name.replace("mcp__ct-tools__", ""),
|
|
152
|
+
"input": block.input,
|
|
153
|
+
"tool_use_id": block_id,
|
|
154
|
+
"timestamp": now,
|
|
155
|
+
})
|
|
156
|
+
if not headless and trace_renderer:
|
|
157
|
+
trace_renderer.render_tool_start(block.name, block.input)
|
|
158
|
+
# Activity callback — show tool name
|
|
159
|
+
if on_activity:
|
|
160
|
+
clean = block.name.replace("mcp__ct-tools__", "")
|
|
161
|
+
on_activity(f"\u25b8 {clean}")
|
|
162
|
+
|
|
163
|
+
elif ToolResultBlock is not None and isinstance(block, ToolResultBlock):
|
|
164
|
+
tool_use_id = getattr(block, "tool_use_id", "") or ""
|
|
165
|
+
is_error = getattr(block, "is_error", False)
|
|
166
|
+
|
|
167
|
+
# Extract result text from content
|
|
168
|
+
content = getattr(block, "content", None)
|
|
169
|
+
result_text = ""
|
|
170
|
+
if isinstance(content, list):
|
|
171
|
+
for item in content:
|
|
172
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
173
|
+
result_text += item.get("text", "")
|
|
174
|
+
elif isinstance(content, str):
|
|
175
|
+
result_text = content
|
|
176
|
+
|
|
177
|
+
# Match to inflight tracker
|
|
178
|
+
tracked = inflight.pop(tool_use_id, None)
|
|
179
|
+
duration = 0.0
|
|
180
|
+
tool_name = ""
|
|
181
|
+
tool_input = {}
|
|
182
|
+
if tracked:
|
|
183
|
+
duration = time.time() - tracked["start_time"]
|
|
184
|
+
tool_name = tracked["name"]
|
|
185
|
+
tool_input = tracked["input"]
|
|
186
|
+
else:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"Orphan ToolResultBlock with tool_use_id=%s",
|
|
189
|
+
tool_use_id,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Update the matching tool_calls entry with result
|
|
193
|
+
for tc in reversed(tool_calls):
|
|
194
|
+
if tc["name"] == tool_name and "result_text" not in tc:
|
|
195
|
+
tc["result_text"] = result_text
|
|
196
|
+
tc["duration_s"] = duration
|
|
197
|
+
break
|
|
198
|
+
|
|
199
|
+
# Trace capture
|
|
200
|
+
if trace_events is not None:
|
|
201
|
+
clean_tool = tool_name.replace("mcp__ct-tools__", "")
|
|
202
|
+
trace_events.append({
|
|
203
|
+
"type": "tool_result",
|
|
204
|
+
"tool": clean_tool,
|
|
205
|
+
"tool_use_id": tool_use_id,
|
|
206
|
+
"result_text": result_text,
|
|
207
|
+
"is_error": is_error,
|
|
208
|
+
"duration_s": duration,
|
|
209
|
+
"timestamp": time.time(),
|
|
210
|
+
})
|
|
211
|
+
|
|
212
|
+
if not headless and trace_renderer:
|
|
213
|
+
if is_error:
|
|
214
|
+
trace_renderer.render_tool_error(
|
|
215
|
+
tool_name or "unknown", result_text
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
trace_renderer.render_tool_complete(
|
|
219
|
+
tool_name or "unknown",
|
|
220
|
+
tool_input,
|
|
221
|
+
result_text,
|
|
222
|
+
duration,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# --- ResultMessage ---
|
|
226
|
+
elif isinstance(message, ResultMessage):
|
|
227
|
+
# Final message, make sure animation is stopped
|
|
228
|
+
if thinking_status is not None:
|
|
229
|
+
thinking_status.stop()
|
|
230
|
+
thinking_status = None
|
|
231
|
+
if runner is not None:
|
|
232
|
+
runner._active_spinner = None
|
|
233
|
+
|
|
234
|
+
result_msg = message
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
"full_text": full_text,
|
|
238
|
+
"tool_calls": tool_calls,
|
|
239
|
+
"result_msg": result_msg,
|
|
240
|
+
"streamed_len": streamed_len,
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class AgentRunner:
|
|
245
|
+
"""Run queries via the Claude Agent SDK agentic loop.
|
|
246
|
+
|
|
247
|
+
All 192 domain tools are exposed as MCP tools. Claude handles planning,
|
|
248
|
+
execution, error recovery, and synthesis in one conversation.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
session,
|
|
254
|
+
trajectory=None,
|
|
255
|
+
headless: bool = False,
|
|
256
|
+
trace_store=None,
|
|
257
|
+
):
|
|
258
|
+
self.session = session
|
|
259
|
+
self.trajectory = trajectory
|
|
260
|
+
self._headless = headless
|
|
261
|
+
self.trace_store = trace_store
|
|
262
|
+
|
|
263
|
+
# ------------------------------------------------------------------
|
|
264
|
+
# Public API
|
|
265
|
+
# ------------------------------------------------------------------
|
|
266
|
+
|
|
267
|
+
def run(
|
|
268
|
+
self,
|
|
269
|
+
query: str,
|
|
270
|
+
context: dict | None = None,
|
|
271
|
+
progress_callback=None,
|
|
272
|
+
) -> ExecutionResult:
|
|
273
|
+
"""Execute a query synchronously (blocking wrapper around async)."""
|
|
274
|
+
return asyncio.run(
|
|
275
|
+
self._run_async(query, context, progress_callback)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
async def _run_async(
|
|
279
|
+
self,
|
|
280
|
+
query: str,
|
|
281
|
+
context: dict | None = None,
|
|
282
|
+
progress_callback=None,
|
|
283
|
+
) -> ExecutionResult:
|
|
284
|
+
"""Execute a query using the Agent SDK agentic loop.
|
|
285
|
+
|
|
286
|
+
Uses ``ClaudeSDKClient`` (bidirectional client) which supports custom
|
|
287
|
+
MCP tools, unlike ``query()`` which does not.
|
|
288
|
+
"""
|
|
289
|
+
from claude_agent_sdk import (
|
|
290
|
+
ClaudeSDKClient,
|
|
291
|
+
ClaudeAgentOptions,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Start spinner immediately — user should see feedback the moment they hit Enter
|
|
295
|
+
thinking_status = None
|
|
296
|
+
if not self._headless:
|
|
297
|
+
from ct.ui.status import ThinkingStatus
|
|
298
|
+
thinking_status = ThinkingStatus(self.session.console, phase="planning")
|
|
299
|
+
thinking_status.__enter__()
|
|
300
|
+
thinking_status.start_async_refresh()
|
|
301
|
+
self._active_spinner = thinking_status
|
|
302
|
+
from ct.agent.mcp_server import create_ct_mcp_server
|
|
303
|
+
from ct.agent.system_prompt import build_system_prompt
|
|
304
|
+
from ct.ui.traces import TraceRenderer
|
|
305
|
+
|
|
306
|
+
t0 = time.time()
|
|
307
|
+
config = self.session.config
|
|
308
|
+
ctx = context or {}
|
|
309
|
+
|
|
310
|
+
# ----- Build MCP server -----
|
|
311
|
+
exclude_cats = set()
|
|
312
|
+
if not config.get("agent.enable_experimental_tools", False):
|
|
313
|
+
from ct.tools import EXPERIMENTAL_CATEGORIES
|
|
314
|
+
exclude_cats = set(EXPERIMENTAL_CATEGORIES)
|
|
315
|
+
|
|
316
|
+
server, sandbox, tool_names, code_trace_buffer = create_ct_mcp_server(
|
|
317
|
+
self.session,
|
|
318
|
+
exclude_categories=exclude_cats,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# ----- Build system prompt -----
|
|
322
|
+
data_context = None
|
|
323
|
+
data_dir = ctx.get("data_dir")
|
|
324
|
+
if data_dir:
|
|
325
|
+
data_context = f"Data directory: {data_dir}\n"
|
|
326
|
+
config.set("sandbox.extra_read_dirs", str(data_dir))
|
|
327
|
+
|
|
328
|
+
history = None
|
|
329
|
+
if self.trajectory and self.trajectory.turns:
|
|
330
|
+
history = self.trajectory.context_for_planner()
|
|
331
|
+
|
|
332
|
+
system_prompt = build_system_prompt(
|
|
333
|
+
self.session,
|
|
334
|
+
tool_names=tool_names,
|
|
335
|
+
data_context=data_context,
|
|
336
|
+
history=history,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# ----- Configure Agent SDK -----
|
|
340
|
+
model = config.get("llm.model") or "claude-sonnet-4-5-20250929"
|
|
341
|
+
max_turns = int(config.get("agent.max_sdk_turns", 30))
|
|
342
|
+
|
|
343
|
+
allowed_tools = [f"mcp__ct-tools__{name}" for name in tool_names]
|
|
344
|
+
|
|
345
|
+
_STRIP_VARS = {
|
|
346
|
+
"CLAUDECODE",
|
|
347
|
+
"CLAUDE_CODE_SESSION_ID",
|
|
348
|
+
"CLAUDE_CODE_PARENT_SESSION_ID",
|
|
349
|
+
}
|
|
350
|
+
clean_env = {
|
|
351
|
+
k: v for k, v in os.environ.items()
|
|
352
|
+
if k not in _STRIP_VARS
|
|
353
|
+
}
|
|
354
|
+
api_key = config.llm_api_key("anthropic")
|
|
355
|
+
if api_key:
|
|
356
|
+
clean_env["ANTHROPIC_API_KEY"] = api_key
|
|
357
|
+
# Suppress warnings in SDK subprocess (matplotlib, pydeseq2, numpy, etc.)
|
|
358
|
+
clean_env["PYTHONWARNINGS"] = "ignore"
|
|
359
|
+
|
|
360
|
+
# Plan mode: use SDK's built-in plan permission mode.
|
|
361
|
+
# In plan mode, Claude outputs a plan then calls ExitPlanMode.
|
|
362
|
+
# We intercept that to show the plan and ask for approval.
|
|
363
|
+
plan_preview = bool(config.get("agent.plan_preview", False))
|
|
364
|
+
permission_mode = "plan" if (plan_preview and not self._headless) else "bypassPermissions"
|
|
365
|
+
|
|
366
|
+
# Enable streaming for real-time output
|
|
367
|
+
options_kwargs = dict(
|
|
368
|
+
system_prompt=system_prompt,
|
|
369
|
+
model=model,
|
|
370
|
+
max_turns=max_turns,
|
|
371
|
+
mcp_servers={"ct-tools": server},
|
|
372
|
+
allowed_tools=allowed_tools,
|
|
373
|
+
permission_mode=permission_mode,
|
|
374
|
+
env=clean_env,
|
|
375
|
+
hooks={}, # Disable inherited hooks (e.g. from Claude Code)
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if plan_preview and not self._headless:
|
|
379
|
+
options_kwargs["can_use_tool"] = self._plan_approval_hook()
|
|
380
|
+
|
|
381
|
+
# Try to enable partial message streaming (graceful fallback)
|
|
382
|
+
try:
|
|
383
|
+
options = ClaudeAgentOptions(
|
|
384
|
+
include_partial_messages=True,
|
|
385
|
+
**options_kwargs,
|
|
386
|
+
)
|
|
387
|
+
except TypeError:
|
|
388
|
+
# SDK version doesn't support include_partial_messages
|
|
389
|
+
logger.info("SDK does not support include_partial_messages, using non-streaming")
|
|
390
|
+
options = ClaudeAgentOptions(**options_kwargs)
|
|
391
|
+
|
|
392
|
+
# ----- Build user prompt -----
|
|
393
|
+
user_prompt = query
|
|
394
|
+
context_parts = []
|
|
395
|
+
if ctx.get("compound_smiles"):
|
|
396
|
+
context_parts.append(f"Compound SMILES: {ctx['compound_smiles']}")
|
|
397
|
+
if ctx.get("target"):
|
|
398
|
+
context_parts.append(f"Target: {ctx['target']}")
|
|
399
|
+
if ctx.get("indication"):
|
|
400
|
+
context_parts.append(f"Indication: {ctx['indication']}")
|
|
401
|
+
# Inject mention context if present
|
|
402
|
+
if ctx.get("mention_context"):
|
|
403
|
+
context_parts.append(ctx["mention_context"])
|
|
404
|
+
if context_parts:
|
|
405
|
+
user_prompt = query + "\n\nContext:\n" + "\n".join(context_parts)
|
|
406
|
+
|
|
407
|
+
# ----- Create trace renderer -----
|
|
408
|
+
trace_renderer = TraceRenderer(self.session.console)
|
|
409
|
+
|
|
410
|
+
# ----- Prepare trace capture -----
|
|
411
|
+
trace_events: list[dict] | None = None
|
|
412
|
+
if self.trace_store is not None:
|
|
413
|
+
trace_events = []
|
|
414
|
+
|
|
415
|
+
# ----- Run the agentic loop via ClaudeSDKClient -----
|
|
416
|
+
try:
|
|
417
|
+
async with ClaudeSDKClient(options=options) as client:
|
|
418
|
+
await client.query(user_prompt)
|
|
419
|
+
result = await process_messages(
|
|
420
|
+
client.receive_response(),
|
|
421
|
+
trace_renderer=trace_renderer,
|
|
422
|
+
headless=self._headless,
|
|
423
|
+
trace_events=trace_events,
|
|
424
|
+
thinking_status=thinking_status,
|
|
425
|
+
runner=self,
|
|
426
|
+
on_activity=progress_callback,
|
|
427
|
+
)
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.error("Agent SDK query failed: %s\n%s", e, traceback.format_exc())
|
|
430
|
+
duration = time.time() - t0
|
|
431
|
+
return self._make_error_result(query, str(e), duration)
|
|
432
|
+
finally:
|
|
433
|
+
# Ensure animation is cleaned up even on error
|
|
434
|
+
if thinking_status is not None:
|
|
435
|
+
thinking_status.stop()
|
|
436
|
+
|
|
437
|
+
duration = time.time() - t0
|
|
438
|
+
|
|
439
|
+
full_text = result["full_text"]
|
|
440
|
+
tool_calls = result["tool_calls"]
|
|
441
|
+
result_msg = result["result_msg"]
|
|
442
|
+
|
|
443
|
+
# ----- Build ExecutionResult -----
|
|
444
|
+
summary = "\n".join(full_text).strip()
|
|
445
|
+
if not summary:
|
|
446
|
+
summary = "(Agent produced no text output)"
|
|
447
|
+
|
|
448
|
+
answer = None
|
|
449
|
+
if sandbox:
|
|
450
|
+
result_var = sandbox.get_variable("result")
|
|
451
|
+
if isinstance(result_var, dict):
|
|
452
|
+
answer = result_var.get("answer")
|
|
453
|
+
|
|
454
|
+
steps = []
|
|
455
|
+
for i, tc in enumerate(tool_calls, 1):
|
|
456
|
+
step = Step(
|
|
457
|
+
id=i,
|
|
458
|
+
tool=tc["name"].replace("mcp__ct-tools__", ""),
|
|
459
|
+
description=f"Called {tc['name']}",
|
|
460
|
+
tool_args=tc.get("input", {}),
|
|
461
|
+
)
|
|
462
|
+
step.status = "completed"
|
|
463
|
+
steps.append(step)
|
|
464
|
+
|
|
465
|
+
plan = Plan(query=query, steps=steps)
|
|
466
|
+
|
|
467
|
+
cost_usd = 0.0
|
|
468
|
+
if result_msg:
|
|
469
|
+
cost_usd = getattr(result_msg, "total_cost_usd", 0.0) or 0.0
|
|
470
|
+
|
|
471
|
+
exec_result = ExecutionResult(
|
|
472
|
+
plan=plan,
|
|
473
|
+
summary=summary,
|
|
474
|
+
raw_results={"tool_calls": tool_calls, "answer": answer},
|
|
475
|
+
duration_s=duration,
|
|
476
|
+
iterations=1,
|
|
477
|
+
metadata={
|
|
478
|
+
"sdk_cost_usd": cost_usd,
|
|
479
|
+
"sdk_turns": getattr(result_msg, "num_turns", 0) if result_msg else 0,
|
|
480
|
+
"sdk_duration_ms": getattr(result_msg, "duration_ms", 0) if result_msg else 0,
|
|
481
|
+
"tool_call_count": len(tool_calls),
|
|
482
|
+
},
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# ----- Inject tool_result events from code_trace_buffer -----
|
|
486
|
+
# The SDK stream typically does NOT include ToolResultBlock messages,
|
|
487
|
+
# so process_messages() only produces tool_start events for code tools.
|
|
488
|
+
# MCP handlers write structured results (code, stdout, plots) to
|
|
489
|
+
# code_trace_buffer. We match buffer entries to tool_start events
|
|
490
|
+
# by tool name in sequential order, and insert tool_result events
|
|
491
|
+
# immediately after each tool_start.
|
|
492
|
+
if trace_events is not None and trace_events:
|
|
493
|
+
buffer_iter = iter(code_trace_buffer)
|
|
494
|
+
# Also create tool_result events for non-code tools from tool_calls
|
|
495
|
+
non_code_results = {}
|
|
496
|
+
for tc in tool_calls:
|
|
497
|
+
name = tc["name"].replace("mcp__ct-tools__", "")
|
|
498
|
+
if name not in ("run_python", "run_r") and "result_text" in tc:
|
|
499
|
+
key = name + ":" + str(tc.get("input", {}))
|
|
500
|
+
non_code_results[key] = tc
|
|
501
|
+
|
|
502
|
+
enriched: list[dict] = []
|
|
503
|
+
non_code_iter_idx = {} # track which non-code tool_calls we've used
|
|
504
|
+
for event in trace_events:
|
|
505
|
+
enriched.append(event)
|
|
506
|
+
if event.get("type") != "tool_start":
|
|
507
|
+
continue
|
|
508
|
+
|
|
509
|
+
tool = event.get("tool", "")
|
|
510
|
+
tool_use_id = event.get("tool_use_id", "")
|
|
511
|
+
|
|
512
|
+
if tool in ("run_python", "run_r"):
|
|
513
|
+
meta = next(buffer_iter, None)
|
|
514
|
+
if meta:
|
|
515
|
+
enriched.append({
|
|
516
|
+
"type": "tool_result",
|
|
517
|
+
"tool": tool,
|
|
518
|
+
"tool_use_id": tool_use_id,
|
|
519
|
+
"result_text": meta.get("stdout", ""),
|
|
520
|
+
"is_error": bool(meta.get("error")),
|
|
521
|
+
"duration_s": 0.0,
|
|
522
|
+
"code": meta.get("code", ""),
|
|
523
|
+
"stdout": meta.get("stdout", ""),
|
|
524
|
+
"plots": meta.get("plots", []),
|
|
525
|
+
"exports": meta.get("exports", []),
|
|
526
|
+
"error": meta.get("error"),
|
|
527
|
+
"timestamp": time.time(),
|
|
528
|
+
})
|
|
529
|
+
else:
|
|
530
|
+
# For non-code tools, find matching result from tool_calls
|
|
531
|
+
for tc in tool_calls:
|
|
532
|
+
tc_name = tc["name"].replace("mcp__ct-tools__", "")
|
|
533
|
+
if tc_name == tool and "result_text" in tc and not tc.get("_used"):
|
|
534
|
+
tc["_used"] = True
|
|
535
|
+
enriched.append({
|
|
536
|
+
"type": "tool_result",
|
|
537
|
+
"tool": tool,
|
|
538
|
+
"tool_use_id": tool_use_id,
|
|
539
|
+
"result_text": tc["result_text"],
|
|
540
|
+
"is_error": False,
|
|
541
|
+
"duration_s": tc.get("duration_s", 0.0),
|
|
542
|
+
"timestamp": time.time(),
|
|
543
|
+
})
|
|
544
|
+
break
|
|
545
|
+
|
|
546
|
+
trace_events = enriched
|
|
547
|
+
|
|
548
|
+
# ----- Flush trace events -----
|
|
549
|
+
if self.trace_store is not None and trace_events:
|
|
550
|
+
try:
|
|
551
|
+
self.trace_store.add_events(
|
|
552
|
+
trace_events,
|
|
553
|
+
query=query,
|
|
554
|
+
model=model,
|
|
555
|
+
duration_s=duration,
|
|
556
|
+
cost_usd=cost_usd,
|
|
557
|
+
)
|
|
558
|
+
self.trace_store.flush()
|
|
559
|
+
except Exception as e:
|
|
560
|
+
logger.warning("Failed to flush trace: %s", e)
|
|
561
|
+
|
|
562
|
+
if not self._headless and result_msg:
|
|
563
|
+
self._print_usage(result_msg, duration)
|
|
564
|
+
|
|
565
|
+
return exec_result
|
|
566
|
+
|
|
567
|
+
# ------------------------------------------------------------------
|
|
568
|
+
# Plan mode
|
|
569
|
+
# ------------------------------------------------------------------
|
|
570
|
+
|
|
571
|
+
def _plan_approval_hook(self):
|
|
572
|
+
"""Return a can_use_tool callback for SDK plan mode.
|
|
573
|
+
|
|
574
|
+
Intercepts the ExitPlanMode call to show Claude's plan and ask
|
|
575
|
+
for user approval. All other tool calls are auto-allowed.
|
|
576
|
+
"""
|
|
577
|
+
console = self.session.console
|
|
578
|
+
# Shared ref so process_messages can keep it in sync
|
|
579
|
+
self._active_spinner = None
|
|
580
|
+
|
|
581
|
+
async def _hook(tool_name, input_data, context):
|
|
582
|
+
if tool_name == "ExitPlanMode":
|
|
583
|
+
# Stop the spinner so it doesn't interfere with input()
|
|
584
|
+
if self._active_spinner is not None:
|
|
585
|
+
self._active_spinner.stop()
|
|
586
|
+
self._active_spinner = None
|
|
587
|
+
|
|
588
|
+
# Claude is requesting to exit plan mode and start executing
|
|
589
|
+
console.print("\n [bold cyan]Proposed Plan[/bold cyan]")
|
|
590
|
+
# The plan text may be in the input data or in Claude's
|
|
591
|
+
# preceding text output (which the user already saw streamed).
|
|
592
|
+
if isinstance(input_data, dict):
|
|
593
|
+
for key in ("plan", "description", "summary"):
|
|
594
|
+
if key in input_data and input_data[key]:
|
|
595
|
+
console.print(f" {input_data[key]}")
|
|
596
|
+
break
|
|
597
|
+
console.print()
|
|
598
|
+
|
|
599
|
+
try:
|
|
600
|
+
answer = input(" Execute this plan? [Y/n] ").strip().lower()
|
|
601
|
+
except (EOFError, KeyboardInterrupt):
|
|
602
|
+
answer = "n"
|
|
603
|
+
|
|
604
|
+
if answer in ("", "y", "yes"):
|
|
605
|
+
return {"allow": True, "updated_input": input_data}
|
|
606
|
+
else:
|
|
607
|
+
# Ask what to change so Claude can revise the plan
|
|
608
|
+
try:
|
|
609
|
+
feedback = input(" What would you change? ").strip()
|
|
610
|
+
except (EOFError, KeyboardInterrupt):
|
|
611
|
+
feedback = ""
|
|
612
|
+
|
|
613
|
+
msg = f"User rejected the plan. Feedback: {feedback}" if feedback else "User rejected the plan."
|
|
614
|
+
return {"allow": False, "message": msg}
|
|
615
|
+
|
|
616
|
+
# All other tools: allow
|
|
617
|
+
return {"allow": True, "updated_input": input_data}
|
|
618
|
+
|
|
619
|
+
return _hook
|
|
620
|
+
|
|
621
|
+
# ------------------------------------------------------------------
|
|
622
|
+
# Console output helpers
|
|
623
|
+
# ------------------------------------------------------------------
|
|
624
|
+
|
|
625
|
+
def _print_usage(self, result_msg, duration: float):
|
|
626
|
+
"""Print cost and usage summary."""
|
|
627
|
+
cost = getattr(result_msg, "total_cost_usd", 0)
|
|
628
|
+
turns = getattr(result_msg, "num_turns", 0)
|
|
629
|
+
parts = []
|
|
630
|
+
if cost:
|
|
631
|
+
parts.append(f"${cost:.2f}")
|
|
632
|
+
if turns:
|
|
633
|
+
parts.append(f"{turns} turns")
|
|
634
|
+
if duration >= 60:
|
|
635
|
+
mins = int(duration // 60)
|
|
636
|
+
secs = int(duration % 60)
|
|
637
|
+
parts.append(f"{mins}m {secs}s")
|
|
638
|
+
else:
|
|
639
|
+
parts.append(f"{duration:.1f}s")
|
|
640
|
+
self.session.console.print(f"\n [dim]{' | '.join(parts)}[/dim]")
|
|
641
|
+
|
|
642
|
+
# ------------------------------------------------------------------
|
|
643
|
+
# Error handling
|
|
644
|
+
# ------------------------------------------------------------------
|
|
645
|
+
|
|
646
|
+
@staticmethod
|
|
647
|
+
def _make_error_result(query: str, error: str, duration: float) -> ExecutionResult:
|
|
648
|
+
"""Build an ExecutionResult representing a failed query."""
|
|
649
|
+
plan = Plan(query=query, steps=[])
|
|
650
|
+
return ExecutionResult(
|
|
651
|
+
plan=plan,
|
|
652
|
+
summary=f"Agent SDK error: {error}",
|
|
653
|
+
raw_results={"error": error},
|
|
654
|
+
duration_s=duration,
|
|
655
|
+
iterations=1,
|
|
656
|
+
)
|