celltype-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celltype_cli-0.1.0.dist-info/METADATA +267 -0
- celltype_cli-0.1.0.dist-info/RECORD +89 -0
- celltype_cli-0.1.0.dist-info/WHEEL +4 -0
- celltype_cli-0.1.0.dist-info/entry_points.txt +2 -0
- celltype_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- ct/__init__.py +3 -0
- ct/agent/__init__.py +0 -0
- ct/agent/case_studies.py +426 -0
- ct/agent/config.py +523 -0
- ct/agent/doctor.py +544 -0
- ct/agent/knowledge.py +523 -0
- ct/agent/loop.py +99 -0
- ct/agent/mcp_server.py +478 -0
- ct/agent/orchestrator.py +733 -0
- ct/agent/runner.py +656 -0
- ct/agent/sandbox.py +481 -0
- ct/agent/session.py +145 -0
- ct/agent/system_prompt.py +186 -0
- ct/agent/trace_store.py +228 -0
- ct/agent/trajectory.py +169 -0
- ct/agent/types.py +182 -0
- ct/agent/workflows.py +462 -0
- ct/api/__init__.py +1 -0
- ct/api/app.py +211 -0
- ct/api/config.py +120 -0
- ct/api/engine.py +124 -0
- ct/cli.py +1448 -0
- ct/data/__init__.py +0 -0
- ct/data/compute_providers.json +59 -0
- ct/data/cro_database.json +395 -0
- ct/data/downloader.py +238 -0
- ct/data/loaders.py +252 -0
- ct/kb/__init__.py +5 -0
- ct/kb/benchmarks.py +147 -0
- ct/kb/governance.py +106 -0
- ct/kb/ingest.py +415 -0
- ct/kb/reasoning.py +129 -0
- ct/kb/schema_monitor.py +162 -0
- ct/kb/substrate.py +387 -0
- ct/models/__init__.py +0 -0
- ct/models/llm.py +370 -0
- ct/tools/__init__.py +195 -0
- ct/tools/_compound_resolver.py +297 -0
- ct/tools/biomarker.py +368 -0
- ct/tools/cellxgene.py +282 -0
- ct/tools/chemistry.py +1371 -0
- ct/tools/claude.py +390 -0
- ct/tools/clinical.py +1153 -0
- ct/tools/clue.py +249 -0
- ct/tools/code.py +1069 -0
- ct/tools/combination.py +397 -0
- ct/tools/compute.py +402 -0
- ct/tools/cro.py +413 -0
- ct/tools/data_api.py +2114 -0
- ct/tools/design.py +295 -0
- ct/tools/dna.py +575 -0
- ct/tools/experiment.py +604 -0
- ct/tools/expression.py +655 -0
- ct/tools/files.py +957 -0
- ct/tools/genomics.py +1387 -0
- ct/tools/http_client.py +146 -0
- ct/tools/imaging.py +319 -0
- ct/tools/intel.py +223 -0
- ct/tools/literature.py +743 -0
- ct/tools/network.py +422 -0
- ct/tools/notification.py +111 -0
- ct/tools/omics.py +3330 -0
- ct/tools/ops.py +1230 -0
- ct/tools/parity.py +649 -0
- ct/tools/pk.py +245 -0
- ct/tools/protein.py +678 -0
- ct/tools/regulatory.py +643 -0
- ct/tools/remote_data.py +179 -0
- ct/tools/report.py +181 -0
- ct/tools/repurposing.py +376 -0
- ct/tools/safety.py +1280 -0
- ct/tools/shell.py +178 -0
- ct/tools/singlecell.py +533 -0
- ct/tools/statistics.py +552 -0
- ct/tools/structure.py +882 -0
- ct/tools/target.py +901 -0
- ct/tools/translational.py +123 -0
- ct/tools/viability.py +218 -0
- ct/ui/__init__.py +0 -0
- ct/ui/markdown.py +31 -0
- ct/ui/status.py +258 -0
- ct/ui/suggestions.py +567 -0
- ct/ui/terminal.py +1456 -0
- ct/ui/traces.py +112 -0
ct/models/llm.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified LLM client: supports Anthropic, OpenAI, local models, and CellType models.
|
|
3
|
+
|
|
4
|
+
Provides a consistent interface regardless of backend. CellType's own models
|
|
5
|
+
(GlueLM, C2S, etc.) are imported directly as Python modules when available,
|
|
6
|
+
falling back to API calls if served remotely.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Optional, Generator
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("ct.llm")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class LLMResponse:
|
|
20
|
+
"""Standardized response from any LLM backend."""
|
|
21
|
+
content: str
|
|
22
|
+
model: str
|
|
23
|
+
usage: dict = None
|
|
24
|
+
raw: object = None
|
|
25
|
+
content_blocks: list = None # Raw content blocks from API (for tool use)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Pricing per million tokens (USD) — updated Feb 2026
|
|
29
|
+
MODEL_PRICING = {
|
|
30
|
+
# Anthropic
|
|
31
|
+
"claude-sonnet-4-5-20250929": {"input": 3.00, "output": 15.00},
|
|
32
|
+
"claude-haiku-4-5-20251001": {"input": 0.80, "output": 4.00},
|
|
33
|
+
"claude-opus-4-6": {"input": 15.00, "output": 75.00},
|
|
34
|
+
# OpenAI
|
|
35
|
+
"gpt-4o": {"input": 2.50, "output": 10.00},
|
|
36
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class UsageTracker:
|
|
42
|
+
"""Tracks cumulative token usage and cost across LLM calls."""
|
|
43
|
+
calls: list = field(default_factory=list)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def total_input_tokens(self) -> int:
|
|
47
|
+
return sum(c.get("input", 0) for c in self.calls)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def total_output_tokens(self) -> int:
|
|
51
|
+
return sum(c.get("output", 0) for c in self.calls)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def total_tokens(self) -> int:
|
|
55
|
+
return self.total_input_tokens + self.total_output_tokens
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def total_cost(self) -> float:
|
|
59
|
+
return sum(c.get("cost", 0.0) for c in self.calls)
|
|
60
|
+
|
|
61
|
+
def record(self, model: str, usage: dict):
|
|
62
|
+
"""Record a single LLM call's usage."""
|
|
63
|
+
if not usage:
|
|
64
|
+
return
|
|
65
|
+
cost = self._estimate_cost(model, usage)
|
|
66
|
+
self.calls.append({
|
|
67
|
+
"model": model,
|
|
68
|
+
"input": usage.get("input", 0),
|
|
69
|
+
"output": usage.get("output", 0),
|
|
70
|
+
"cost": cost,
|
|
71
|
+
})
|
|
72
|
+
|
|
73
|
+
def _estimate_cost(self, model: str, usage: dict) -> float:
|
|
74
|
+
pricing = MODEL_PRICING.get(model)
|
|
75
|
+
if not pricing:
|
|
76
|
+
return 0.0
|
|
77
|
+
input_cost = (usage.get("input", 0) / 1_000_000) * pricing["input"]
|
|
78
|
+
output_cost = (usage.get("output", 0) / 1_000_000) * pricing["output"]
|
|
79
|
+
return input_cost + output_cost
|
|
80
|
+
|
|
81
|
+
def summary(self) -> str:
|
|
82
|
+
"""Human-readable usage summary."""
|
|
83
|
+
if not self.calls:
|
|
84
|
+
return "No LLM calls made."
|
|
85
|
+
models_used = set(c["model"] for c in self.calls)
|
|
86
|
+
return (
|
|
87
|
+
f"{len(self.calls)} LLM calls | "
|
|
88
|
+
f"{self.total_input_tokens:,} in + {self.total_output_tokens:,} out tokens | "
|
|
89
|
+
f"${self.total_cost:.2f} | "
|
|
90
|
+
f"models: {', '.join(models_used)}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def reset(self):
|
|
94
|
+
self.calls.clear()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LLMClient:
|
|
98
|
+
"""Unified LLM client supporting multiple providers."""
|
|
99
|
+
|
|
100
|
+
# Default models per provider
|
|
101
|
+
DEFAULT_MODELS = {
|
|
102
|
+
"anthropic": "claude-sonnet-4-5-20250929",
|
|
103
|
+
"openai": "gpt-4o",
|
|
104
|
+
"local": None, # User must specify
|
|
105
|
+
"gluelm": None, # CellType's own model
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
def __init__(self, provider: str = "anthropic", model: str = None,
|
|
109
|
+
api_key: str = None):
|
|
110
|
+
self.provider = provider
|
|
111
|
+
self.model = model or self.DEFAULT_MODELS.get(provider)
|
|
112
|
+
self.api_key = api_key
|
|
113
|
+
self._client = None
|
|
114
|
+
self.usage = UsageTracker()
|
|
115
|
+
|
|
116
|
+
def _get_client(self):
|
|
117
|
+
"""Lazily initialize the appropriate client."""
|
|
118
|
+
if self._client is not None:
|
|
119
|
+
return self._client
|
|
120
|
+
|
|
121
|
+
if self.provider == "anthropic":
|
|
122
|
+
import anthropic
|
|
123
|
+
self._client = anthropic.Anthropic(
|
|
124
|
+
api_key=self.api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
elif self.provider == "openai":
|
|
128
|
+
import openai
|
|
129
|
+
self._client = openai.OpenAI(
|
|
130
|
+
api_key=self.api_key or os.environ.get("OPENAI_API_KEY")
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
elif self.provider == "local":
|
|
134
|
+
# Local model via vLLM, ollama, or direct transformers
|
|
135
|
+
self._client = self._init_local()
|
|
136
|
+
|
|
137
|
+
elif self.provider == "gluelm":
|
|
138
|
+
# CellType's own model — direct Python import
|
|
139
|
+
self._client = self._init_gluelm()
|
|
140
|
+
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError(f"Unknown provider: {self.provider}")
|
|
143
|
+
|
|
144
|
+
return self._client
|
|
145
|
+
|
|
146
|
+
def chat(self, system: str, messages: list[dict], temperature: float = 0.1,
|
|
147
|
+
max_tokens: int = 4096, tools: list[dict] | None = None) -> LLMResponse:
|
|
148
|
+
"""Send a chat completion request.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
tools: Optional list of tool definitions (Anthropic tool_use format).
|
|
152
|
+
When provided, the response may contain tool_use content blocks
|
|
153
|
+
accessible via ``response.content_blocks``.
|
|
154
|
+
"""
|
|
155
|
+
client = self._get_client()
|
|
156
|
+
|
|
157
|
+
if self.provider == "anthropic":
|
|
158
|
+
resp = self._chat_anthropic(client, system, messages, temperature, max_tokens, tools=tools)
|
|
159
|
+
elif self.provider == "openai":
|
|
160
|
+
resp = self._chat_openai(client, system, messages, temperature, max_tokens)
|
|
161
|
+
elif self.provider == "local":
|
|
162
|
+
resp = self._chat_local(client, system, messages, temperature, max_tokens)
|
|
163
|
+
elif self.provider == "gluelm":
|
|
164
|
+
resp = self._chat_gluelm(client, system, messages, temperature, max_tokens)
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Unknown provider: {self.provider}")
|
|
167
|
+
|
|
168
|
+
# Track usage
|
|
169
|
+
if resp.usage:
|
|
170
|
+
self.usage.record(resp.model, resp.usage)
|
|
171
|
+
|
|
172
|
+
return resp
|
|
173
|
+
|
|
174
|
+
def stream(self, system: str, messages: list[dict], temperature: float = 0.1,
|
|
175
|
+
max_tokens: int = 4096) -> Generator[str, None, LLMResponse]:
|
|
176
|
+
"""Stream a chat completion, yielding text chunks.
|
|
177
|
+
|
|
178
|
+
Yields individual text deltas. After the generator is exhausted,
|
|
179
|
+
send() returns the final LLMResponse with full content and usage.
|
|
180
|
+
|
|
181
|
+
Usage:
|
|
182
|
+
gen = llm.stream(system, messages)
|
|
183
|
+
chunks = []
|
|
184
|
+
for chunk in gen:
|
|
185
|
+
print(chunk, end="", flush=True)
|
|
186
|
+
chunks.append(chunk)
|
|
187
|
+
# Full response available after iteration
|
|
188
|
+
"""
|
|
189
|
+
client = self._get_client()
|
|
190
|
+
|
|
191
|
+
if self.provider == "anthropic":
|
|
192
|
+
yield from self._stream_anthropic(client, system, messages, temperature, max_tokens)
|
|
193
|
+
elif self.provider == "openai":
|
|
194
|
+
yield from self._stream_openai(client, system, messages, temperature, max_tokens)
|
|
195
|
+
else:
|
|
196
|
+
# Fallback: non-streaming providers just yield the full response
|
|
197
|
+
resp = self.chat(system, messages, temperature, max_tokens)
|
|
198
|
+
yield resp.content
|
|
199
|
+
|
|
200
|
+
def _chat_anthropic(self, client, system, messages, temperature, max_tokens, tools=None):
|
|
201
|
+
return self._retry(
|
|
202
|
+
lambda: self._call_anthropic(client, system, messages, temperature, max_tokens, tools=tools)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _call_anthropic(self, client, system, messages, temperature, max_tokens, tools=None):
|
|
206
|
+
kwargs = dict(
|
|
207
|
+
model=self.model,
|
|
208
|
+
system=system,
|
|
209
|
+
messages=messages,
|
|
210
|
+
temperature=temperature,
|
|
211
|
+
max_tokens=max_tokens,
|
|
212
|
+
)
|
|
213
|
+
if tools:
|
|
214
|
+
kwargs["tools"] = tools
|
|
215
|
+
response = client.messages.create(**kwargs)
|
|
216
|
+
# Guard against empty content array (e.g., content filtering)
|
|
217
|
+
if not response.content:
|
|
218
|
+
content_text = ""
|
|
219
|
+
else:
|
|
220
|
+
# Extract text parts only (skip tool_use blocks)
|
|
221
|
+
text_parts = [b.text for b in response.content if hasattr(b, "text")]
|
|
222
|
+
content_text = "\n".join(text_parts) if text_parts else ""
|
|
223
|
+
return LLMResponse(
|
|
224
|
+
content=content_text,
|
|
225
|
+
model=self.model,
|
|
226
|
+
usage={"input": response.usage.input_tokens, "output": response.usage.output_tokens},
|
|
227
|
+
raw=response,
|
|
228
|
+
content_blocks=list(response.content) if response.content else [],
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _retry(self, fn, max_retries: int = 3, base_delay: float = 2.0):
|
|
232
|
+
"""Retry a function with exponential backoff on transient errors."""
|
|
233
|
+
for attempt in range(1, max_retries + 1):
|
|
234
|
+
try:
|
|
235
|
+
return fn()
|
|
236
|
+
except Exception as e:
|
|
237
|
+
err_str = str(e).lower()
|
|
238
|
+
is_transient = any(w in err_str for w in (
|
|
239
|
+
"rate_limit", "rate limit", "429", "overloaded",
|
|
240
|
+
"529", "500", "502", "503", "connection", "timeout",
|
|
241
|
+
))
|
|
242
|
+
if is_transient and attempt < max_retries:
|
|
243
|
+
delay = base_delay * (2 ** (attempt - 1))
|
|
244
|
+
logger.warning("LLM call failed (attempt %d/%d): %s — retrying in %.1fs",
|
|
245
|
+
attempt, max_retries, e, delay)
|
|
246
|
+
time.sleep(delay)
|
|
247
|
+
else:
|
|
248
|
+
raise
|
|
249
|
+
|
|
250
|
+
def _stream_anthropic(self, client, system, messages, temperature, max_tokens):
|
|
251
|
+
"""Stream from Anthropic API, yielding text deltas."""
|
|
252
|
+
with client.messages.stream(
|
|
253
|
+
model=self.model,
|
|
254
|
+
system=system,
|
|
255
|
+
messages=messages,
|
|
256
|
+
temperature=temperature,
|
|
257
|
+
max_tokens=max_tokens,
|
|
258
|
+
) as stream:
|
|
259
|
+
try:
|
|
260
|
+
for text in stream.text_stream:
|
|
261
|
+
yield text
|
|
262
|
+
finally:
|
|
263
|
+
# Record usage even if stream is interrupted (Ctrl+C)
|
|
264
|
+
try:
|
|
265
|
+
response = stream.get_final_message()
|
|
266
|
+
usage = {"input": response.usage.input_tokens, "output": response.usage.output_tokens}
|
|
267
|
+
self.usage.record(self.model, usage)
|
|
268
|
+
except Exception:
|
|
269
|
+
logger.debug("Could not record usage after stream interrupt")
|
|
270
|
+
|
|
271
|
+
def _stream_openai(self, client, system, messages, temperature, max_tokens):
|
|
272
|
+
"""Stream from OpenAI API, yielding text deltas."""
|
|
273
|
+
all_messages = [{"role": "system", "content": system}] + messages
|
|
274
|
+
stream = client.chat.completions.create(
|
|
275
|
+
model=self.model,
|
|
276
|
+
messages=all_messages,
|
|
277
|
+
temperature=temperature,
|
|
278
|
+
max_tokens=max_tokens,
|
|
279
|
+
stream=True,
|
|
280
|
+
stream_options={"include_usage": True},
|
|
281
|
+
)
|
|
282
|
+
usage = None
|
|
283
|
+
for chunk in stream:
|
|
284
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
285
|
+
yield chunk.choices[0].delta.content
|
|
286
|
+
if chunk.usage:
|
|
287
|
+
usage = {"input": chunk.usage.prompt_tokens, "output": chunk.usage.completion_tokens}
|
|
288
|
+
|
|
289
|
+
if usage:
|
|
290
|
+
self.usage.record(self.model, usage)
|
|
291
|
+
|
|
292
|
+
def _chat_openai(self, client, system, messages, temperature, max_tokens):
|
|
293
|
+
return self._retry(
|
|
294
|
+
lambda: self._call_openai(client, system, messages, temperature, max_tokens)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def _call_openai(self, client, system, messages, temperature, max_tokens):
|
|
298
|
+
all_messages = [{"role": "system", "content": system}] + messages
|
|
299
|
+
response = client.chat.completions.create(
|
|
300
|
+
model=self.model,
|
|
301
|
+
messages=all_messages,
|
|
302
|
+
temperature=temperature,
|
|
303
|
+
max_tokens=max_tokens,
|
|
304
|
+
)
|
|
305
|
+
return LLMResponse(
|
|
306
|
+
content=response.choices[0].message.content,
|
|
307
|
+
model=self.model,
|
|
308
|
+
usage={"input": response.usage.prompt_tokens, "output": response.usage.completion_tokens},
|
|
309
|
+
raw=response,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def _init_local(self):
|
|
313
|
+
"""Initialize local model (vLLM or transformers)."""
|
|
314
|
+
# Try vLLM first (fastest for local inference)
|
|
315
|
+
try:
|
|
316
|
+
from vllm import LLM
|
|
317
|
+
return LLM(model=self.model)
|
|
318
|
+
except ImportError:
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
# Fall back to transformers
|
|
322
|
+
try:
|
|
323
|
+
from transformers import pipeline
|
|
324
|
+
return pipeline("text-generation", model=self.model, device_map="auto")
|
|
325
|
+
except ImportError:
|
|
326
|
+
raise ImportError("Install vllm or transformers for local model support")
|
|
327
|
+
|
|
328
|
+
def _chat_local(self, client, system, messages, temperature, max_tokens):
|
|
329
|
+
"""Chat with local model."""
|
|
330
|
+
# Format for local model
|
|
331
|
+
prompt = f"System: {system}\n\n"
|
|
332
|
+
for msg in messages:
|
|
333
|
+
role = msg["role"].capitalize()
|
|
334
|
+
prompt += f"{role}: {msg['content']}\n\n"
|
|
335
|
+
prompt += "Assistant: "
|
|
336
|
+
|
|
337
|
+
if hasattr(client, 'generate'):
|
|
338
|
+
# vLLM
|
|
339
|
+
from vllm import SamplingParams
|
|
340
|
+
params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
|
|
341
|
+
outputs = client.generate([prompt], params)
|
|
342
|
+
text = outputs[0].outputs[0].text
|
|
343
|
+
else:
|
|
344
|
+
# transformers pipeline
|
|
345
|
+
outputs = client(prompt, max_new_tokens=max_tokens, temperature=temperature)
|
|
346
|
+
text = outputs[0]["generated_text"][len(prompt):]
|
|
347
|
+
|
|
348
|
+
return LLMResponse(content=text, model=self.model or "local")
|
|
349
|
+
|
|
350
|
+
def _init_gluelm(self):
|
|
351
|
+
"""Initialize CellType's GlueLM model."""
|
|
352
|
+
try:
|
|
353
|
+
from gluelm import GlueLMModel
|
|
354
|
+
return GlueLMModel.from_pretrained(self.model)
|
|
355
|
+
except ImportError:
|
|
356
|
+
raise ImportError(
|
|
357
|
+
"GlueLM not installed. Install from CellType/GlueLM or "
|
|
358
|
+
"set llm.provider to 'anthropic' for cloud inference."
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def _chat_gluelm(self, client, system, messages, temperature, max_tokens):
|
|
362
|
+
"""Chat with GlueLM — specialized for degradation queries."""
|
|
363
|
+
# GlueLM is a domain-specific model, not a general chat model
|
|
364
|
+
# Route degradation-specific queries to it, general queries to fallback
|
|
365
|
+
query = messages[-1]["content"] if messages else ""
|
|
366
|
+
result = client.predict(query)
|
|
367
|
+
return LLMResponse(
|
|
368
|
+
content=str(result),
|
|
369
|
+
model="gluelm",
|
|
370
|
+
)
|
ct/tools/__init__.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool registry for ct.
|
|
3
|
+
|
|
4
|
+
Each tool is a Python function decorated with @tool that the agent can invoke.
|
|
5
|
+
Tools are organized by category (target, structure, chemistry, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
import importlib
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Callable, Optional
|
|
12
|
+
from rich.table import Table
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
EXPERIMENTAL_CATEGORIES = frozenset({"compute", "cro"})
|
|
16
|
+
_TOOL_MODULES = (
|
|
17
|
+
"target",
|
|
18
|
+
"structure",
|
|
19
|
+
"chemistry",
|
|
20
|
+
"expression",
|
|
21
|
+
"viability",
|
|
22
|
+
"biomarker",
|
|
23
|
+
"combination",
|
|
24
|
+
"clinical",
|
|
25
|
+
"intel",
|
|
26
|
+
"translational",
|
|
27
|
+
"regulatory",
|
|
28
|
+
"pk",
|
|
29
|
+
"report",
|
|
30
|
+
"literature",
|
|
31
|
+
"safety",
|
|
32
|
+
"cro",
|
|
33
|
+
"compute",
|
|
34
|
+
"experiment",
|
|
35
|
+
"notification",
|
|
36
|
+
"code",
|
|
37
|
+
"files",
|
|
38
|
+
"claude",
|
|
39
|
+
"network",
|
|
40
|
+
"genomics",
|
|
41
|
+
"statistics",
|
|
42
|
+
"repurposing",
|
|
43
|
+
"design",
|
|
44
|
+
"singlecell",
|
|
45
|
+
"protein",
|
|
46
|
+
"imaging",
|
|
47
|
+
"data_api",
|
|
48
|
+
"parity",
|
|
49
|
+
"ops",
|
|
50
|
+
"dna",
|
|
51
|
+
"omics",
|
|
52
|
+
"shell",
|
|
53
|
+
"cellxgene",
|
|
54
|
+
"clue",
|
|
55
|
+
"remote_data",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class Tool:
|
|
61
|
+
"""A registered tool that the agent can invoke."""
|
|
62
|
+
name: str # e.g., "target.neosubstrate_score"
|
|
63
|
+
description: str # Human-readable description
|
|
64
|
+
category: str # e.g., "target", "structure", "chemistry"
|
|
65
|
+
function: Callable # The actual Python function
|
|
66
|
+
parameters: dict = field(default_factory=dict) # Parameter descriptions
|
|
67
|
+
requires_data: list = field(default_factory=list) # Required datasets
|
|
68
|
+
usage_guide: str = "" # When/why to use this tool (injected into planner prompt)
|
|
69
|
+
|
|
70
|
+
def run(self, **kwargs):
|
|
71
|
+
"""Execute the tool."""
|
|
72
|
+
return self.function(**kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ToolRegistry:
|
|
76
|
+
"""Central registry of all available tools."""
|
|
77
|
+
|
|
78
|
+
def __init__(self):
|
|
79
|
+
self._tools: dict[str, Tool] = {}
|
|
80
|
+
|
|
81
|
+
def register(self, name: str, description: str, category: str,
|
|
82
|
+
parameters: dict = None, requires_data: list = None,
|
|
83
|
+
usage_guide: str = ""):
|
|
84
|
+
"""Decorator to register a function as a tool."""
|
|
85
|
+
def decorator(func):
|
|
86
|
+
self._tools[name] = Tool(
|
|
87
|
+
name=name,
|
|
88
|
+
description=description,
|
|
89
|
+
category=category,
|
|
90
|
+
function=func,
|
|
91
|
+
parameters=parameters or {},
|
|
92
|
+
requires_data=requires_data or [],
|
|
93
|
+
usage_guide=usage_guide,
|
|
94
|
+
)
|
|
95
|
+
return func
|
|
96
|
+
return decorator
|
|
97
|
+
|
|
98
|
+
def get_tool(self, name: str) -> Optional[Tool]:
|
|
99
|
+
"""Look up a tool by name."""
|
|
100
|
+
return self._tools.get(name)
|
|
101
|
+
|
|
102
|
+
def list_tools(self, category: str = None) -> list[Tool]:
|
|
103
|
+
"""List all tools, optionally filtered by category."""
|
|
104
|
+
tools = list(self._tools.values())
|
|
105
|
+
if category:
|
|
106
|
+
tools = [t for t in tools if t.category == category]
|
|
107
|
+
return sorted(tools, key=lambda t: t.name)
|
|
108
|
+
|
|
109
|
+
def list_tools_table(self) -> Table:
|
|
110
|
+
"""Render tool list as a rich table."""
|
|
111
|
+
table = Table(title="ct Tools")
|
|
112
|
+
table.add_column("Tool", style="cyan")
|
|
113
|
+
table.add_column("Status")
|
|
114
|
+
table.add_column("Description")
|
|
115
|
+
table.add_column("Data Required", style="dim")
|
|
116
|
+
|
|
117
|
+
for tool in self.list_tools():
|
|
118
|
+
data_str = ", ".join(tool.requires_data) if tool.requires_data else "-"
|
|
119
|
+
if tool.name == "claude.code":
|
|
120
|
+
status = "[yellow]guarded (opt-in)[/yellow]"
|
|
121
|
+
elif tool.category in EXPERIMENTAL_CATEGORIES:
|
|
122
|
+
status = "[yellow]experimental / TODO[/yellow]"
|
|
123
|
+
else:
|
|
124
|
+
status = "[green]stable[/green]"
|
|
125
|
+
table.add_row(tool.name, status, tool.description, data_str)
|
|
126
|
+
|
|
127
|
+
return table
|
|
128
|
+
|
|
129
|
+
def categories(self) -> list[str]:
|
|
130
|
+
"""List all tool categories."""
|
|
131
|
+
return sorted(set(t.category for t in self._tools.values()))
|
|
132
|
+
|
|
133
|
+
def tool_descriptions_for_llm(
|
|
134
|
+
self,
|
|
135
|
+
exclude_categories: set[str] | None = None,
|
|
136
|
+
exclude_tools: set[str] | None = None,
|
|
137
|
+
) -> str:
|
|
138
|
+
"""Generate tool descriptions for the LLM planner."""
|
|
139
|
+
exclude_categories = exclude_categories or set()
|
|
140
|
+
exclude_tools = exclude_tools or set()
|
|
141
|
+
lines = []
|
|
142
|
+
for cat in self.categories():
|
|
143
|
+
if cat in exclude_categories:
|
|
144
|
+
continue
|
|
145
|
+
cat_tools = [t for t in self.list_tools(cat) if t.name not in exclude_tools]
|
|
146
|
+
if not cat_tools:
|
|
147
|
+
continue
|
|
148
|
+
lines.append(f"\n## {cat}")
|
|
149
|
+
for tool in cat_tools:
|
|
150
|
+
params = ", ".join(f"{k}: {v}" for k, v in tool.parameters.items())
|
|
151
|
+
lines.append(f"- **{tool.name}**({params}): {tool.description}")
|
|
152
|
+
if tool.usage_guide:
|
|
153
|
+
lines.append(f" USE WHEN: {tool.usage_guide}")
|
|
154
|
+
if tool.category in EXPERIMENTAL_CATEGORIES:
|
|
155
|
+
lines.append(" NOTE: Experimental/TODO category. Outputs may be placeholder or limited.")
|
|
156
|
+
return "\n".join(lines)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Global registry instance
|
|
160
|
+
registry = ToolRegistry()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# Import tool modules to trigger registration
|
|
164
|
+
def _load_tools():
|
|
165
|
+
"""Import all tool modules to register them."""
|
|
166
|
+
logger = logging.getLogger("ct.tools")
|
|
167
|
+
errors = {}
|
|
168
|
+
|
|
169
|
+
for module_name in _TOOL_MODULES:
|
|
170
|
+
import_name = f"ct.tools.{module_name}"
|
|
171
|
+
try:
|
|
172
|
+
importlib.import_module(import_name)
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
errors[module_name] = str(exc)
|
|
175
|
+
logger.warning("Failed to load tool module %s: %s", import_name, exc)
|
|
176
|
+
|
|
177
|
+
return errors
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# Lazy loading — tools are registered on first access
|
|
181
|
+
_loaded = False
|
|
182
|
+
_load_errors: dict[str, str] = {}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def ensure_loaded():
|
|
186
|
+
global _loaded
|
|
187
|
+
global _load_errors
|
|
188
|
+
if not _loaded:
|
|
189
|
+
_load_errors = _load_tools()
|
|
190
|
+
_loaded = True
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def tool_load_errors() -> dict[str, str]:
|
|
194
|
+
"""Return module import failures from tool loading."""
|
|
195
|
+
return dict(_load_errors)
|