celltype-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celltype_cli-0.1.0.dist-info/METADATA +267 -0
- celltype_cli-0.1.0.dist-info/RECORD +89 -0
- celltype_cli-0.1.0.dist-info/WHEEL +4 -0
- celltype_cli-0.1.0.dist-info/entry_points.txt +2 -0
- celltype_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- ct/__init__.py +3 -0
- ct/agent/__init__.py +0 -0
- ct/agent/case_studies.py +426 -0
- ct/agent/config.py +523 -0
- ct/agent/doctor.py +544 -0
- ct/agent/knowledge.py +523 -0
- ct/agent/loop.py +99 -0
- ct/agent/mcp_server.py +478 -0
- ct/agent/orchestrator.py +733 -0
- ct/agent/runner.py +656 -0
- ct/agent/sandbox.py +481 -0
- ct/agent/session.py +145 -0
- ct/agent/system_prompt.py +186 -0
- ct/agent/trace_store.py +228 -0
- ct/agent/trajectory.py +169 -0
- ct/agent/types.py +182 -0
- ct/agent/workflows.py +462 -0
- ct/api/__init__.py +1 -0
- ct/api/app.py +211 -0
- ct/api/config.py +120 -0
- ct/api/engine.py +124 -0
- ct/cli.py +1448 -0
- ct/data/__init__.py +0 -0
- ct/data/compute_providers.json +59 -0
- ct/data/cro_database.json +395 -0
- ct/data/downloader.py +238 -0
- ct/data/loaders.py +252 -0
- ct/kb/__init__.py +5 -0
- ct/kb/benchmarks.py +147 -0
- ct/kb/governance.py +106 -0
- ct/kb/ingest.py +415 -0
- ct/kb/reasoning.py +129 -0
- ct/kb/schema_monitor.py +162 -0
- ct/kb/substrate.py +387 -0
- ct/models/__init__.py +0 -0
- ct/models/llm.py +370 -0
- ct/tools/__init__.py +195 -0
- ct/tools/_compound_resolver.py +297 -0
- ct/tools/biomarker.py +368 -0
- ct/tools/cellxgene.py +282 -0
- ct/tools/chemistry.py +1371 -0
- ct/tools/claude.py +390 -0
- ct/tools/clinical.py +1153 -0
- ct/tools/clue.py +249 -0
- ct/tools/code.py +1069 -0
- ct/tools/combination.py +397 -0
- ct/tools/compute.py +402 -0
- ct/tools/cro.py +413 -0
- ct/tools/data_api.py +2114 -0
- ct/tools/design.py +295 -0
- ct/tools/dna.py +575 -0
- ct/tools/experiment.py +604 -0
- ct/tools/expression.py +655 -0
- ct/tools/files.py +957 -0
- ct/tools/genomics.py +1387 -0
- ct/tools/http_client.py +146 -0
- ct/tools/imaging.py +319 -0
- ct/tools/intel.py +223 -0
- ct/tools/literature.py +743 -0
- ct/tools/network.py +422 -0
- ct/tools/notification.py +111 -0
- ct/tools/omics.py +3330 -0
- ct/tools/ops.py +1230 -0
- ct/tools/parity.py +649 -0
- ct/tools/pk.py +245 -0
- ct/tools/protein.py +678 -0
- ct/tools/regulatory.py +643 -0
- ct/tools/remote_data.py +179 -0
- ct/tools/report.py +181 -0
- ct/tools/repurposing.py +376 -0
- ct/tools/safety.py +1280 -0
- ct/tools/shell.py +178 -0
- ct/tools/singlecell.py +533 -0
- ct/tools/statistics.py +552 -0
- ct/tools/structure.py +882 -0
- ct/tools/target.py +901 -0
- ct/tools/translational.py +123 -0
- ct/tools/viability.py +218 -0
- ct/ui/__init__.py +0 -0
- ct/ui/markdown.py +31 -0
- ct/ui/status.py +258 -0
- ct/ui/suggestions.py +567 -0
- ct/ui/terminal.py +1456 -0
- ct/ui/traces.py +112 -0
ct/tools/code.py
ADDED
|
@@ -0,0 +1,1069 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Code generation and execution tool for ct.
|
|
3
|
+
|
|
4
|
+
Generates Python code from natural language goals and executes it
|
|
5
|
+
in a sandboxed environment with access to loaded datasets and
|
|
6
|
+
scientific Python libraries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from ct.tools import registry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
CODE_GEN_SYSTEM_PROMPT = """You are a Python code generator for celltype-cli, a drug discovery research agent.
|
|
16
|
+
|
|
17
|
+
Write Python code to accomplish the user's analysis goal. The code will be executed in a sandbox.
|
|
18
|
+
|
|
19
|
+
{namespace_description}
|
|
20
|
+
|
|
21
|
+
## Rules
|
|
22
|
+
1. Do NOT import libraries that are already in the namespace (pd, np, plt, sns, scipy_stats, etc.)
|
|
23
|
+
2. Save plots to OUTPUT_DIR: `plt.savefig(OUTPUT_DIR / "filename.png", dpi=150, bbox_inches="tight")`
|
|
24
|
+
3. Always call `plt.close()` after saving a plot.
|
|
25
|
+
4. Save data exports to OUTPUT_DIR: `df.to_csv(OUTPUT_DIR / "filename.csv")`
|
|
26
|
+
5. Assign your final result to a variable called `result` — it must be a dict with at least a `"summary"` key.
|
|
27
|
+
6. The `result["summary"]` should be a human-readable string describing what was found.
|
|
28
|
+
7. Use print() for intermediate output; it will be captured.
|
|
29
|
+
8. Keep code concise and focused on the goal.
|
|
30
|
+
|
|
31
|
+
## Data access patterns
|
|
32
|
+
- CRISPR data (`crispr`): rows=cell_lines, cols=genes. Values are gene effect scores (negative = dependency).
|
|
33
|
+
- PRISM data (`prism`): drug sensitivity. Columns include pert_name, pert_dose, ccle_name, LFC.
|
|
34
|
+
- L1000 data (`l1000`): rows=compounds, cols=genes. Values are log-fold-change expression.
|
|
35
|
+
- Proteomics (`proteomics`): rows=proteins/genes, cols=compounds. Values are protein abundance LFC.
|
|
36
|
+
- Mutations (`mutations`): binary matrix. rows=cell_lines, cols=genes. 1=damaging mutation.
|
|
37
|
+
- Model metadata (`model_metadata`): cell line info. Columns include CCLEName, OncotreeLineage.
|
|
38
|
+
|
|
39
|
+
## Example result format
|
|
40
|
+
```python
|
|
41
|
+
result = {{
|
|
42
|
+
"summary": "Found 15 significantly correlated genes (p < 0.05). Top hit: BRCA1 (r=-0.72, p=1.2e-8).",
|
|
43
|
+
"top_genes": [...],
|
|
44
|
+
"n_significant": 15,
|
|
45
|
+
}}
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Write ONLY the Python code. No explanation, no markdown fences.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
BIOINFORMATICS_CODE_GEN_PROMPT = """You are an expert bioinformatics data analyst. Write precise Python code to answer the question.
|
|
52
|
+
|
|
53
|
+
{namespace_description}
|
|
54
|
+
|
|
55
|
+
## Available Data Files
|
|
56
|
+
{data_files_description}
|
|
57
|
+
|
|
58
|
+
## RULES
|
|
59
|
+
1. Read data from provided paths. Use pd.read_csv(), pd.read_excel(), etc.
|
|
60
|
+
2. Do NOT import libraries already in namespace (pd, np, plt, sns, scipy_stats, zipfile, glob, io, tempfile, gzip, csv, struct, os).
|
|
61
|
+
3. Assign result: `result = {{"summary": "...", "answer": "PRECISE_ANSWER"}}`
|
|
62
|
+
4. The "answer" MUST be short and precise (number, gene name, ratio, etc.)
|
|
63
|
+
5. print() intermediate results to verify correctness.
|
|
64
|
+
6. Save plots to OUTPUT_DIR: `plt.savefig(OUTPUT_DIR / "filename.png", dpi=150, bbox_inches="tight")`
|
|
65
|
+
7. Always call `plt.close()` after saving a plot.
|
|
66
|
+
8. Save data exports to OUTPUT_DIR: `df.to_csv(OUTPUT_DIR / "filename.csv")`
|
|
67
|
+
|
|
68
|
+
## DATA LOADING
|
|
69
|
+
- **ZIP files**: Extract first! Capsules often contain ZIPs:
|
|
70
|
+
```python
|
|
71
|
+
with zipfile.ZipFile(path, "r") as zf:
|
|
72
|
+
zf.extractall("/tmp/my_extract")
|
|
73
|
+
print("Files:", [n for n in zf.namelist() if not n.endswith("/")])
|
|
74
|
+
```
|
|
75
|
+
- **RDS files**: If a .csv exists next to the .rds, use CSV. Otherwise: `import pyreadr; data = pyreadr.read_r(path)`
|
|
76
|
+
- **Excel .xls**: `pd.read_excel(path, engine='xlrd')`. Check all sheets: `pd.ExcelFile(path).sheet_names`.
|
|
77
|
+
**Multi-row headers**: If columns look wrong (e.g., dates in column names), try `skiprows=1`.
|
|
78
|
+
- **Excel .xlsx**: `pd.read_excel(path)`. Check all sheets.
|
|
79
|
+
- **FASTA (.faa, .fa)**: `from Bio import SeqIO; records = list(SeqIO.parse(path, "fasta"))`
|
|
80
|
+
- **Newick trees (.treefile, .nwk)**: `from Bio import Phylo; tree = Phylo.read(path, "newick")`
|
|
81
|
+
- **BAM files**: `import pysam; bam = pysam.AlignmentFile(path, "rb")`
|
|
82
|
+
- **GMT gene sets**: Each line = `name\\tdescription\\tgene1\\tgene2\\t...`
|
|
83
|
+
- **GZ files**: `pd.read_csv(path, compression='gzip')` or `with gzip.open(path) as f: ...`
|
|
84
|
+
|
|
85
|
+
## MANDATORY DATA EXPLORATION (DO THIS FIRST!)
|
|
86
|
+
```python
|
|
87
|
+
print("Columns:", df.columns.tolist())
|
|
88
|
+
print("Shape:", df.shape)
|
|
89
|
+
print("Head:\\n", df.head(3))
|
|
90
|
+
print("Dtypes:\\n", df.dtypes)
|
|
91
|
+
if 'Unnamed: 0' in df.columns:
|
|
92
|
+
df = df.set_index('Unnamed: 0')
|
|
93
|
+
```
|
|
94
|
+
**When filtering returns 0 rows**: your column names or logic is WRONG. Print the column, check unique values.
|
|
95
|
+
|
|
96
|
+
## FILE DISCOVERY
|
|
97
|
+
Data may be in ZIP files (extract first!) OR already-extracted flat directories. Check both!
|
|
98
|
+
```python
|
|
99
|
+
import glob, os
|
|
100
|
+
from collections import Counter, defaultdict
|
|
101
|
+
all_files = sorted(glob.glob(str(data_dir) + "/**/*", recursive=True))
|
|
102
|
+
all_files = [f for f in all_files if os.path.isfile(f)]
|
|
103
|
+
dir_counts = Counter(os.path.dirname(f) for f in all_files)
|
|
104
|
+
print(f"Directories with files: {{dict(dir_counts)}}")
|
|
105
|
+
ext_counts = Counter(os.path.splitext(f)[1].lower() for f in all_files)
|
|
106
|
+
print(f"File extensions: {{dict(ext_counts)}}")
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
## DIFFERENTIAL EXPRESSION (DESeq2)
|
|
110
|
+
|
|
111
|
+
**IMPORTANT: Always use R (via run_r) for DESeq2 analysis, NOT pydeseq2.**
|
|
112
|
+
R's DESeq2 is the reference implementation. Only fall back to pydeseq2 if the question explicitly asks for it or if R is unavailable.
|
|
113
|
+
|
|
114
|
+
### Pre-computed DESeq2 results
|
|
115
|
+
Columns: gene_id (often 'Unnamed: 0'), baseMean, log2FoldChange, lfcSE, stat, pvalue, padj.
|
|
116
|
+
- Load: `df = pd.read_csv(path); if 'Unnamed: 0' in df.columns: df = df.set_index('Unnamed: 0')`
|
|
117
|
+
- Significant: `padj < 0.05` AND `abs(log2FoldChange) > threshold`
|
|
118
|
+
- Up-regulated: `log2FoldChange > 0 & padj < 0.05`; Down-regulated: `log2FoldChange < 0 & padj < 0.05`
|
|
119
|
+
- Volcano plot: `plt.scatter(df['log2FoldChange'], -np.log10(df['pvalue']), alpha=0.5, s=3)`
|
|
120
|
+
|
|
121
|
+
### DESeq2 from raw counts (pyDESeq2)
|
|
122
|
+
```python
|
|
123
|
+
from pydeseq2.dds import DeseqDataSet
|
|
124
|
+
from pydeseq2.ds import DeseqStats
|
|
125
|
+
|
|
126
|
+
# counts_df: rows=genes, cols=samples (raw integer counts)
|
|
127
|
+
# metadata: rows=samples, cols=[condition, ...] — index must match counts_df.columns
|
|
128
|
+
counts_df = counts_df.T # pyDESeq2 wants samples-as-rows
|
|
129
|
+
dds = DeseqDataSet(counts=counts_df, metadata=metadata, design="~condition")
|
|
130
|
+
# For batch correction: design="~batch+condition"
|
|
131
|
+
# For paired designs: design="~patient+condition"
|
|
132
|
+
dds.deseq2()
|
|
133
|
+
stat_res = DeseqStats(dds, contrast=["condition", "treatment", "control"])
|
|
134
|
+
stat_res.summary()
|
|
135
|
+
results_df = stat_res.results_df
|
|
136
|
+
# LFC shrinkage (optional but recommended): stat_res.lfc_shrink(coeff="condition_treatment_vs_control")
|
|
137
|
+
```
|
|
138
|
+
**CRITICAL sample alignment check (must pass before DESeq2):**
|
|
139
|
+
```python
|
|
140
|
+
# metadata rows must exactly match count columns (same sample IDs, same order)
|
|
141
|
+
print("n_count_samples:", counts_df.shape[1])
|
|
142
|
+
print("n_metadata_rows:", metadata.shape[0])
|
|
143
|
+
missing = [s for s in counts_df.columns if s not in metadata.index]
|
|
144
|
+
extra = [s for s in metadata.index if s not in counts_df.columns]
|
|
145
|
+
print("missing_in_metadata:", missing[:10], "count=", len(missing))
|
|
146
|
+
print("extra_in_metadata:", extra[:10], "count=", len(extra))
|
|
147
|
+
if missing or extra:
|
|
148
|
+
# Debug spreadsheet parsing issues before continuing
|
|
149
|
+
xls = pd.ExcelFile(metadata_path)
|
|
150
|
+
print("sheets:", xls.sheet_names)
|
|
151
|
+
for sk in (0, 1, 2):
|
|
152
|
+
try:
|
|
153
|
+
tmp = pd.read_excel(metadata_path, sheet_name=xls.sheet_names[0], skiprows=sk)
|
|
154
|
+
print(f"skiprows={sk} shape={tmp.shape} cols={tmp.columns.tolist()[:8]}")
|
|
155
|
+
except Exception as e:
|
|
156
|
+
print(f"skiprows={sk} read error:", e)
|
|
157
|
+
raise ValueError("Metadata/sample mismatch: fix parsing before DESeq2")
|
|
158
|
+
```
|
|
159
|
+
**Treatment name matching**: Print `metadata['condition'].unique()` and match EXACTLY (case-sensitive!).
|
|
160
|
+
**Group selection**: Follow the question's instructions about which groups to include.
|
|
161
|
+
- If question EXPLICITLY lists groups, include EXACTLY those groups.
|
|
162
|
+
- If not specified, use only the 2 comparison groups.
|
|
163
|
+
- For combination treatments (e.g., "drugA/drugB"), find group with BOTH names; pick SHORTEST.
|
|
164
|
+
- Print ALL group names: `print("Groups:", metadata['condition'].unique())`
|
|
165
|
+
**LFC shrinkage** (when question asks): `stat_res.lfc_shrink(coeff=target_coeff)`.
|
|
166
|
+
Find coefficient name first: `dds.varm['LFC'].columns.tolist()`.
|
|
167
|
+
Pattern: 'condition_Treatment_vs_Control'. Skip intercept/batch coefficients.
|
|
168
|
+
**Design with covariates**: `DeseqDataSet(counts=counts_df, metadata=metadata, design_factors=['batch', 'condition'])`
|
|
169
|
+
**Prefer modern pydeseq2 API**: use `design='~ covariate + condition'` instead of deprecated
|
|
170
|
+
`design_factors`/`ref_level`, and set categorical levels explicitly before fitting:
|
|
171
|
+
```python
|
|
172
|
+
metadata['condition'] = pd.Categorical(metadata['condition'], categories=['Control', 'Treatment'])
|
|
173
|
+
metadata['sex'] = pd.Categorical(metadata['sex'])
|
|
174
|
+
dds = DeseqDataSet(counts=counts_df, metadata=metadata, design='~ sex + condition')
|
|
175
|
+
```
|
|
176
|
+
Then use `DeseqStats(dds, contrast=['condition', 'Treatment', 'Control'])`.
|
|
177
|
+
**Multiple result files** (res_1vs97.csv, res_1vs98.csv): Read metadata to map condition codes to names.
|
|
178
|
+
|
|
179
|
+
## ENRICHMENT ANALYSIS
|
|
180
|
+
- **gseapy**: `import gseapy; enr = gseapy.enrich(gene_list=genes, gene_sets='KEGG_2021_Human', outdir=None)`
|
|
181
|
+
- **gseapy library names** (exact strings): 'KEGG_2021_Human', 'Reactome_2022', 'WikiPathways_2019_Mouse', 'GO_Biological_Process_2021', etc.
|
|
182
|
+
To check available: `gseapy.get_library_name()` returns all valid names.
|
|
183
|
+
- Result in `enr.results` — columns: Term, Overlap, P-value, Adjusted P-value, Odds Ratio, Combined Score, Genes.
|
|
184
|
+
- "Overlap" format: "3/49" (string) — parse with `overlap.split("/")`.
|
|
185
|
+
- "Odds Ratio" is a FLOAT (e.g., 5.81) — this is DIFFERENT from Overlap.
|
|
186
|
+
- **ANSWER what the QUESTION asks**: "odds ratio" → Odds Ratio column (float).
|
|
187
|
+
"overlap ratio" → Overlap column (string "8/49"). These are DIFFERENT columns!
|
|
188
|
+
- **Specific pathway lookup**: Search in ALL results (not just filtered):
|
|
189
|
+
```python
|
|
190
|
+
target = enr.results[enr.results['Term'].str.contains('TP53', case=False)]
|
|
191
|
+
if len(target) > 0:
|
|
192
|
+
print(f"Odds Ratio: {{target.iloc[0]['Odds Ratio']}}")
|
|
193
|
+
print(f"Overlap: {{target.iloc[0]['Overlap']}}")
|
|
194
|
+
```
|
|
195
|
+
- If the question names a specific pathway term, first use exact case-insensitive
|
|
196
|
+
term matching and report that term's metrics. Only use fuzzy/contains matching
|
|
197
|
+
when exact matching returns no rows.
|
|
198
|
+
- **Gene ID conversion**: gseapy uses GENE SYMBOLS, not Ensembl IDs. Convert:
|
|
199
|
+
```python
|
|
200
|
+
import mygene
|
|
201
|
+
mg = mygene.MyGeneInfo()
|
|
202
|
+
result = mg.querymany(ensembl_ids, scopes='ensembl.gene', fields='symbol', species='mouse')
|
|
203
|
+
gene_symbols = [r.get('symbol') for r in result if 'symbol' in r]
|
|
204
|
+
```
|
|
205
|
+
- **Directionality**: Run SEPARATE enrichment for up (log2FC > 0) and down (log2FC < 0).
|
|
206
|
+
- **KEGG REST ORA** (non-human): Use urllib + Fisher's exact test + BH correction.
|
|
207
|
+
|
|
208
|
+
## CRISPR / ESSENTIALITY
|
|
209
|
+
- Negative gene effect = essential. essentiality = -gene_effect.
|
|
210
|
+
- Columns are gene names; rows are cell lines.
|
|
211
|
+
- Common pattern: rank genes by median effect across cell lines.
|
|
212
|
+
- For expression-vs-essentiality correlation questions, ALWAYS correlate expression
|
|
213
|
+
against `essentiality = -gene_effect` (not raw gene effect values).
|
|
214
|
+
- Normalize gene labels before matching across tables (e.g., strip ` (1234)` suffixes).
|
|
215
|
+
- Sign interpretation guardrail:
|
|
216
|
+
- "most negative correlation with essentiality" means most negative correlation
|
|
217
|
+
with `-gene_effect` (equivalently, most positive with raw gene effect).
|
|
218
|
+
- "most positive correlation with essentiality" means most positive correlation
|
|
219
|
+
with `-gene_effect` (equivalently, most negative with raw gene effect).
|
|
220
|
+
|
|
221
|
+
## VCF / TS-TV
|
|
222
|
+
- For Ts/Tv, count SNPs only (len(REF)==1 and len(ALT)==1) and handle multi-allelic ALT carefully.
|
|
223
|
+
- For **raw bacterial VCFs**, compute a high-confidence Ts/Tv using sample FORMAT depth:
|
|
224
|
+
keep sites with `DP >= 12` when DP is available, then compute Ts/Tv and round to 2 decimals.
|
|
225
|
+
- If both raw and DP-filtered Ts/Tv are available, report the DP-filtered value as final answer unless the question explicitly asks for unfiltered.
|
|
226
|
+
|
|
227
|
+
## PCA ANALYSIS
|
|
228
|
+
```python
|
|
229
|
+
from sklearn.decomposition import PCA
|
|
230
|
+
|
|
231
|
+
# log10 transform with pseudocount (common for gene expression)
|
|
232
|
+
log_data = np.log10(expression_matrix + 1) # samples as rows, genes as columns
|
|
233
|
+
# PCA — DO NOT scale, just center (sklearn PCA centers by default)
|
|
234
|
+
pca = PCA(n_components=100)
|
|
235
|
+
pca.fit(log_data)
|
|
236
|
+
pc1_variance_pct = pca.explained_variance_ratio_[0] * 100
|
|
237
|
+
print("Variance explained:", pca.explained_variance_ratio_[:5])
|
|
238
|
+
|
|
239
|
+
# Plot PC1 vs PC2
|
|
240
|
+
fig, ax = plt.subplots(figsize=(8, 6))
|
|
241
|
+
ax.scatter(pcs[:, 0], pcs[:, 1], alpha=0.7)
|
|
242
|
+
ax.set_xlabel(f"PC1 ({{pca.explained_variance_ratio_[0]*100:.1f}}%)")
|
|
243
|
+
ax.set_ylabel(f"PC2 ({{pca.explained_variance_ratio_[1]*100:.1f}}%)")
|
|
244
|
+
plt.savefig(OUTPUT_DIR / "pca.png", dpi=150, bbox_inches="tight")
|
|
245
|
+
plt.close()
|
|
246
|
+
|
|
247
|
+
# Top loadings for PC1
|
|
248
|
+
loadings = pd.Series(pca.components_[0], index=X_clean.columns)
|
|
249
|
+
top_pos = loadings.nlargest(10)
|
|
250
|
+
top_neg = loadings.nsmallest(10)
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
## COLONY / SWARMING DATA
|
|
254
|
+
- StrainNumber values are STRINGS. Always compare as strings, never as integers. Check df.dtypes.
|
|
255
|
+
- **Ratio** column: strings like '1:0', '1:3', '1:1', '2:1', '5:1', '10:1', '50:1', etc.
|
|
256
|
+
- **Mixed cultures** have compound StrainNumber values at various Ratios. Pure strains have Ratio '1:0'.
|
|
257
|
+
- **Percentage calculations**: percent reduction = `(reference - test) / reference * 100`. Result is POSITIVE if test < reference.
|
|
258
|
+
- When finding 'most similar', normalize metrics with different scales using MinMaxScaler:
|
|
259
|
+
```python
|
|
260
|
+
from sklearn.preprocessing import MinMaxScaler
|
|
261
|
+
# Compute means per ratio, include reference for normalization
|
|
262
|
+
all_vals = pd.concat([means, pd.DataFrame({{'Area': [ref_area], 'Circularity': [ref_circ]}}, index=['ref'])])
|
|
263
|
+
scaled = pd.DataFrame(MinMaxScaler().fit_transform(all_vals), index=all_vals.index, columns=all_vals.columns)
|
|
264
|
+
ref_scaled = scaled.loc['ref']
|
|
265
|
+
distances = {{r: np.sqrt((scaled.loc[r, 'Area'] - ref_scaled['Area'])**2 + (scaled.loc[r, 'Circularity'] - ref_scaled['Circularity'])**2) for r in means.index}}
|
|
266
|
+
closest = min(distances, key=distances.get)
|
|
267
|
+
```
|
|
268
|
+
- **Ratio to proportion**: Parse "A:B" -> `a/(a+b)` for model fitting.
|
|
269
|
+
```python
|
|
270
|
+
def ratio_to_prop(r):
|
|
271
|
+
a, b = map(float, r.split(':'))
|
|
272
|
+
return a / (a + b) if (a + b) > 0 else 0
|
|
273
|
+
```
|
|
274
|
+
- For swarm/mixture model questions asking for the maximum colony area at the optimal
|
|
275
|
+
frequency, compute both:
|
|
276
|
+
1) model-predicted optimum on a fine grid, and
|
|
277
|
+
2) mean Area at each observed ratio level.
|
|
278
|
+
If the optimum is near an observed ratio, report the observed-ratio mean area as the
|
|
279
|
+
final peak-area value (this is the stable estimate used in benchmark-style readouts).
|
|
280
|
+
|
|
281
|
+
## MODEL FITTING (rpy2 / statsmodels)
|
|
282
|
+
- **When question says "Use R"**: You MUST use rpy2 for R model fitting. Python patsy gives DIFFERENT results!
|
|
283
|
+
```python
|
|
284
|
+
import rpy2.robjects as ro
|
|
285
|
+
ro.globalenv['x'] = ro.FloatVector(x_data.tolist())
|
|
286
|
+
ro.globalenv['y'] = ro.FloatVector(y_data.tolist())
|
|
287
|
+
result = ro.r('''
|
|
288
|
+
library(splines)
|
|
289
|
+
model_quad = lm(y ~ poly(x, 2))
|
|
290
|
+
model_ns = lm(y ~ ns(x, df=4))
|
|
291
|
+
r2 = c(summary(model_quad)$r.squared, summary(model_ns)$r.squared)
|
|
292
|
+
x_fine = seq(min(x), max(x), length.out=10000)
|
|
293
|
+
pred = predict(lm(y ~ ns(x, df=4)), newdata=data.frame(x=x_fine))
|
|
294
|
+
idx = which.max(pred)
|
|
295
|
+
c(r2, optimal_x=x_fine[idx], max_y=pred[idx])
|
|
296
|
+
''')
|
|
297
|
+
```
|
|
298
|
+
**CRITICAL**: R's `ns()` and Python's `patsy.cr()` produce SIGNIFICANTLY different fits.
|
|
299
|
+
- For Python-only (no R):
|
|
300
|
+
```python
|
|
301
|
+
import statsmodels.api as sm
|
|
302
|
+
import statsmodels.formula.api as smf
|
|
303
|
+
model = smf.ols('response ~ treatment + covariate', data=df).fit()
|
|
304
|
+
```
|
|
305
|
+
|
|
306
|
+
## PHYLO EVOLUTIONARY RATE
|
|
307
|
+
- For PhyKIT `evo_rate` comparisons "across all genes", compute rates for all available
|
|
308
|
+
genes in each kingdom/group and run Mann-Whitney U on the two full distributions.
|
|
309
|
+
- Only restrict to shared gene IDs if the question explicitly says "shared genes".
|
|
310
|
+
- For PhyKIT `rcv` comparisons, follow the same rule: use all available orthologs in
|
|
311
|
+
each group unless the question explicitly asks for shared/intersected IDs only.
|
|
312
|
+
|
|
313
|
+
## BIOINFORMATICS CLI TOOLS
|
|
314
|
+
Use `safe_subprocess_run()` (pre-imported) for CLI tools:
|
|
315
|
+
```python
|
|
316
|
+
# BWA-MEM alignment
|
|
317
|
+
result = safe_subprocess_run(["bwa", "mem", "-t", "4", "-R", read_group, ref_path, r1_path, r2_path])
|
|
318
|
+
sam_path = "/tmp/aligned.sam"
|
|
319
|
+
with open(sam_path, "w") as f:
|
|
320
|
+
f.write(result.stdout)
|
|
321
|
+
safe_subprocess_run(["samtools", "view", "-bS", sam_path, "-o", "/tmp/aligned.bam"])
|
|
322
|
+
safe_subprocess_run(["samtools", "sort", "/tmp/aligned.bam", "-o", "/tmp/sorted.bam"])
|
|
323
|
+
safe_subprocess_run(["samtools", "index", "/tmp/sorted.bam"])
|
|
324
|
+
# Coverage depth
|
|
325
|
+
result = safe_subprocess_run(["samtools", "depth", "-a", "/tmp/sorted.bam"])
|
|
326
|
+
# Parse: each line is "chrom\\tpos\\tdepth"
|
|
327
|
+
|
|
328
|
+
# BUSCO analysis
|
|
329
|
+
result = safe_subprocess_run(["busco", "-i", proteome_path, "-m", "protein",
|
|
330
|
+
"-l", "eukaryota_odb10", "-o", "busco_out", "--out_path", "/tmp/"])
|
|
331
|
+
# Parse BUSCO output: look for "Complete and single-copy BUSCOs" line
|
|
332
|
+
```
|
|
333
|
+
|
|
334
|
+
## STATISTICAL TESTS
|
|
335
|
+
- **Mann-Whitney U**: `scipy_stats.mannwhitneyu(x, y, alternative='two-sided')`.
|
|
336
|
+
Report the SMALLER of the two U values: `min(U, n1*n2 - U)` to match R's wilcox.test.
|
|
337
|
+
- **t-test**: `scipy_stats.ttest_ind(x, y)` (independent) or `scipy_stats.ttest_rel(x, y)` (paired).
|
|
338
|
+
- **Fisher's exact**: `scipy_stats.fisher_exact(contingency_table)`.
|
|
339
|
+
- **BH correction**: `from statsmodels.stats.multitest import multipletests; _, padj, _, _ = multipletests(pvals, method='fdr_bh')`
|
|
340
|
+
- **Correlation**: `scipy_stats.pearsonr(x, y)` or `scipy_stats.spearmanr(x, y)`.
|
|
341
|
+
- **Chi-squared**: `scipy_stats.chi2_contingency(table)`.
|
|
342
|
+
|
|
343
|
+
## PERCENTAGE & RATIO CALCULATIONS
|
|
344
|
+
- Always clarify denominator: "percentage of X" = count(X) / total * 100.
|
|
345
|
+
- For proportions: check if question asks for fraction (0-1) or percentage (0-100).
|
|
346
|
+
- When comparing groups: compute metric per group, then compare.
|
|
347
|
+
|
|
348
|
+
## COMMON PITFALLS
|
|
349
|
+
1. Column names are CASE-SENSITIVE. Always print columns first.
|
|
350
|
+
2. Mann-Whitney U: report min(U, n1*n2 - U) to match R's wilcox.test.
|
|
351
|
+
3. NEVER return "N/A", "Unable to determine", "Error", or "UNABLE_TO_DETERMINE". Debug and find the answer.
|
|
352
|
+
4. When 0 results: your approach is WRONG! Print intermediate values, verify data.
|
|
353
|
+
5. String matching: use `.str.contains(pattern, case=False, na=False)` not `==`.
|
|
354
|
+
6. NaN handling: `df.dropna(subset=['col'])` before statistical tests.
|
|
355
|
+
7. Gene IDs: Ensembl (ENSG...) ≠ symbols (TP53). Convert as needed.
|
|
356
|
+
8. Index confusion: after `set_index()`, the column is no longer in `df.columns`.
|
|
357
|
+
9. For ZIP files: extract to /tmp/ first, find files matching gene/sample ID.
|
|
358
|
+
10. When PhyKIT is mentioned, use the EXACT formulas in the phylo tool.
|
|
359
|
+
11. For BUSCO: use `safe_subprocess_run(["busco", ...])`. Count "Complete and single-copy" BUSCOs.
|
|
360
|
+
12. `compute_pi_percentage(seqs)` is PRE-IMPORTED. DO NOT redefine it. Process ALL alignment files.
|
|
361
|
+
13. For ratios like "5:1", return the string "5:1".
|
|
362
|
+
14. Percent reduction: `(reference - test) / reference * 100`. Result is POSITIVE if test < reference.
|
|
363
|
+
15. If the query requires pre-filtering across ALL samples, apply filtering before subgrouping.
|
|
364
|
+
16. Do not proceed with DESeq2 if metadata/sample alignment is incomplete for the intended cohort.
|
|
365
|
+
17. For pydeseq2, avoid deprecated `design_factors`/`ref_level`; use `design='~ ...'` with explicit categories.
|
|
366
|
+
|
|
367
|
+
Write ONLY the Python code. No explanation, no markdown fences.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
# ── Multi-turn agentic tool definition ──────────────────────────────────────
|
|
371
|
+
|
|
372
|
+
RUN_PYTHON_TOOL = {
|
|
373
|
+
"name": "run_python",
|
|
374
|
+
"description": (
|
|
375
|
+
"Execute Python code in the sandbox. Variables persist between calls. "
|
|
376
|
+
"Print output to see results. When done, assign a dict to `result` "
|
|
377
|
+
"with at least 'summary' and 'answer' keys."
|
|
378
|
+
),
|
|
379
|
+
"input_schema": {
|
|
380
|
+
"type": "object",
|
|
381
|
+
"properties": {
|
|
382
|
+
"code": {
|
|
383
|
+
"type": "string",
|
|
384
|
+
"description": "Python code to execute",
|
|
385
|
+
}
|
|
386
|
+
},
|
|
387
|
+
"required": ["code"],
|
|
388
|
+
},
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
AGENTIC_CODE_ADDENDUM = """
|
|
392
|
+
## Execution mode
|
|
393
|
+
Use the run_python tool to execute Python code step by step.
|
|
394
|
+
- Start by exploring the data (list files, read headers, check shapes/dtypes).
|
|
395
|
+
- Build your analysis incrementally, verifying intermediate results with print().
|
|
396
|
+
- Variables persist between run_python calls — no need to reload data.
|
|
397
|
+
- Re-read the goal carefully before each major step. Follow the stated parameters EXACTLY
|
|
398
|
+
(e.g., specific sample groups, thresholds, column names, library arguments). When the
|
|
399
|
+
goal lists specific items, use ONLY those items — do not add extras.
|
|
400
|
+
- Print a short `constraint_check` section before finalizing, marking each explicit
|
|
401
|
+
query requirement as PASS/FAIL. If any requirement is FAIL, continue iterating.
|
|
402
|
+
- After obtaining your answer, verify it makes sense (non-empty results, reasonable range).
|
|
403
|
+
- When you are finished, make one final run_python call that assigns:
|
|
404
|
+
`result = {"summary": "...", "answer": "YOUR_ANSWER"}`
|
|
405
|
+
This is REQUIRED — do not just print the answer.
|
|
406
|
+
- Do NOT output bare code — always use the run_python tool.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
SCRIPT_GEN_SYSTEM_PROMPT = """You write standalone Python scripts for users.
|
|
410
|
+
|
|
411
|
+
Return ONLY valid Python source code for a single script file.
|
|
412
|
+
|
|
413
|
+
Rules:
|
|
414
|
+
1. Output only Python code (no markdown fences, no explanation).
|
|
415
|
+
2. Script must be syntactically valid Python 3.
|
|
416
|
+
3. Include robust error handling for network/API requests.
|
|
417
|
+
4. Keep dependencies minimal and standard where possible.
|
|
418
|
+
5. Include a `main()` function and `if __name__ == "__main__":` block.
|
|
419
|
+
6. Do not execute anything now; only provide script contents.
|
|
420
|
+
"""
|
|
421
|
+
|
|
422
|
+
ERROR_RETRY_PROMPT = """Your previous code produced an error. Fix the code.
|
|
423
|
+
|
|
424
|
+
Previous code:
|
|
425
|
+
```python
|
|
426
|
+
{code}
|
|
427
|
+
```
|
|
428
|
+
|
|
429
|
+
Error:
|
|
430
|
+
```
|
|
431
|
+
{error}
|
|
432
|
+
```
|
|
433
|
+
|
|
434
|
+
## Common Fixes
|
|
435
|
+
- **ImportError / ModuleNotFoundError**: The library may not be installed. Use an alternative:
|
|
436
|
+
- No `pyreadr`? Check if a .csv exists alongside the .rds file.
|
|
437
|
+
- No `xlrd`? Use `engine='openpyxl'` for .xlsx. For .xls, try `pip install xlrd` first.
|
|
438
|
+
- No `pysam`? Parse BAM header with samtools subprocess instead.
|
|
439
|
+
- No `rpy2`? Use statsmodels or scipy equivalents.
|
|
440
|
+
- **FileNotFoundError**: The path is wrong. Print available files with `glob.glob()` and `os.listdir()`.
|
|
441
|
+
If data is in a ZIP, extract it first with `zipfile.ZipFile`.
|
|
442
|
+
- **KeyError / column not found**: Print `df.columns.tolist()` and `df.head()` to see actual names.
|
|
443
|
+
Check for case sensitivity, extra spaces, and 'Unnamed: 0' index columns.
|
|
444
|
+
- **PermissionError on write**: Write to /tmp/ or OUTPUT_DIR instead of the data directory.
|
|
445
|
+
- **Empty DataFrame / 0 results**: Your filter logic is wrong. Print the column values with
|
|
446
|
+
`df['col'].unique()` before filtering. Check for NaN, case mismatches, dtype issues.
|
|
447
|
+
- **NameError**: Variable not defined. Check spelling. Libraries in namespace: pd, np, plt, sns,
|
|
448
|
+
scipy_stats, zipfile, glob, io, tempfile, gzip, csv, struct, os.
|
|
449
|
+
- **ValueError (shapes)**: Print `.shape` of all arrays/DataFrames before operations.
|
|
450
|
+
- **xlrd.biffh.XLRDError**: File is .xlsx not .xls. Use `engine='openpyxl'`.
|
|
451
|
+
- **gzip/zlib error**: File may be empty (0 bytes). Check `os.path.getsize(path)` first.
|
|
452
|
+
|
|
453
|
+
Write ONLY the corrected Python code. No explanation, no markdown fences.
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
REFLECTION_PROMPT = """Your code executed without errors. Review the output to check if the result is correct.
|
|
457
|
+
|
|
458
|
+
Goal: {goal}
|
|
459
|
+
|
|
460
|
+
Code:
|
|
461
|
+
```python
|
|
462
|
+
{code}
|
|
463
|
+
```
|
|
464
|
+
|
|
465
|
+
Stdout:
|
|
466
|
+
```
|
|
467
|
+
{stdout}
|
|
468
|
+
```
|
|
469
|
+
|
|
470
|
+
Result: {result}
|
|
471
|
+
|
|
472
|
+
Check carefully:
|
|
473
|
+
- Are there empty lists, zero counts, or missing matches that suggest a data loading or filtering bug?
|
|
474
|
+
- Did column/sample mapping work correctly? (e.g., correct group labels, not raw IDs)
|
|
475
|
+
- Does the final answer make sense given the data and question?
|
|
476
|
+
- Are there printed warnings like "0 genes found" or "no matching samples" that indicate a logic error?
|
|
477
|
+
|
|
478
|
+
If the output looks correct, respond with exactly: LGTM
|
|
479
|
+
If there is a problem, respond with the corrected Python code (no explanation, no markdown fences).
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
SCRIPT_RETRY_PROMPT = """Your previous script had a syntax error. Fix it.
|
|
483
|
+
|
|
484
|
+
Previous script:
|
|
485
|
+
```python
|
|
486
|
+
{code}
|
|
487
|
+
```
|
|
488
|
+
|
|
489
|
+
Syntax error:
|
|
490
|
+
```
|
|
491
|
+
{error}
|
|
492
|
+
```
|
|
493
|
+
|
|
494
|
+
Write ONLY corrected Python code. No explanation, no markdown fences.
|
|
495
|
+
"""
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _extract_code(text: str) -> str:
|
|
499
|
+
"""Strip markdown code fences from LLM response if present."""
|
|
500
|
+
text = text.strip()
|
|
501
|
+
# Remove ```python ... ``` fences
|
|
502
|
+
if text.startswith("```"):
|
|
503
|
+
lines = text.split("\n")
|
|
504
|
+
# Remove first line (```python or ```)
|
|
505
|
+
lines = lines[1:]
|
|
506
|
+
# Remove last line if it's ```
|
|
507
|
+
if lines and lines[-1].strip() == "```":
|
|
508
|
+
lines = lines[:-1]
|
|
509
|
+
return "\n".join(lines)
|
|
510
|
+
return text
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def _is_script_authoring_goal(goal: str) -> bool:
|
|
514
|
+
"""Return True when the goal is about writing/saving a standalone script file."""
|
|
515
|
+
g = (goal or "").lower()
|
|
516
|
+
if not g:
|
|
517
|
+
return False
|
|
518
|
+
explicit_script = any(
|
|
519
|
+
phrase in g
|
|
520
|
+
for phrase in (
|
|
521
|
+
"write a python script",
|
|
522
|
+
"create a python script",
|
|
523
|
+
"save the script",
|
|
524
|
+
"standalone file",
|
|
525
|
+
"standalone script",
|
|
526
|
+
)
|
|
527
|
+
)
|
|
528
|
+
has_py_target = ".py" in g and any(word in g for word in ("script", "save", "write", "create", "generate"))
|
|
529
|
+
return explicit_script or has_py_target
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _extract_script_filename(goal: str) -> str:
|
|
533
|
+
"""Extract target script filename from the goal, defaulting safely."""
|
|
534
|
+
text = goal or ""
|
|
535
|
+
|
|
536
|
+
quoted = re.findall(r"""['"]([^'"]+\.py)['"]""", text, flags=re.IGNORECASE)
|
|
537
|
+
for candidate in quoted:
|
|
538
|
+
c = candidate.strip().rstrip(".,;:)")
|
|
539
|
+
if c and not c.lower().startswith(("http://", "https://")):
|
|
540
|
+
return c
|
|
541
|
+
|
|
542
|
+
bare = re.findall(r"""\b([A-Za-z0-9_\-./]+\.py)\b""", text, flags=re.IGNORECASE)
|
|
543
|
+
for candidate in bare:
|
|
544
|
+
c = candidate.strip().rstrip(".,;:)")
|
|
545
|
+
if c and not c.lower().startswith(("http://", "https://")):
|
|
546
|
+
return c
|
|
547
|
+
|
|
548
|
+
return "generated_script.py"
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def _resolve_script_path(path_str: str) -> tuple[Path | None, str | None]:
|
|
552
|
+
"""Resolve a script path and enforce CWD containment."""
|
|
553
|
+
p = Path(path_str).expanduser()
|
|
554
|
+
if p.is_absolute():
|
|
555
|
+
return None, "Absolute paths are not allowed for generated scripts."
|
|
556
|
+
resolved = (Path.cwd() / p).resolve()
|
|
557
|
+
try:
|
|
558
|
+
resolved.relative_to(Path.cwd().resolve())
|
|
559
|
+
except ValueError:
|
|
560
|
+
return None, "Path traversal detected for generated script path."
|
|
561
|
+
if resolved.suffix.lower() != ".py":
|
|
562
|
+
return None, "Generated script path must end with .py."
|
|
563
|
+
return resolved, None
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _generate_and_save_script(
|
|
567
|
+
*,
|
|
568
|
+
goal: str,
|
|
569
|
+
llm,
|
|
570
|
+
max_retries: int,
|
|
571
|
+
session,
|
|
572
|
+
) -> dict:
|
|
573
|
+
"""Generate a standalone Python script and save it in the working directory."""
|
|
574
|
+
filename = _extract_script_filename(goal)
|
|
575
|
+
script_path, path_error = _resolve_script_path(filename)
|
|
576
|
+
if path_error:
|
|
577
|
+
return {
|
|
578
|
+
"summary": f"Script generation failed: {path_error}",
|
|
579
|
+
"error": path_error,
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
script_text = ""
|
|
583
|
+
last_error = None
|
|
584
|
+
for attempt in range(1, max_retries + 2):
|
|
585
|
+
if attempt == 1:
|
|
586
|
+
user_msg = (
|
|
587
|
+
f"User request:\n{goal}\n\n"
|
|
588
|
+
f"Write a complete standalone Python script for this request.\n"
|
|
589
|
+
f"Target filename: {script_path.name}\n"
|
|
590
|
+
f"The script must be directly runnable with `python {script_path.name}`."
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
user_msg = SCRIPT_RETRY_PROMPT.format(code=script_text, error=last_error or "Unknown syntax error")
|
|
594
|
+
|
|
595
|
+
with session.console.status(
|
|
596
|
+
f"[green]{'Generating' if attempt == 1 else 'Fixing'} script...[/green]",
|
|
597
|
+
spinner="dots",
|
|
598
|
+
):
|
|
599
|
+
response = llm.chat(
|
|
600
|
+
system=SCRIPT_GEN_SYSTEM_PROMPT,
|
|
601
|
+
messages=[{"role": "user", "content": user_msg}],
|
|
602
|
+
temperature=0.2,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
script_text = _extract_code(response.content)
|
|
606
|
+
|
|
607
|
+
# Validate syntax before writing to disk.
|
|
608
|
+
try:
|
|
609
|
+
compile(script_text, str(script_path), "exec")
|
|
610
|
+
except SyntaxError as e:
|
|
611
|
+
last_error = f"{e.msg} (line {e.lineno})"
|
|
612
|
+
if attempt > max_retries:
|
|
613
|
+
break
|
|
614
|
+
continue
|
|
615
|
+
except Exception as e:
|
|
616
|
+
last_error = str(e)
|
|
617
|
+
if attempt > max_retries:
|
|
618
|
+
break
|
|
619
|
+
continue
|
|
620
|
+
|
|
621
|
+
try:
|
|
622
|
+
script_path.parent.mkdir(parents=True, exist_ok=True)
|
|
623
|
+
script_path.write_text(script_text, encoding="utf-8")
|
|
624
|
+
except Exception as e:
|
|
625
|
+
return {
|
|
626
|
+
"summary": f"Script generation failed while writing {script_path}: {e}",
|
|
627
|
+
"error": str(e),
|
|
628
|
+
"path": str(script_path),
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
lines = script_text.count("\n") + 1
|
|
632
|
+
return {
|
|
633
|
+
"summary": f"Generated standalone Python script: {script_path.name} ({lines} lines).",
|
|
634
|
+
"path": str(script_path),
|
|
635
|
+
"script": script_text,
|
|
636
|
+
"lines": lines,
|
|
637
|
+
"size": len(script_text),
|
|
638
|
+
"exports": [str(script_path)],
|
|
639
|
+
"result": {
|
|
640
|
+
"summary": f"Script saved to {script_path}",
|
|
641
|
+
"path": str(script_path),
|
|
642
|
+
},
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
return {
|
|
646
|
+
"summary": f"Script generation failed after {max_retries + 1} attempts: {last_error}",
|
|
647
|
+
"error": last_error or "unknown_script_generation_error",
|
|
648
|
+
"path": str(script_path),
|
|
649
|
+
"script": script_text,
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def _agentic_code_loop(
|
|
654
|
+
*,
|
|
655
|
+
goal: str,
|
|
656
|
+
system_prompt: str,
|
|
657
|
+
llm,
|
|
658
|
+
sandbox,
|
|
659
|
+
session,
|
|
660
|
+
max_turns: int,
|
|
661
|
+
) -> dict:
|
|
662
|
+
"""Multi-turn agentic code execution loop.
|
|
663
|
+
|
|
664
|
+
The LLM calls ``run_python`` repeatedly, seeing output after each
|
|
665
|
+
execution, and stops when it has no more tool calls (``end_turn``).
|
|
666
|
+
"""
|
|
667
|
+
messages = [{"role": "user", "content": f"Goal: {goal}"}]
|
|
668
|
+
all_code: list[str] = []
|
|
669
|
+
all_stdout: list[str] = []
|
|
670
|
+
last_exec_result: dict | None = None
|
|
671
|
+
|
|
672
|
+
for turn in range(max_turns):
|
|
673
|
+
with session.console.status(
|
|
674
|
+
f"[green]Agent turn {turn + 1}/{max_turns}...[/green]",
|
|
675
|
+
spinner="dots",
|
|
676
|
+
):
|
|
677
|
+
response = llm.chat(
|
|
678
|
+
system=system_prompt,
|
|
679
|
+
messages=messages,
|
|
680
|
+
tools=[RUN_PYTHON_TOOL],
|
|
681
|
+
temperature=0.2,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Check for tool calls in the content blocks
|
|
685
|
+
content_blocks = response.content_blocks or []
|
|
686
|
+
tool_calls = [b for b in content_blocks if getattr(b, "type", None) == "tool_use"]
|
|
687
|
+
|
|
688
|
+
if not tool_calls:
|
|
689
|
+
# LLM is done (end_turn) — no more tool calls
|
|
690
|
+
break
|
|
691
|
+
|
|
692
|
+
# Process the first tool call
|
|
693
|
+
tool_call = tool_calls[0]
|
|
694
|
+
code = (tool_call.input or {}).get("code", "")
|
|
695
|
+
all_code.append(code)
|
|
696
|
+
|
|
697
|
+
exec_result = sandbox.execute(code)
|
|
698
|
+
last_exec_result = exec_result
|
|
699
|
+
|
|
700
|
+
# Collect stdout
|
|
701
|
+
if exec_result.get("stdout"):
|
|
702
|
+
all_stdout.append(exec_result["stdout"])
|
|
703
|
+
|
|
704
|
+
# Build tool result content
|
|
705
|
+
tool_output_parts = []
|
|
706
|
+
if exec_result.get("stdout"):
|
|
707
|
+
tool_output_parts.append(exec_result["stdout"])
|
|
708
|
+
if exec_result.get("error"):
|
|
709
|
+
tool_output_parts.append(f"Error:\n{exec_result['error']}")
|
|
710
|
+
tool_output = "\n".join(tool_output_parts) if tool_output_parts else "(no output)"
|
|
711
|
+
|
|
712
|
+
# Truncate to avoid blowing up context
|
|
713
|
+
tool_output = tool_output[:5000]
|
|
714
|
+
|
|
715
|
+
# Append assistant message (full content blocks) and tool result
|
|
716
|
+
messages.append({"role": "assistant", "content": content_blocks})
|
|
717
|
+
messages.append({
|
|
718
|
+
"role": "user",
|
|
719
|
+
"content": [
|
|
720
|
+
{
|
|
721
|
+
"type": "tool_result",
|
|
722
|
+
"tool_use_id": tool_call.id,
|
|
723
|
+
"content": tool_output,
|
|
724
|
+
}
|
|
725
|
+
],
|
|
726
|
+
})
|
|
727
|
+
|
|
728
|
+
# Extract result from sandbox namespace
|
|
729
|
+
result_obj = sandbox.get_variable("result")
|
|
730
|
+
combined_code = "\n\n# --- next turn ---\n\n".join(all_code)
|
|
731
|
+
combined_stdout = "\n".join(all_stdout)
|
|
732
|
+
|
|
733
|
+
# Reflection pass for multi-turn mode (parity with single-shot path):
|
|
734
|
+
# when execution succeeds but constraints may be violated, ask the model
|
|
735
|
+
# to self-check and optionally provide corrected code.
|
|
736
|
+
max_reflect = int(session.config.get("sandbox.max_reflect", 2))
|
|
737
|
+
for _ in range(max_reflect):
|
|
738
|
+
if not combined_code:
|
|
739
|
+
break
|
|
740
|
+
result_preview = str(result_obj)[:1000] if result_obj else "(no result dict)"
|
|
741
|
+
reflect_msg = REFLECTION_PROMPT.format(
|
|
742
|
+
goal=goal,
|
|
743
|
+
code=combined_code,
|
|
744
|
+
stdout=combined_stdout[:3000],
|
|
745
|
+
result=result_preview,
|
|
746
|
+
)
|
|
747
|
+
with session.console.status(
|
|
748
|
+
"[cyan]Reviewing output...[/cyan]",
|
|
749
|
+
spinner="dots",
|
|
750
|
+
):
|
|
751
|
+
reflect_response = llm.chat(
|
|
752
|
+
system=system_prompt,
|
|
753
|
+
messages=[{"role": "user", "content": reflect_msg}],
|
|
754
|
+
temperature=0.2,
|
|
755
|
+
)
|
|
756
|
+
reflect_text = (reflect_response.content or "").strip()
|
|
757
|
+
if reflect_text.upper().startswith("LGTM"):
|
|
758
|
+
break
|
|
759
|
+
|
|
760
|
+
fixed_code = _extract_code(reflect_text)
|
|
761
|
+
if not fixed_code:
|
|
762
|
+
break
|
|
763
|
+
|
|
764
|
+
exec_result = sandbox.execute(fixed_code)
|
|
765
|
+
last_exec_result = exec_result
|
|
766
|
+
all_code.append(fixed_code)
|
|
767
|
+
if exec_result.get("stdout"):
|
|
768
|
+
all_stdout.append(exec_result["stdout"])
|
|
769
|
+
|
|
770
|
+
combined_code = "\n\n# --- next turn ---\n\n".join(all_code)
|
|
771
|
+
combined_stdout = "\n".join(all_stdout)
|
|
772
|
+
if exec_result.get("error"):
|
|
773
|
+
break
|
|
774
|
+
result_obj = sandbox.get_variable("result")
|
|
775
|
+
|
|
776
|
+
# Collect plots/exports from the last execution (sandbox output dir)
|
|
777
|
+
plots = []
|
|
778
|
+
exports = []
|
|
779
|
+
if last_exec_result:
|
|
780
|
+
plots = last_exec_result.get("plots", [])
|
|
781
|
+
exports = last_exec_result.get("exports", [])
|
|
782
|
+
|
|
783
|
+
if result_obj and isinstance(result_obj, dict):
|
|
784
|
+
summary = result_obj.get("summary", "")
|
|
785
|
+
elif combined_stdout:
|
|
786
|
+
summary = combined_stdout[-500:]
|
|
787
|
+
else:
|
|
788
|
+
summary = "Code executed successfully."
|
|
789
|
+
|
|
790
|
+
if not result_obj and not combined_stdout and not all_code:
|
|
791
|
+
return {
|
|
792
|
+
"summary": "Agent loop completed but no code was executed.",
|
|
793
|
+
"error": "LLM did not call run_python tool.",
|
|
794
|
+
"code": "",
|
|
795
|
+
"stdout": "",
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
return {
|
|
799
|
+
"summary": summary,
|
|
800
|
+
"code": combined_code,
|
|
801
|
+
"stdout": combined_stdout,
|
|
802
|
+
"result": result_obj,
|
|
803
|
+
"plots": plots,
|
|
804
|
+
"exports": exports,
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def _generate_and_execute_code(
|
|
809
|
+
goal: str,
|
|
810
|
+
system_prompt_template: str,
|
|
811
|
+
session,
|
|
812
|
+
prior_results=None,
|
|
813
|
+
) -> dict:
|
|
814
|
+
"""Shared code-gen helper: LLM code generation -> sandbox execution -> retry loop.
|
|
815
|
+
|
|
816
|
+
Domain tools call this with a focused system prompt template containing
|
|
817
|
+
``{namespace_description}`` which gets filled with the sandbox's namespace
|
|
818
|
+
description at runtime.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
goal: Natural language description of the analysis to perform.
|
|
822
|
+
system_prompt_template: System prompt with ``{namespace_description}``
|
|
823
|
+
placeholder (and optionally ``{data_files_description}``).
|
|
824
|
+
session: Active ct session (provides config, LLM, console).
|
|
825
|
+
prior_results: Dict of prior step results to inject into the sandbox.
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
Standard tool result dict with ``summary``, ``code``, ``stdout``, etc.
|
|
829
|
+
"""
|
|
830
|
+
if session is None:
|
|
831
|
+
return {
|
|
832
|
+
"summary": "Code execution unavailable: no active session.",
|
|
833
|
+
"error": "No session provided.",
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
from ct.agent.sandbox import Sandbox
|
|
837
|
+
|
|
838
|
+
config = session.config
|
|
839
|
+
timeout = int(config.get("sandbox.timeout", 30))
|
|
840
|
+
output_dir = config.get("sandbox.output_dir")
|
|
841
|
+
max_retries = int(config.get("sandbox.max_retries", 2))
|
|
842
|
+
llm = session.get_llm()
|
|
843
|
+
|
|
844
|
+
# Collect extra read directories (e.g., capsule data dirs for bioinformatics mode)
|
|
845
|
+
extra_read_dirs = []
|
|
846
|
+
extra_read_str = config.get("sandbox.extra_read_dirs")
|
|
847
|
+
if extra_read_str:
|
|
848
|
+
for d in str(extra_read_str).split(","):
|
|
849
|
+
d = d.strip()
|
|
850
|
+
if d and Path(d).exists():
|
|
851
|
+
extra_read_dirs.append(Path(d))
|
|
852
|
+
|
|
853
|
+
# Create sandbox and load datasets
|
|
854
|
+
sandbox = Sandbox(
|
|
855
|
+
timeout=timeout,
|
|
856
|
+
output_dir=output_dir,
|
|
857
|
+
max_retries=max_retries,
|
|
858
|
+
extra_read_dirs=extra_read_dirs or None,
|
|
859
|
+
)
|
|
860
|
+
sandbox.load_datasets()
|
|
861
|
+
if prior_results:
|
|
862
|
+
sandbox.inject_prior_results(prior_results)
|
|
863
|
+
|
|
864
|
+
# Build the system prompt with namespace info
|
|
865
|
+
ns_desc = sandbox.describe_namespace()
|
|
866
|
+
format_kwargs = {"namespace_description": ns_desc}
|
|
867
|
+
|
|
868
|
+
# Provide a data_files_description if the template expects one
|
|
869
|
+
if "{data_files_description}" in system_prompt_template:
|
|
870
|
+
format_kwargs["data_files_description"] = _describe_data_files(extra_dirs=extra_read_dirs)
|
|
871
|
+
|
|
872
|
+
system_prompt = system_prompt_template.format(**format_kwargs)
|
|
873
|
+
|
|
874
|
+
# ── Multi-turn agentic path ─────────────────────────────────────────
|
|
875
|
+
max_turns = int(config.get("sandbox.max_turns", 0))
|
|
876
|
+
if max_turns > 0:
|
|
877
|
+
agentic_prompt = system_prompt + AGENTIC_CODE_ADDENDUM
|
|
878
|
+
return _agentic_code_loop(
|
|
879
|
+
goal=goal,
|
|
880
|
+
system_prompt=agentic_prompt,
|
|
881
|
+
llm=llm,
|
|
882
|
+
sandbox=sandbox,
|
|
883
|
+
session=session,
|
|
884
|
+
max_turns=max_turns,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# ── Legacy single-shot path ─────────────────────────────────────────
|
|
888
|
+
code = None
|
|
889
|
+
exec_result = {"error": "No code generated"}
|
|
890
|
+
|
|
891
|
+
for attempt in range(1, max_retries + 2): # 1 initial + max_retries fixes
|
|
892
|
+
if attempt == 1:
|
|
893
|
+
user_msg = f"Goal: {goal}"
|
|
894
|
+
else:
|
|
895
|
+
user_msg = ERROR_RETRY_PROMPT.format(code=code, error=exec_result["error"])
|
|
896
|
+
|
|
897
|
+
with session.console.status(
|
|
898
|
+
f"[green]{'Generating' if attempt == 1 else 'Fixing'} code...[/green]",
|
|
899
|
+
spinner="dots",
|
|
900
|
+
):
|
|
901
|
+
response = llm.chat(
|
|
902
|
+
system=system_prompt,
|
|
903
|
+
messages=[{"role": "user", "content": user_msg}],
|
|
904
|
+
temperature=0.2,
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
code = _extract_code(response.content)
|
|
908
|
+
exec_result = sandbox.execute(code)
|
|
909
|
+
|
|
910
|
+
if exec_result["error"] is None:
|
|
911
|
+
# Reflection: let the LLM review its own output for logical errors
|
|
912
|
+
max_reflect = int(config.get("sandbox.max_reflect", 2))
|
|
913
|
+
for reflect_turn in range(max_reflect):
|
|
914
|
+
stdout_text = exec_result.get("stdout", "") or ""
|
|
915
|
+
result_obj = exec_result.get("result")
|
|
916
|
+
# Only reflect if there's meaningful output to check
|
|
917
|
+
if not stdout_text and not result_obj:
|
|
918
|
+
break
|
|
919
|
+
result_preview = str(result_obj)[:1000] if result_obj else "(no result dict)"
|
|
920
|
+
reflect_msg = REFLECTION_PROMPT.format(
|
|
921
|
+
goal=goal,
|
|
922
|
+
code=code,
|
|
923
|
+
stdout=stdout_text[:3000],
|
|
924
|
+
result=result_preview,
|
|
925
|
+
)
|
|
926
|
+
with session.console.status(
|
|
927
|
+
f"[cyan]Reviewing output (turn {reflect_turn + 1})...[/cyan]",
|
|
928
|
+
spinner="dots",
|
|
929
|
+
):
|
|
930
|
+
reflect_response = llm.chat(
|
|
931
|
+
system=system_prompt,
|
|
932
|
+
messages=[{"role": "user", "content": reflect_msg}],
|
|
933
|
+
temperature=0.2,
|
|
934
|
+
)
|
|
935
|
+
reflect_text = reflect_response.content.strip()
|
|
936
|
+
if reflect_text.upper().startswith("LGTM"):
|
|
937
|
+
break # LLM says output is correct
|
|
938
|
+
# LLM returned corrected code — execute it
|
|
939
|
+
fixed_code = _extract_code(reflect_text)
|
|
940
|
+
if not fixed_code or fixed_code == code:
|
|
941
|
+
break # No new code or same code — stop
|
|
942
|
+
code = fixed_code
|
|
943
|
+
exec_result = sandbox.execute(code)
|
|
944
|
+
if exec_result["error"] is not None:
|
|
945
|
+
break # New code errored — fall through to error retry
|
|
946
|
+
|
|
947
|
+
# Return result (either original or reflection-fixed)
|
|
948
|
+
if exec_result["error"] is None:
|
|
949
|
+
summary = ""
|
|
950
|
+
if exec_result["result"] and isinstance(exec_result["result"], dict):
|
|
951
|
+
summary = exec_result["result"].get("summary", "")
|
|
952
|
+
if not summary and exec_result["stdout"]:
|
|
953
|
+
summary = exec_result["stdout"][:500]
|
|
954
|
+
if not summary:
|
|
955
|
+
summary = "Code executed successfully."
|
|
956
|
+
|
|
957
|
+
return {
|
|
958
|
+
"summary": summary,
|
|
959
|
+
"code": code,
|
|
960
|
+
"stdout": exec_result["stdout"],
|
|
961
|
+
"result": exec_result["result"],
|
|
962
|
+
"plots": exec_result["plots"],
|
|
963
|
+
"exports": exec_result["exports"],
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
if attempt > max_retries:
|
|
967
|
+
break
|
|
968
|
+
|
|
969
|
+
return {
|
|
970
|
+
"summary": f"Code execution failed after {max_retries + 1} attempts: {exec_result['error'][:200]}",
|
|
971
|
+
"error": exec_result["error"],
|
|
972
|
+
"code": code,
|
|
973
|
+
"stdout": exec_result.get("stdout", ""),
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def _describe_data_files(extra_dirs: list[Path] | None = None) -> str:
|
|
978
|
+
"""List data files in CWD and extra directories for domain tool prompts."""
|
|
979
|
+
data_exts = {
|
|
980
|
+
".csv", ".tsv", ".xlsx", ".xls", ".parquet",
|
|
981
|
+
".vcf", ".bed", ".bam", ".fasta", ".fa", ".faa",
|
|
982
|
+
".fastq", ".gff", ".gtf", ".nwk", ".nex", ".tree",
|
|
983
|
+
".mafft", ".clipkit", ".aln", ".phy", ".gz",
|
|
984
|
+
".zip", ".rds", ".rdata", ".gmt", ".json",
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
def _scan_dir(directory: Path, label: str) -> list[str]:
|
|
988
|
+
entries = []
|
|
989
|
+
if not directory.exists():
|
|
990
|
+
return entries
|
|
991
|
+
try:
|
|
992
|
+
for f in sorted(directory.rglob("*")):
|
|
993
|
+
if f.is_file() and (f.suffix.lower() in data_exts):
|
|
994
|
+
size = f.stat().st_size
|
|
995
|
+
if size > 1_000_000:
|
|
996
|
+
size_str = f"{size / 1_000_000:.1f} MB"
|
|
997
|
+
elif size > 1_000:
|
|
998
|
+
size_str = f"{size / 1_000:.1f} KB"
|
|
999
|
+
else:
|
|
1000
|
+
size_str = f"{size} bytes"
|
|
1001
|
+
try:
|
|
1002
|
+
rel = f.relative_to(directory)
|
|
1003
|
+
except ValueError:
|
|
1004
|
+
rel = f.name
|
|
1005
|
+
entries.append(f"- {rel} ({size_str})")
|
|
1006
|
+
except PermissionError:
|
|
1007
|
+
pass
|
|
1008
|
+
return entries[:80]
|
|
1009
|
+
|
|
1010
|
+
sections = []
|
|
1011
|
+
cwd = Path.cwd()
|
|
1012
|
+
cwd_files = _scan_dir(cwd, "working directory")
|
|
1013
|
+
if cwd_files:
|
|
1014
|
+
sections.append("Files in working directory:\n" + "\n".join(cwd_files))
|
|
1015
|
+
|
|
1016
|
+
for d in (extra_dirs or []):
|
|
1017
|
+
d_files = _scan_dir(d, str(d))
|
|
1018
|
+
if d_files:
|
|
1019
|
+
sections.append(f"Files in {d}:\n" + "\n".join(d_files))
|
|
1020
|
+
sections.append(f"\nData directory accessible for reading: {d}")
|
|
1021
|
+
|
|
1022
|
+
if not sections:
|
|
1023
|
+
return "No data files found in the working directory."
|
|
1024
|
+
return "\n\n".join(sections)
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
@registry.register(
|
|
1028
|
+
name="code.execute",
|
|
1029
|
+
description="Generate and execute custom Python analysis code",
|
|
1030
|
+
category="code",
|
|
1031
|
+
parameters={"goal": "Natural language description of the analysis to perform"},
|
|
1032
|
+
usage_guide=(
|
|
1033
|
+
"Use ONLY when no pre-built tool covers the analysis. Good for: custom visualizations, "
|
|
1034
|
+
"statistical tests, data exploration, combining/filtering data in novel ways, generating plots. "
|
|
1035
|
+
"Pre-built tools are preferred — this is the escape hatch."
|
|
1036
|
+
),
|
|
1037
|
+
)
|
|
1038
|
+
def execute(goal: str, _session=None, _prior_results=None, **kwargs) -> dict:
|
|
1039
|
+
"""Generate and execute Python code for a custom analysis goal."""
|
|
1040
|
+
if _session is None:
|
|
1041
|
+
return {
|
|
1042
|
+
"summary": "Code execution unavailable: no active session.",
|
|
1043
|
+
"error": "No session provided. code.execute requires an active ct session.",
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
# Handle "write script to file" goals directly (outside sandbox execution path).
|
|
1047
|
+
if _is_script_authoring_goal(goal):
|
|
1048
|
+
llm = _session.get_llm()
|
|
1049
|
+
max_retries = int(_session.config.get("sandbox.max_retries", 2))
|
|
1050
|
+
return _generate_and_save_script(
|
|
1051
|
+
goal=goal,
|
|
1052
|
+
llm=llm,
|
|
1053
|
+
max_retries=max_retries,
|
|
1054
|
+
session=_session,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
# Use bioinformatics prompt when data files are present in the sandbox,
|
|
1058
|
+
# otherwise use the generic code-gen prompt. Domain-specific knowledge
|
|
1059
|
+
# (phylo, KEGG ORA, variant classification) lives in the dedicated domain
|
|
1060
|
+
# tools (phylo.analyze, omics.kegg_ora, genomics.variant_classify) which
|
|
1061
|
+
# the planner selects directly.
|
|
1062
|
+
prompt = BIOINFORMATICS_CODE_GEN_PROMPT if _session.config.get("agent.bioinformatics_mode") else CODE_GEN_SYSTEM_PROMPT
|
|
1063
|
+
|
|
1064
|
+
return _generate_and_execute_code(
|
|
1065
|
+
goal=goal,
|
|
1066
|
+
system_prompt_template=prompt,
|
|
1067
|
+
session=_session,
|
|
1068
|
+
prior_results=_prior_results,
|
|
1069
|
+
)
|