celltype-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. celltype_cli-0.1.0.dist-info/METADATA +267 -0
  2. celltype_cli-0.1.0.dist-info/RECORD +89 -0
  3. celltype_cli-0.1.0.dist-info/WHEEL +4 -0
  4. celltype_cli-0.1.0.dist-info/entry_points.txt +2 -0
  5. celltype_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  6. ct/__init__.py +3 -0
  7. ct/agent/__init__.py +0 -0
  8. ct/agent/case_studies.py +426 -0
  9. ct/agent/config.py +523 -0
  10. ct/agent/doctor.py +544 -0
  11. ct/agent/knowledge.py +523 -0
  12. ct/agent/loop.py +99 -0
  13. ct/agent/mcp_server.py +478 -0
  14. ct/agent/orchestrator.py +733 -0
  15. ct/agent/runner.py +656 -0
  16. ct/agent/sandbox.py +481 -0
  17. ct/agent/session.py +145 -0
  18. ct/agent/system_prompt.py +186 -0
  19. ct/agent/trace_store.py +228 -0
  20. ct/agent/trajectory.py +169 -0
  21. ct/agent/types.py +182 -0
  22. ct/agent/workflows.py +462 -0
  23. ct/api/__init__.py +1 -0
  24. ct/api/app.py +211 -0
  25. ct/api/config.py +120 -0
  26. ct/api/engine.py +124 -0
  27. ct/cli.py +1448 -0
  28. ct/data/__init__.py +0 -0
  29. ct/data/compute_providers.json +59 -0
  30. ct/data/cro_database.json +395 -0
  31. ct/data/downloader.py +238 -0
  32. ct/data/loaders.py +252 -0
  33. ct/kb/__init__.py +5 -0
  34. ct/kb/benchmarks.py +147 -0
  35. ct/kb/governance.py +106 -0
  36. ct/kb/ingest.py +415 -0
  37. ct/kb/reasoning.py +129 -0
  38. ct/kb/schema_monitor.py +162 -0
  39. ct/kb/substrate.py +387 -0
  40. ct/models/__init__.py +0 -0
  41. ct/models/llm.py +370 -0
  42. ct/tools/__init__.py +195 -0
  43. ct/tools/_compound_resolver.py +297 -0
  44. ct/tools/biomarker.py +368 -0
  45. ct/tools/cellxgene.py +282 -0
  46. ct/tools/chemistry.py +1371 -0
  47. ct/tools/claude.py +390 -0
  48. ct/tools/clinical.py +1153 -0
  49. ct/tools/clue.py +249 -0
  50. ct/tools/code.py +1069 -0
  51. ct/tools/combination.py +397 -0
  52. ct/tools/compute.py +402 -0
  53. ct/tools/cro.py +413 -0
  54. ct/tools/data_api.py +2114 -0
  55. ct/tools/design.py +295 -0
  56. ct/tools/dna.py +575 -0
  57. ct/tools/experiment.py +604 -0
  58. ct/tools/expression.py +655 -0
  59. ct/tools/files.py +957 -0
  60. ct/tools/genomics.py +1387 -0
  61. ct/tools/http_client.py +146 -0
  62. ct/tools/imaging.py +319 -0
  63. ct/tools/intel.py +223 -0
  64. ct/tools/literature.py +743 -0
  65. ct/tools/network.py +422 -0
  66. ct/tools/notification.py +111 -0
  67. ct/tools/omics.py +3330 -0
  68. ct/tools/ops.py +1230 -0
  69. ct/tools/parity.py +649 -0
  70. ct/tools/pk.py +245 -0
  71. ct/tools/protein.py +678 -0
  72. ct/tools/regulatory.py +643 -0
  73. ct/tools/remote_data.py +179 -0
  74. ct/tools/report.py +181 -0
  75. ct/tools/repurposing.py +376 -0
  76. ct/tools/safety.py +1280 -0
  77. ct/tools/shell.py +178 -0
  78. ct/tools/singlecell.py +533 -0
  79. ct/tools/statistics.py +552 -0
  80. ct/tools/structure.py +882 -0
  81. ct/tools/target.py +901 -0
  82. ct/tools/translational.py +123 -0
  83. ct/tools/viability.py +218 -0
  84. ct/ui/__init__.py +0 -0
  85. ct/ui/markdown.py +31 -0
  86. ct/ui/status.py +258 -0
  87. ct/ui/suggestions.py +567 -0
  88. ct/ui/terminal.py +1456 -0
  89. ct/ui/traces.py +112 -0
@@ -0,0 +1,297 @@
1
+ """Compound name resolver — maps drug names to dataset-specific IDs.
2
+
3
+ The ct datasets use proprietary compound IDs:
4
+ - PRISM/L1000: YU-codes (e.g., YU254653)
5
+ - Proteomics: Cmpd format (e.g., Cmpd18_B10)
6
+
7
+ This module resolves common drug names (lenalidomide, pomalidomide, etc.)
8
+ to the most structurally similar compound in each dataset via Tanimoto similarity.
9
+ """
10
+
11
+ import csv
12
+ import os
13
+ import re
14
+ from functools import lru_cache
15
+
16
+ # Data file paths
17
+ _DATA_DIR = "/mnt2/bronze/molecular_glue/crews_library"
18
+ _SMILES_CSV = os.path.join(_DATA_DIR, "all_compounds_smiles.csv")
19
+ _PROT_MAPPING_CSV = os.path.join(_DATA_DIR, "proteomics_to_yu_mapping.csv")
20
+
21
+ # Regex patterns
22
+ _YU_PATTERN = re.compile(r"^YU\d{6}$")
23
+ _CMPD_PATTERN = re.compile(r"^Cmpd\d+")
24
+
25
+ # Module-level caches (populated on first use)
26
+ _yu_smiles: dict | None = None
27
+ _prot_to_yu: dict | None = None
28
+ _yu_to_prot: dict | None = None
29
+
30
+
31
+ def _load_yu_smiles() -> dict:
32
+ """Load YU compound → SMILES mapping (lazy, cached)."""
33
+ global _yu_smiles
34
+ if _yu_smiles is not None:
35
+ return _yu_smiles
36
+ _yu_smiles = {}
37
+ if not os.path.exists(_SMILES_CSV):
38
+ return _yu_smiles
39
+ with open(_SMILES_CSV) as f:
40
+ reader = csv.DictReader(f)
41
+ for row in reader:
42
+ _yu_smiles[row["compound"]] = row["smiles"]
43
+ return _yu_smiles
44
+
45
+
46
+ def _load_prot_mapping() -> tuple[dict, dict]:
47
+ """Load proteomics Cmpd ↔ YU mapping (lazy, cached)."""
48
+ global _prot_to_yu, _yu_to_prot
49
+ if _prot_to_yu is not None:
50
+ return _prot_to_yu, _yu_to_prot
51
+ _prot_to_yu = {}
52
+ _yu_to_prot = {}
53
+ if not os.path.exists(_PROT_MAPPING_CSV):
54
+ return _prot_to_yu, _yu_to_prot
55
+ with open(_PROT_MAPPING_CSV) as f:
56
+ reader = csv.DictReader(f)
57
+ for row in reader:
58
+ _prot_to_yu[row["cmpd_id"]] = row["yu_id"]
59
+ _yu_to_prot[row["yu_id"]] = row["cmpd_id"]
60
+ return _prot_to_yu, _yu_to_prot
61
+
62
+
63
+ @lru_cache(maxsize=64)
64
+ def _tanimoto_search(smiles: str, candidate_ids: frozenset) -> tuple[str, float] | None:
65
+ """Find most similar YU compound to a SMILES string by Tanimoto similarity."""
66
+ try:
67
+ from rdkit import Chem
68
+ from rdkit.Chem import AllChem, DataStructs
69
+ except ImportError:
70
+ return None
71
+
72
+ mol = Chem.MolFromSmiles(smiles)
73
+ if mol is None:
74
+ return None
75
+ query_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
76
+
77
+ yu_smiles = _load_yu_smiles()
78
+ best_id, best_sim = None, 0.0
79
+
80
+ for yu_id in candidate_ids:
81
+ smi = yu_smiles.get(yu_id)
82
+ if not smi:
83
+ continue
84
+ m = Chem.MolFromSmiles(smi)
85
+ if m is None:
86
+ continue
87
+ fp = AllChem.GetMorganFingerprintAsBitVect(m, 2, nBits=2048)
88
+ sim = DataStructs.TanimotoSimilarity(query_fp, fp)
89
+ if sim > best_sim:
90
+ best_sim = sim
91
+ best_id = yu_id
92
+
93
+ if best_id is None:
94
+ return None
95
+ return best_id, round(best_sim, 4)
96
+
97
+
98
+ def resolve_to_smiles(name_or_smiles: str) -> str:
99
+ """Resolve a compound name or SMILES string to a canonical SMILES string.
100
+
101
+ Resolution order:
102
+ 1. Try parsing as SMILES with RDKit (if installed)
103
+ 2. Try PubChem lookup via API
104
+ 3. Try ChEMBL lookup via API
105
+ 4. Raise ValueError if all methods fail
106
+
107
+ Parameters
108
+ ----------
109
+ name_or_smiles : str
110
+ Drug name (e.g. "lenalidomide") or SMILES string.
111
+
112
+ Returns
113
+ -------
114
+ str
115
+ Canonical SMILES string.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If the input cannot be resolved to a SMILES string.
121
+ """
122
+ if not name_or_smiles or not isinstance(name_or_smiles, str):
123
+ raise ValueError(f"Invalid input: {name_or_smiles}")
124
+
125
+ name_or_smiles = name_or_smiles.strip()
126
+ if not name_or_smiles:
127
+ raise ValueError("Empty input")
128
+
129
+ # 1. Try RDKit parse — is it already a valid SMILES?
130
+ try:
131
+ from rdkit import Chem
132
+
133
+ mol = Chem.MolFromSmiles(name_or_smiles)
134
+ if mol is not None:
135
+ return name_or_smiles
136
+ except ImportError:
137
+ # No RDKit — heuristic: if it has typical SMILES characters, assume it's SMILES
138
+ if any(c in name_or_smiles for c in "=#()[]"):
139
+ return name_or_smiles
140
+
141
+ # 2. Try PubChem lookup
142
+ try:
143
+ from ct.tools.chemistry import pubchem_lookup
144
+
145
+ result = pubchem_lookup(name_or_smiles, query_type="name")
146
+ smiles = (result.get("properties") or {}).get("canonical_smiles")
147
+ if smiles:
148
+ return smiles
149
+ except Exception:
150
+ pass
151
+
152
+ # 3. Try ChEMBL lookup
153
+ try:
154
+ from ct.tools.literature import chembl_query
155
+
156
+ result = chembl_query(name_or_smiles, query_type="molecule", max_results=1)
157
+ molecules = result.get("molecules", [])
158
+ if molecules and molecules[0].get("smiles"):
159
+ return molecules[0]["smiles"]
160
+ except Exception:
161
+ pass
162
+
163
+ raise ValueError(
164
+ f"Could not resolve '{name_or_smiles}' to a SMILES string. "
165
+ "Tried: RDKit SMILES parse, PubChem, ChEMBL."
166
+ )
167
+
168
+
169
+ def resolve_compound(name_or_id: str, dataset: str = "prism") -> str:
170
+ """Resolve a compound name or ID to a dataset-specific ID.
171
+
172
+ Parameters
173
+ ----------
174
+ name_or_id : str
175
+ Drug name (e.g. "lenalidomide"), YU code, or Cmpd ID.
176
+ dataset : str
177
+ Target dataset: "prism", "l1000", or "proteomics".
178
+
179
+ Returns
180
+ -------
181
+ str
182
+ Resolved compound ID for the target dataset.
183
+ For L1000 with compound-named index, returns the compound name directly.
184
+ For proteomics, returns the full Cmpd_well ID if possible.
185
+ Falls back to the original input if resolution fails.
186
+ """
187
+ if not name_or_id or not isinstance(name_or_id, str):
188
+ return name_or_id
189
+
190
+ name_or_id = name_or_id.strip()
191
+
192
+ # For L1000: check if the index uses compound names (new format) vs YU codes (legacy)
193
+ if dataset == "l1000":
194
+ try:
195
+ from ct.data.loaders import load_l1000
196
+ l1000 = load_l1000()
197
+ sample_idx = str(l1000.index[0]) if len(l1000) > 0 else ""
198
+ if not _YU_PATTERN.match(sample_idx):
199
+ # New compound-named index: case-insensitive lookup
200
+ name_lower = name_or_id.lower().strip()
201
+ idx_lower = {c.lower(): c for c in l1000.index}
202
+ if name_lower in idx_lower:
203
+ return idx_lower[name_lower]
204
+ # Try partial match
205
+ for key, original in idx_lower.items():
206
+ if name_lower in key or key in name_lower:
207
+ return original
208
+ return name_or_id
209
+ except (FileNotFoundError, ImportError):
210
+ pass
211
+
212
+ # Already a YU code — return as-is for PRISM/L1000, convert for proteomics
213
+ if _YU_PATTERN.match(name_or_id):
214
+ if dataset == "proteomics":
215
+ return _yu_to_proteomics_col(name_or_id) or name_or_id
216
+ return name_or_id
217
+
218
+ # Already a Cmpd ID — return as-is for proteomics, convert for PRISM/L1000
219
+ if _CMPD_PATTERN.match(name_or_id):
220
+ if dataset == "proteomics":
221
+ return name_or_id
222
+ prot_to_yu, _ = _load_prot_mapping()
223
+ base = name_or_id.split("_")[0]
224
+ yu_id = prot_to_yu.get(base)
225
+ return yu_id if yu_id else name_or_id
226
+
227
+ # Drug name — try to resolve SMILES via API, then find closest match in dataset
228
+ try:
229
+ drug_smi = resolve_to_smiles(name_or_id)
230
+ except ValueError:
231
+ return name_or_id # Cannot resolve to SMILES, return as-is
232
+
233
+ # Dynamic SMILES-based resolution (cached)
234
+ # Get candidate compounds from the target dataset
235
+ candidates = _get_dataset_compounds(dataset)
236
+ if not candidates:
237
+ return name_or_id
238
+
239
+ result = _tanimoto_search(drug_smi, frozenset(candidates))
240
+ if result is None:
241
+ return name_or_id
242
+
243
+ yu_id, sim = result
244
+ # Low-similarity proxies produce misleading data — return the original
245
+ # name so the tool reports "not found" and synthesis uses LLM knowledge
246
+ if sim < 0.65:
247
+ return name_or_id
248
+ if dataset == "proteomics":
249
+ return _yu_to_proteomics_col(yu_id) or yu_id
250
+ return yu_id
251
+
252
+
253
+ def resolve_proteomics_id(yu_id: str) -> str | None:
254
+ """Convert a YU compound ID to the proteomics Cmpd_well column name.
255
+
256
+ Returns None if no mapping exists.
257
+ """
258
+ return _yu_to_proteomics_col(yu_id)
259
+
260
+
261
+ def _yu_to_proteomics_col(yu_id: str) -> str | None:
262
+ """Map YU ID → full proteomics column name (Cmpd##_well)."""
263
+ _, yu_to_prot = _load_prot_mapping()
264
+ base_cmpd = yu_to_prot.get(yu_id)
265
+ if base_cmpd is None:
266
+ return None
267
+
268
+ # Find the full column name in proteomics data
269
+ try:
270
+ from ct.data.loaders import load_proteomics
271
+ prot = load_proteomics()
272
+ for col in prot.columns:
273
+ if col.startswith(base_cmpd + "_"):
274
+ return col
275
+ except (FileNotFoundError, ImportError):
276
+ pass
277
+
278
+ return base_cmpd
279
+
280
+
281
+ def _get_dataset_compounds(dataset: str) -> set:
282
+ """Get the set of YU compound IDs available in a dataset."""
283
+ try:
284
+ if dataset == "prism":
285
+ from ct.data.loaders import load_prism
286
+ prism = load_prism()
287
+ return set(prism["pert_name"].unique())
288
+ elif dataset == "l1000":
289
+ from ct.data.loaders import load_l1000
290
+ l1000 = load_l1000()
291
+ return set(l1000.index.tolist())
292
+ elif dataset == "proteomics":
293
+ _, yu_to_prot = _load_prot_mapping()
294
+ return set(yu_to_prot.keys())
295
+ except (FileNotFoundError, ImportError):
296
+ pass
297
+ return set()
ct/tools/biomarker.py ADDED
@@ -0,0 +1,368 @@
1
+ """
2
+ Biomarker tools: mutation sensitivity, resistance profiling, dependency validation.
3
+ """
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ from scipy import stats
8
+ from ct.tools import registry
9
+
10
+
11
+ @registry.register(
12
+ name="biomarker.mutation_sensitivity",
13
+ description="Test whether specific mutations sensitize or confer resistance to a compound",
14
+ category="biomarker",
15
+ parameters={"compound_id": "Compound YU ID", "gene": "Gene to test (or 'all' for genome-wide)"},
16
+ requires_data=["prism", "depmap_mutations", "depmap_model"],
17
+ usage_guide="You want to find predictive biomarkers — which mutations make cells more or less sensitive to a compound. Use for patient stratification and clinical trial design.",
18
+ )
19
+ def mutation_sensitivity(compound_id: str, gene: str = "all", **kwargs) -> dict:
20
+ """Test mutation-sensitivity associations."""
21
+ from ct.data.loaders import load_prism, load_mutations, load_model_metadata
22
+ from ct.tools._compound_resolver import resolve_compound
23
+
24
+ compound_id = resolve_compound(compound_id, dataset="prism")
25
+
26
+ prism = load_prism()
27
+ mutations = load_mutations()
28
+ model = load_model_metadata()
29
+
30
+ # Map PRISM cell lines to DepMap ModelIDs
31
+ ccle_to_model = {}
32
+ for _, row in model.iterrows():
33
+ ccle = row.get("CCLEName", "")
34
+ mid = row.get("ModelID", "")
35
+ if pd.notna(ccle) and pd.notna(mid):
36
+ ccle_to_model[ccle] = mid
37
+
38
+ # Get compound sensitivity at highest dose
39
+ cpd = prism[prism["pert_name"] == compound_id]
40
+ if len(cpd) == 0:
41
+ return {"error": f"Compound {compound_id} not in PRISM", "summary": f"Compound {compound_id} not found in PRISM data"}
42
+
43
+ max_dose = cpd["pert_dose"].max()
44
+ cpd_hd = cpd[cpd["pert_dose"] == max_dose].groupby("ccle_name")["LFC"].mean()
45
+
46
+ # Map to ModelIDs
47
+ sensitivity = {}
48
+ for ccle, lfc in cpd_hd.items():
49
+ mid = ccle_to_model.get(ccle)
50
+ if mid and mid in mutations.index:
51
+ sensitivity[mid] = lfc
52
+
53
+ if len(sensitivity) < 20:
54
+ return {"error": f"Insufficient overlap: only {len(sensitivity)} cell lines mapped", "summary": f"Insufficient cell line overlap ({len(sensitivity)} < 20 required)"}
55
+
56
+ # Test genes
57
+ genes_to_test = [gene] if gene != "all" else [g for g in mutations.columns if mutations[g].sum() >= 3]
58
+ model_ids = list(sensitivity.keys())
59
+ lfc_values = pd.Series(sensitivity)
60
+
61
+ results = []
62
+ for g in genes_to_test:
63
+ if g not in mutations.columns:
64
+ continue
65
+
66
+ mut_status = mutations.loc[model_ids, g].reindex(model_ids).fillna(0)
67
+ mutant_ids = [m for m in model_ids if mut_status.loc[m] > 0]
68
+ wt_ids = [m for m in model_ids if mut_status.loc[m] == 0]
69
+
70
+ if len(mutant_ids) < 3 or len(wt_ids) < 3:
71
+ continue
72
+
73
+ mut_lfc = lfc_values[mutant_ids]
74
+ wt_lfc = lfc_values[wt_ids]
75
+ stat, pval = stats.mannwhitneyu(mut_lfc, wt_lfc, alternative="two-sided")
76
+
77
+ results.append({
78
+ "gene": g,
79
+ "mut_mean_lfc": round(float(mut_lfc.mean()), 3),
80
+ "wt_mean_lfc": round(float(wt_lfc.mean()), 3),
81
+ "delta": round(float(mut_lfc.mean() - wt_lfc.mean()), 3),
82
+ "direction": "sensitizing" if mut_lfc.mean() < wt_lfc.mean() else "resistance",
83
+ "pval": float(pval),
84
+ "n_mutant": len(mutant_ids),
85
+ "n_wt": len(wt_ids),
86
+ })
87
+
88
+ if not results:
89
+ return {
90
+ "summary": (
91
+ f"Mutation sensitivity for {compound_id}: {len(genes_to_test)} genes tested, "
92
+ f"0 significant (p<0.05). No genes met minimum sample size (3 mutant, 3 WT)."
93
+ ),
94
+ "compound": compound_id,
95
+ "significant_mutations": [],
96
+ "n_tested": 0,
97
+ }
98
+
99
+ df = pd.DataFrame(results).sort_values("pval")
100
+ if len(df) > 0:
101
+ # Benjamini-Hochberg FDR: p * m / rank (monotonicity enforced)
102
+ ranks = df["pval"].rank(method="first")
103
+ df["fdr"] = (df["pval"] * len(df) / ranks).clip(upper=1.0)
104
+ # Enforce monotonicity: walk backwards to ensure FDR is non-decreasing with p-value
105
+ fdr_arr = df["fdr"].values.copy()
106
+ for i in range(len(fdr_arr) - 2, -1, -1):
107
+ fdr_arr[i] = min(fdr_arr[i], fdr_arr[i + 1])
108
+ df["fdr"] = fdr_arr
109
+
110
+ sig = df[df["pval"] < 0.05]
111
+
112
+ return {
113
+ "summary": (
114
+ f"Mutation sensitivity for {compound_id}: {len(genes_to_test)} genes tested, "
115
+ f"{len(sig)} significant (p<0.05)"
116
+ ),
117
+ "compound": compound_id,
118
+ "significant_mutations": sig.head(20).to_dict("records") if len(sig) > 0 else [],
119
+ "n_tested": len(results),
120
+ }
121
+
122
+
123
+ @registry.register(
124
+ name="biomarker.resistance_profile",
125
+ description="Profile resistance mechanisms for a compound (lineage, mutation, dependency enrichment)",
126
+ category="biomarker",
127
+ parameters={"compound_id": "Compound YU ID"},
128
+ requires_data=["prism", "depmap_crispr", "depmap_mutations", "depmap_model"],
129
+ usage_guide="You want to understand why some cell lines resist a compound — lineage effects, specific mutations, or dependency patterns. Use to anticipate clinical resistance mechanisms.",
130
+ )
131
+ def resistance_profile(compound_id: str, **kwargs) -> dict:
132
+ """Comprehensive resistance profiling for a compound."""
133
+ from ct.data.loaders import load_prism, load_model_metadata
134
+ from ct.tools._compound_resolver import resolve_compound
135
+
136
+ compound_id = resolve_compound(compound_id, dataset="prism")
137
+
138
+ prism = load_prism()
139
+ model = load_model_metadata()
140
+
141
+ cpd = prism[prism["pert_name"] == compound_id]
142
+ if len(cpd) == 0:
143
+ return {"error": f"Compound {compound_id} not in PRISM", "summary": f"Compound {compound_id} not found in PRISM data"}
144
+
145
+ max_dose = cpd["pert_dose"].max()
146
+ cpd_hd = cpd[cpd["pert_dose"] == max_dose].groupby("ccle_name")["LFC"].mean()
147
+
148
+ n_sensitive = (cpd_hd < -0.5).sum()
149
+ n_resistant = (cpd_hd > -0.1).sum()
150
+ n_intermediate = len(cpd_hd) - n_sensitive - n_resistant
151
+
152
+ # Lineage enrichment
153
+ ccle_to_lineage = {}
154
+ for _, row in model.iterrows():
155
+ ccle = row.get("CCLEName", "")
156
+ lin = row.get("OncotreeLineage", "Unknown")
157
+ if pd.notna(ccle) and pd.notna(lin):
158
+ ccle_to_lineage[ccle] = lin
159
+
160
+ sens_lineages = [ccle_to_lineage.get(c, "Unknown") for c in cpd_hd[cpd_hd < -0.5].index]
161
+ res_lineages = [ccle_to_lineage.get(c, "Unknown") for c in cpd_hd[cpd_hd > -0.1].index]
162
+
163
+ lineage_counts = {}
164
+ for lin in set(sens_lineages + res_lineages):
165
+ if lin == "Unknown":
166
+ continue
167
+ s = sens_lineages.count(lin)
168
+ r = res_lineages.count(lin)
169
+ if s + r >= 3:
170
+ lineage_counts[lin] = {"sensitive": s, "resistant": r, "total": s + r}
171
+
172
+ return {
173
+ "summary": (
174
+ f"Resistance profile for {compound_id}:\n"
175
+ f" Sensitive: {n_sensitive}, Intermediate: {n_intermediate}, Resistant: {n_resistant}\n"
176
+ f" {len(lineage_counts)} lineages profiled"
177
+ ),
178
+ "compound": compound_id,
179
+ "n_sensitive": int(n_sensitive),
180
+ "n_resistant": int(n_resistant),
181
+ "n_intermediate": int(n_intermediate),
182
+ "lineage_profiles": lineage_counts,
183
+ }
184
+
185
+
186
+ @registry.register(
187
+ name="biomarker.panel_select",
188
+ description="ML-based biomarker panel selection: identify top predictive mutations for compound sensitivity",
189
+ category="biomarker",
190
+ parameters={
191
+ "compound_id": "Compound ID (PRISM pert_name)",
192
+ "n_features": "Number of top biomarker features to return (default 10)",
193
+ "method": "Feature selection method: 'mutual_info', 'lasso', or 'random_forest' (default 'mutual_info')",
194
+ },
195
+ requires_data=["prism", "depmap_mutations", "depmap_model"],
196
+ usage_guide="You want to select the best biomarker panel for predicting compound response — "
197
+ "which mutations best predict sensitivity. Use for patient stratification, companion "
198
+ "diagnostic design, and clinical trial enrichment. Methods: mutual_info (fast, nonlinear), "
199
+ "lasso (sparse linear), random_forest (handles interactions).",
200
+ )
201
+ def panel_select(
202
+ compound_id: str,
203
+ n_features: int = 10,
204
+ method: str = "mutual_info",
205
+ **kwargs,
206
+ ) -> dict:
207
+ """ML-based biomarker panel selection using sklearn.
208
+
209
+ Uses PRISM sensitivity as target (LFC < -0.5 = sensitive) and the DepMap
210
+ mutation matrix as features. Supports three methods:
211
+ - mutual_info: mutual information classification (fast, captures nonlinear)
212
+ - lasso: LassoCV with L1 regularization (sparse, linear)
213
+ - random_forest: random forest feature importances (handles interactions)
214
+
215
+ Returns ranked list of biomarker genes with importance scores and
216
+ cross-validation AUC.
217
+ """
218
+ from ct.data.loaders import load_prism, load_mutations, load_model_metadata
219
+ from ct.tools._compound_resolver import resolve_compound
220
+
221
+ compound_id = resolve_compound(compound_id, dataset="prism")
222
+
223
+ valid_methods = ("mutual_info", "lasso", "random_forest")
224
+ if method not in valid_methods:
225
+ return {"error": f"Unknown method '{method}'. Choose from: {', '.join(valid_methods)}", "summary": f"Unknown method '{method}'. Choose from: {', '.join(valid_methods)}"}
226
+ prism = load_prism()
227
+ mutations = load_mutations()
228
+ model = load_model_metadata()
229
+
230
+ # --- Map PRISM cell lines to DepMap ModelIDs ---
231
+ ccle_to_model = {}
232
+ for _, row in model.iterrows():
233
+ ccle = row.get("CCLEName", "")
234
+ mid = row.get("ModelID", "")
235
+ if pd.notna(ccle) and pd.notna(mid):
236
+ ccle_to_model[ccle] = mid
237
+
238
+ # --- Get compound sensitivity at highest dose ---
239
+ cpd = prism[prism["pert_name"] == compound_id]
240
+ if len(cpd) == 0:
241
+ return {"error": f"Compound {compound_id} not found in PRISM data", "summary": f"Compound {compound_id} not found in PRISM data"}
242
+ max_dose = cpd["pert_dose"].max()
243
+ cpd_hd = cpd[cpd["pert_dose"] == max_dose].groupby("ccle_name")["LFC"].mean()
244
+
245
+ # Map to ModelIDs and build target vector
246
+ sensitivity = {}
247
+ for ccle, lfc in cpd_hd.items():
248
+ mid = ccle_to_model.get(ccle)
249
+ if mid and mid in mutations.index:
250
+ sensitivity[mid] = lfc
251
+
252
+ if len(sensitivity) < 20:
253
+ return {
254
+ "error": f"Insufficient overlap: only {len(sensitivity)} cell lines mapped between PRISM and mutation data (need >=20)",
255
+ }
256
+
257
+ common_ids = list(sensitivity.keys())
258
+ y_lfc = np.array([sensitivity[mid] for mid in common_ids])
259
+ y_binary = (y_lfc < -0.5).astype(int) # 1 = sensitive
260
+
261
+ n_sensitive = int(y_binary.sum())
262
+ n_resistant = int(len(y_binary) - n_sensitive)
263
+
264
+ if n_sensitive < 3 or n_resistant < 3:
265
+ return {
266
+ "error": f"Insufficient class balance: {n_sensitive} sensitive, {n_resistant} resistant (need >=3 each)",
267
+ }
268
+
269
+ # --- Build feature matrix (mutation status) ---
270
+ # Filter to genes with at least 3 mutated samples
271
+ X_full = mutations.loc[common_ids].fillna(0)
272
+ gene_counts = (X_full > 0).sum()
273
+ usable_genes = gene_counts[(gene_counts >= 3) & (gene_counts <= len(common_ids) - 3)].index.tolist()
274
+
275
+ if len(usable_genes) < 3:
276
+ return {
277
+ "error": f"Only {len(usable_genes)} genes with sufficient mutation frequency for feature selection",
278
+ }
279
+
280
+ X = X_full[usable_genes].values
281
+ feature_names = usable_genes
282
+
283
+ # --- Feature selection ---
284
+ importances = np.zeros(len(feature_names))
285
+
286
+ if method == "mutual_info":
287
+ from sklearn.feature_selection import mutual_info_classif
288
+ importances = mutual_info_classif(X, y_binary, random_state=42)
289
+
290
+ elif method == "lasso":
291
+ from sklearn.linear_model import LassoCV
292
+ from sklearn.preprocessing import StandardScaler
293
+
294
+ scaler = StandardScaler()
295
+ X_scaled = scaler.fit_transform(X)
296
+ lasso = LassoCV(cv=min(5, n_sensitive, n_resistant), random_state=42, max_iter=5000)
297
+ lasso.fit(X_scaled, y_lfc) # Use continuous LFC for lasso
298
+ importances = np.abs(lasso.coef_)
299
+
300
+ elif method == "random_forest":
301
+ from sklearn.ensemble import RandomForestClassifier
302
+ rf = RandomForestClassifier(
303
+ n_estimators=100,
304
+ max_depth=5,
305
+ random_state=42,
306
+ class_weight="balanced",
307
+ )
308
+ rf.fit(X, y_binary)
309
+ importances = rf.feature_importances_
310
+
311
+ # --- Rank features ---
312
+ ranked_idx = np.argsort(importances)[::-1]
313
+ top_idx = ranked_idx[:n_features]
314
+
315
+ biomarkers = []
316
+ for i in top_idx:
317
+ gene_name = feature_names[i]
318
+ imp = float(importances[i])
319
+ if imp <= 0 and method != "lasso":
320
+ continue # Skip zero-importance features
321
+ n_mut = int((X[:, i] > 0).sum())
322
+ biomarkers.append({
323
+ "gene": gene_name,
324
+ "importance": round(imp, 6),
325
+ "n_mutated": n_mut,
326
+ "mutation_frequency": round(n_mut / len(common_ids), 4),
327
+ })
328
+
329
+ # --- Cross-validation AUC using top features ---
330
+ cv_auc = None
331
+ if biomarkers:
332
+ from sklearn.model_selection import cross_val_score
333
+ from sklearn.ensemble import RandomForestClassifier as RFC
334
+
335
+ top_gene_idx = [feature_names.index(b["gene"]) for b in biomarkers if b["gene"] in feature_names]
336
+ if len(top_gene_idx) >= 1:
337
+ X_top = X[:, top_gene_idx]
338
+ cv_folds = min(5, n_sensitive, n_resistant)
339
+ if cv_folds >= 2:
340
+ clf = RFC(n_estimators=50, max_depth=3, random_state=42, class_weight="balanced")
341
+ try:
342
+ scores = cross_val_score(clf, X_top, y_binary, cv=cv_folds, scoring="roc_auc")
343
+ cv_auc = round(float(scores.mean()), 4)
344
+ except ValueError:
345
+ cv_auc = None
346
+
347
+ # --- Summary ---
348
+ top_genes_str = ", ".join(
349
+ f"{b['gene']}({b['importance']:.4f})" for b in biomarkers[:5]
350
+ )
351
+ auc_str = f", CV-AUC={cv_auc:.3f}" if cv_auc is not None else ""
352
+ summary = (
353
+ f"Biomarker panel for {compound_id} ({method}): "
354
+ f"{len(biomarkers)} features selected from {len(usable_genes)} candidates. "
355
+ f"Top: {top_genes_str}{auc_str}"
356
+ )
357
+
358
+ return {
359
+ "summary": summary,
360
+ "compound": compound_id,
361
+ "method": method,
362
+ "n_cell_lines": len(common_ids),
363
+ "n_sensitive": n_sensitive,
364
+ "n_resistant": n_resistant,
365
+ "n_features_tested": len(usable_genes),
366
+ "biomarkers": biomarkers,
367
+ "cv_auc": cv_auc,
368
+ }