oncoordinate 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. oncoordinate/HallmarkPathGMT/HALLMARK_ADIPOGENESIS.v2023.2.Hs.gmt +1 -0
  2. oncoordinate/HallmarkPathGMT/HALLMARK_ALLOGRAFT_REJECTION.v2023.2.Hs.gmt +1 -0
  3. oncoordinate/HallmarkPathGMT/HALLMARK_ANDROGEN_RESPONSE.v2023.2.Hs.gmt +1 -0
  4. oncoordinate/HallmarkPathGMT/HALLMARK_ANGIOGENESIS.v2023.2.Hs.gmt +1 -0
  5. oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_JUNCTION.v2023.2.Hs.gmt +1 -0
  6. oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_SURFACE.v2023.2.Hs.gmt +1 -0
  7. oncoordinate/HallmarkPathGMT/HALLMARK_APOPTOSIS.v2023.2.Hs.gmt +1 -0
  8. oncoordinate/HallmarkPathGMT/HALLMARK_BILE_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
  9. oncoordinate/HallmarkPathGMT/HALLMARK_CHOLESTEROL_HOMEOSTASIS.v2023.2.Hs.gmt +1 -0
  10. oncoordinate/HallmarkPathGMT/HALLMARK_COAGULATION.v2023.2.Hs.gmt +1 -0
  11. oncoordinate/HallmarkPathGMT/HALLMARK_COMPLEMENT.v2023.2.Hs.gmt +1 -0
  12. oncoordinate/HallmarkPathGMT/HALLMARK_DNA_REPAIR.v2023.2.Hs.gmt +1 -0
  13. oncoordinate/HallmarkPathGMT/HALLMARK_E2F_TARGETS.v2023.2.Hs.gmt +1 -0
  14. oncoordinate/HallmarkPathGMT/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v2023.2.Hs.gmt +1 -0
  15. oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_EARLY.v2023.2.Hs.gmt +1 -0
  16. oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_LATE.v2023.2.Hs.gmt +1 -0
  17. oncoordinate/HallmarkPathGMT/HALLMARK_FATTY_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
  18. oncoordinate/HallmarkPathGMT/HALLMARK_G2M_CHECKPOINT.v2023.2.Hs.gmt +1 -0
  19. oncoordinate/HallmarkPathGMT/HALLMARK_GLYCOLYSIS.v2023.2.Hs.gmt +1 -0
  20. oncoordinate/HallmarkPathGMT/HALLMARK_HEDGEHOG_SIGNALING.v2023.2.Hs.gmt +1 -0
  21. oncoordinate/HallmarkPathGMT/HALLMARK_HEME_METABOLISM.v2023.2.Hs.gmt +1 -0
  22. oncoordinate/HallmarkPathGMT/HALLMARK_HYPOXIA.v2023.2.Hs.gmt +1 -0
  23. oncoordinate/HallmarkPathGMT/HALLMARK_IL2_STAT5_SIGNALING.v2023.2.Hs.gmt +1 -0
  24. oncoordinate/HallmarkPathGMT/HALLMARK_IL6_JAK_STAT3_SIGNALING.v2023.2.Hs.gmt +1 -0
  25. oncoordinate/HallmarkPathGMT/HALLMARK_INFLAMMATORY_RESPONSE.v2023.2.Hs.gmt +1 -0
  26. oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_ALPHA_RESPONSE.v2023.2.Hs.gmt +1 -0
  27. oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_GAMMA_RESPONSE.v2023.2.Hs.gmt +1 -0
  28. oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_DN.v2023.2.Hs.gmt +1 -0
  29. oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_UP.v2023.2.Hs.gmt +1 -0
  30. oncoordinate/HallmarkPathGMT/HALLMARK_MITOTIC_SPINDLE.v2023.2.Hs.gmt +1 -0
  31. oncoordinate/HallmarkPathGMT/HALLMARK_MTORC1_SIGNALING.v2023.2.Hs.gmt +1 -0
  32. oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V1.v2023.2.Hs.gmt +1 -0
  33. oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V2.v2023.2.Hs.gmt +1 -0
  34. oncoordinate/HallmarkPathGMT/HALLMARK_MYOGENESIS.v2023.2.Hs.gmt +1 -0
  35. oncoordinate/HallmarkPathGMT/HALLMARK_NOTCH_SIGNALING.v2023.2.Hs.gmt +1 -0
  36. oncoordinate/HallmarkPathGMT/HALLMARK_OXIDATIVE_PHOSPHORYLATION.v2023.2.Hs.gmt +1 -0
  37. oncoordinate/HallmarkPathGMT/HALLMARK_P53_PATHWAY.v2023.2.Hs.gmt +1 -0
  38. oncoordinate/HallmarkPathGMT/HALLMARK_PANCREAS_BETA_CELLS.v2023.2.Hs.gmt +1 -0
  39. oncoordinate/HallmarkPathGMT/HALLMARK_PEROXISOME.v2023.2.Hs.gmt +1 -0
  40. oncoordinate/HallmarkPathGMT/HALLMARK_PI3K_AKT_MTOR_SIGNALING.v2023.2.Hs.gmt +1 -0
  41. oncoordinate/HallmarkPathGMT/HALLMARK_PROTEIN_SECRETION.v2023.2.Hs.gmt +1 -0
  42. oncoordinate/HallmarkPathGMT/HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY.v2023.2.Hs.gmt +1 -0
  43. oncoordinate/HallmarkPathGMT/HALLMARK_SPERMATOGENESIS.v2023.2.Hs.gmt +1 -0
  44. oncoordinate/HallmarkPathGMT/HALLMARK_TGF_BETA_SIGNALING.v2023.2.Hs.gmt +1 -0
  45. oncoordinate/HallmarkPathGMT/HALLMARK_TNFA_SIGNALING_VIA_NFKB.v2023.2.Hs.gmt +1 -0
  46. oncoordinate/HallmarkPathGMT/HALLMARK_UNFOLDED_PROTEIN_RESPONSE.v2023.2.Hs.gmt +1 -0
  47. oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_DN.v2023.2.Hs.gmt +1 -0
  48. oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_UP.v2023.2.Hs.gmt +1 -0
  49. oncoordinate/HallmarkPathGMT/HALLMARK_WNT_BETA_CATENIN_SIGNALING.v2023.2.Hs.gmt +1 -0
  50. oncoordinate/HallmarkPathGMT/HALLMARK_XENOBIOTIC_METABOLISM.v2023.2.Hs.gmt +1 -0
  51. oncoordinate/HallmarkPathGMT/REACTOME_CELL_CYCLE.v2023.2.Hs.gmt +1 -0
  52. oncoordinate/HallmarkPathGMT/REACTOME_SIGNALING_BY_EGFR_IN_CANCER.v2023.2.Hs.gmt +1 -0
  53. oncoordinate/__init__.py +0 -0
  54. oncoordinate/lt.py +472 -0
  55. oncoordinate/oncoordinate.joblib +0 -0
  56. oncoordinate/sc.py +729 -0
  57. oncoordinate/sp.py +513 -0
  58. oncoordinate-0.1.7.dist-info/METADATA +93 -0
  59. oncoordinate-0.1.7.dist-info/RECORD +62 -0
  60. oncoordinate-0.1.7.dist-info/WHEEL +5 -0
  61. oncoordinate-0.1.7.dist-info/licenses/LICENSE +21 -0
  62. oncoordinate-0.1.7.dist-info/top_level.txt +1 -0
oncoordinate/sc.py ADDED
@@ -0,0 +1,729 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from importlib import resources
5
+ from pathlib import Path
6
+ from typing import Iterable, Dict, Optional, Union, List
7
+
8
+ import anndata as ad
9
+ import joblib
10
+ import numpy as np
11
+ import pandas as pd
12
+ import scanpy as _sc
13
+ _sc.settings.verbosity = 0
14
+
15
+ from scipy import sparse as sp
16
+ from sklearn.preprocessing import MinMaxScaler
17
+ import os
18
+ import inspect
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.getLogger().setLevel(logging.ERROR)
22
+ ad.settings.allow_write_nullable_strings = True
23
+
24
+ def _configure_debug_log():
25
+ try:
26
+ import __main__
27
+ main_file = getattr(__main__, "__file__", None)
28
+ if main_file:
29
+ log_dir = Path(main_file).resolve().parent
30
+ else:
31
+ log_dir = Path.cwd()
32
+ except Exception:
33
+ log_dir = Path.cwd()
34
+
35
+ log_path = log_dir / "debug_log.txt"
36
+ pkg_logger = logging.getLogger("oncoordinate")
37
+ pkg_logger.setLevel(logging.DEBUG)
38
+
39
+ for h in pkg_logger.handlers:
40
+ if isinstance(h, logging.FileHandler):
41
+ try:
42
+ if Path(getattr(h, "baseFilename", "")) == log_path:
43
+ break
44
+ except Exception:
45
+ continue
46
+ else:
47
+ try:
48
+ handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
49
+ formatter = logging.Formatter(
50
+ "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
51
+ )
52
+ handler.setFormatter(formatter)
53
+ handler.setLevel(logging.DEBUG)
54
+ pkg_logger.addHandler(handler)
55
+ pkg_logger.debug("Initialized oncoordinate debug logging to %s", log_path)
56
+ except Exception as e:
57
+ logger.warning("Failed to configure debug_log.txt logging: %s", e)
58
+
59
+ _configure_debug_log()
60
+
61
+ def _iter_gmt_files() -> Iterable[Path]:
62
+ try:
63
+ pkg_root = resources.files("oncoordinate").joinpath("HallmarkPathGMT")
64
+ for res in pkg_root.iterdir():
65
+ if res.name.endswith(".gmt"):
66
+ with resources.as_file(res) as p:
67
+ yield Path(p)
68
+ return
69
+ except Exception:
70
+ pass
71
+
72
+ local_root = Path(__file__).resolve().parent / "HallmarkPathGMT"
73
+ if local_root.is_dir():
74
+ for p in sorted(local_root.glob("*.gmt")):
75
+ yield p
76
+ else:
77
+ logger.warning("HallmarkPathGMT directory not found; no pathway scores will be computed.")
78
+
79
+ def _safe_joblib_load(path: Path, device: Optional[str] = None) -> Dict:
80
+ desired_raw = device or "auto"
81
+ desired = desired_raw.strip().lower() if isinstance(desired_raw, str) else "auto"
82
+ if desired == "gpu":
83
+ desired = "cuda"
84
+
85
+ try:
86
+ import torch
87
+ except Exception as e:
88
+ logger.info(
89
+ "torch import failed or unavailable (%s); attempting joblib.load without torch mapping.",
90
+ e,
91
+ )
92
+ return joblib.load(path)
93
+
94
+ try:
95
+ has_cuda = bool(torch.cuda.is_available())
96
+ except Exception as e:
97
+ logger.info("torch.cuda.is_available() check failed (%s); assuming CUDA unavailable.", e)
98
+ has_cuda = False
99
+
100
+ if desired == "cpu":
101
+ use_gpu = False
102
+ elif desired == "cuda":
103
+ use_gpu = has_cuda
104
+ else:
105
+ use_gpu = has_cuda
106
+
107
+ if use_gpu:
108
+ logger.info(
109
+ "Loading %s preferring GPU (device=%s, torch.cuda.is_available=%s).",
110
+ path.name,
111
+ desired,
112
+ has_cuda,
113
+ )
114
+ return joblib.load(path)
115
+
116
+ logger.info(
117
+ "Forcing CPU load for %s (device=%s). Will patch torch.load to set map_location='cpu' during joblib.load.",
118
+ path.name,
119
+ desired,
120
+ )
121
+ orig_torch_load = getattr(torch, "load", None)
122
+
123
+ def _torch_load_cpu(*args, **kwargs):
124
+ kwargs.setdefault("map_location", "cpu")
125
+ return orig_torch_load(*args, **kwargs)
126
+
127
+ if orig_torch_load is None:
128
+ logger.info(
129
+ "torch.load not found; falling back to joblib.load() without torch map_location patch."
130
+ )
131
+ return joblib.load(path)
132
+
133
+ try:
134
+ torch.load = _torch_load_cpu
135
+ return joblib.load(path)
136
+ finally:
137
+ try:
138
+ torch.load = orig_torch_load
139
+ except Exception:
140
+ logger.warning("Could not restore original torch.load after joblib.load().")
141
+
142
+ def _load_model_bundle(device: Optional[str] = None) -> Dict:
143
+ try:
144
+ res = resources.files("oncoordinate").joinpath("oncoordinate.joblib")
145
+ with resources.as_file(res) as p:
146
+ path = Path(p)
147
+ if path.is_file():
148
+ return _safe_joblib_load(path, device=device)
149
+ except Exception:
150
+ pass
151
+
152
+ local_path = Path(__file__).resolve().parent / "oncoordinate.joblib"
153
+ if not local_path.is_file():
154
+ raise FileNotFoundError(
155
+ "Could not find 'oncoordinate.joblib' in package resources or next to sc.py"
156
+ )
157
+
158
+ return _safe_joblib_load(local_path, device=device)
159
+
160
+ def _read_gmt(path: Path) -> Dict[str, List[str]]:
161
+ gene_sets: Dict[str, List[str]] = {}
162
+ with path.open() as f:
163
+ for line in f:
164
+ line = line.strip()
165
+ if not line or line.startswith("#"):
166
+ continue
167
+ parts = line.split("\t")
168
+ if len(parts) < 3:
169
+ continue
170
+ name = parts[0]
171
+ genes = parts[2:]
172
+ gene_sets[name] = genes
173
+ return gene_sets
174
+
175
+ class sc:
176
+ def __init__(
177
+ self,
178
+ adata: Optional[ad.AnnData] = None,
179
+ *,
180
+ sample_key: str = "sample",
181
+ batch_key: Optional[str] = None,
182
+ user_celltype_vectors: Optional[pd.DataFrame] = None,
183
+ device: Optional[str] = None,
184
+ ):
185
+
186
+ self.adata = adata
187
+ self.sample_key = sample_key
188
+ self.batch_key = batch_key
189
+ self.device = (device.strip().lower() if isinstance(device, str) else None) or "auto"
190
+
191
+ if self.adata is not None:
192
+ self.adata.obs.index = self.adata.obs.index.astype(str)
193
+ self.adata.var.index = self.adata.var.index.astype(str)
194
+
195
+ if self.adata is not None and user_celltype_vectors is not None:
196
+ inter = user_celltype_vectors.loc[
197
+ user_celltype_vectors.index.intersection(self.adata.obs_names)
198
+ ].copy()
199
+ inter = inter.apply(pd.to_numeric, errors="coerce").fillna(0.0)
200
+ inter.columns = [
201
+ c if str(c).startswith("ctv_") else f"ctv_{c}" for c in inter.columns
202
+ ]
203
+ self.adata.obs = self.adata.obs.join(inter, how="left")
204
+
205
+ @staticmethod
206
+ def _ensure_cpu_safe_model(model, device: Optional[str] = None):
207
+ try:
208
+ import torch
209
+ except Exception:
210
+ return model
211
+
212
+ requested = device or "auto"
213
+ if isinstance(requested, str):
214
+ requested = requested.strip().lower()
215
+ else:
216
+ requested = "auto"
217
+
218
+ try:
219
+ has_cuda = bool(torch.cuda.is_available())
220
+ except Exception:
221
+ has_cuda = False
222
+
223
+ force_cpu = (requested == "cpu") or (not has_cuda)
224
+ if not force_cpu:
225
+ return model
226
+
227
+ cls_name = type(model).__name__.lower()
228
+ if cls_name.startswith("tabnet"):
229
+ try:
230
+ model.device_name = "cpu"
231
+ except Exception:
232
+ pass
233
+ try:
234
+ model.device = torch.device("cpu")
235
+ except Exception:
236
+ pass
237
+
238
+ net = getattr(model, "network", None)
239
+ if net is not None:
240
+ try:
241
+ net.to(torch.device("cpu"))
242
+ except Exception:
243
+ pass
244
+
245
+ return model
246
+
247
+ @staticmethod
248
+ def _preprocess(
249
+ adata: ad.AnnData,
250
+ *,
251
+ min_genes: int = 200,
252
+ min_cells: int = 10,
253
+ ) -> ad.AnnData:
254
+
255
+ adata = adata.copy()
256
+ adata.obs.index = adata.obs.index.astype(str)
257
+ adata.var.index = adata.var.index.astype(str)
258
+ adata.var_names = adata.var_names.astype(str)
259
+ adata.var["mt"] = adata.var_names.str.upper().str.startswith("MT-")
260
+
261
+ _sc.pp.calculate_qc_metrics(
262
+ adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
263
+ )
264
+
265
+ _sc.pp.filter_cells(adata, min_genes=min_genes)
266
+
267
+ if min_cells is not None and min_cells > 0 and adata.n_obs > 0:
268
+ if sp.issparse(adata.X):
269
+ gene_nonzero = np.asarray((adata.X > 0).sum(axis=0)).ravel()
270
+ else:
271
+ gene_nonzero = np.count_nonzero(adata.X > 0, axis=0)
272
+
273
+ survivors = gene_nonzero >= min_cells
274
+ n_survivors = int(survivors.sum())
275
+
276
+ if n_survivors == 0:
277
+ relaxed = gene_nonzero >= 1
278
+ n_relaxed = int(relaxed.sum())
279
+ if n_relaxed == 0:
280
+ logger.warning(
281
+ "After filter_cells (min_genes=%d), all genes have zero counts "
282
+ "for this sample (n_obs=%d). Skipping gene filtering.",
283
+ min_genes,
284
+ adata.n_obs,
285
+ )
286
+ else:
287
+ logger.warning(
288
+ "filter_genes(min_cells=%d) would remove ALL genes for this "
289
+ "sample (n_obs=%d). Relaxing to min_cells=1; keeping %d genes.",
290
+ min_cells,
291
+ adata.n_obs,
292
+ n_relaxed,
293
+ )
294
+ adata = adata[:, relaxed]
295
+ else:
296
+ adata = adata[:, survivors]
297
+
298
+ if adata.n_vars == 0:
299
+ raise ValueError(
300
+ f"After preprocessing (min_genes={min_genes}, min_cells={min_cells}), "
301
+ f"no genes remain for this sample (n_obs={adata.n_obs}). "
302
+ "Try lowering min_cells and/or min_genes."
303
+ )
304
+
305
+ if sp.issparse(adata.X):
306
+ if hasattr(adata.X, "data"):
307
+ adata.X.data = np.nan_to_num(adata.X.data)
308
+ else:
309
+ adata.X = np.nan_to_num(adata.X)
310
+
311
+ if adata.is_view:
312
+ adata = adata.copy()
313
+
314
+ _sc.pp.normalize_total(adata, target_sum=1e4)
315
+ _sc.pp.log1p(adata)
316
+
317
+ return adata
318
+
319
+ def _run_embeddings(
320
+ self,
321
+ adata: ad.AnnData,
322
+ *,
323
+ pca_n_comps: Optional[int] = None,
324
+ neighbors_n_pcs: Optional[int] = None,
325
+ neighbors_k: Optional[int] = None,
326
+ use_batch_correction: bool = False,
327
+ ) -> ad.AnnData:
328
+
329
+ adata = adata.copy()
330
+
331
+ max_rank = max(0, min(adata.n_obs, adata.n_vars) - 1)
332
+ n_comps = int(pca_n_comps) if pca_n_comps is not None else max(
333
+ 2, min(50, max_rank)
334
+ )
335
+ _sc.tl.pca(adata, n_comps=n_comps, svd_solver="arpack")
336
+
337
+ n_pcs_avail = int(adata.obsm["X_pca"].shape[1])
338
+ if n_pcs_avail < 2:
339
+ raise ValueError(
340
+ f"Too few PCs after preprocessing (got {n_pcs_avail}). "
341
+ "Consider lowering filtering thresholds or increasing pca_n_comps."
342
+ )
343
+
344
+ n_pcs_use = int(neighbors_n_pcs) if neighbors_n_pcs is not None else min(
345
+ 40, n_pcs_avail
346
+ )
347
+ if n_pcs_use > n_pcs_avail:
348
+ raise ValueError(
349
+ f"Requested neighbors_n_pcs={n_pcs_use} > available PCs ({n_pcs_avail})."
350
+ )
351
+
352
+ k = int(neighbors_k) if neighbors_k is not None else min(
353
+ 15, max(2, adata.n_obs - 1)
354
+ )
355
+ k = max(2, min(k, max(2, adata.n_obs - 1)))
356
+
357
+ batch_key = self.batch_key or self.sample_key
358
+ if use_batch_correction and batch_key in adata.obs.columns:
359
+ try:
360
+ import scanpy.external as sce
361
+
362
+ logger.info(
363
+ f"Running BBKNN batch correction with batch_key='{batch_key}'"
364
+ )
365
+ sce.pp.bbknn(adata, batch_key=batch_key, n_pcs=n_pcs_use)
366
+ except Exception as e:
367
+ logger.warning(
368
+ f"BBKNN batch correction failed ({e}); falling back to standard neighbors."
369
+ )
370
+ _sc.pp.neighbors(
371
+ adata,
372
+ n_neighbors=k,
373
+ n_pcs=n_pcs_use,
374
+ use_rep="X_pca",
375
+ )
376
+ else:
377
+ _sc.pp.neighbors(
378
+ adata,
379
+ n_neighbors=k,
380
+ n_pcs=n_pcs_use,
381
+ use_rep="X_pca",
382
+ )
383
+
384
+ _sc.tl.umap(adata)
385
+ _sc.tl.diffmap(adata)
386
+
387
+ return adata
388
+
389
+ @staticmethod
390
+ def _score_pathways(adata: ad.AnnData) -> ad.AnnData:
391
+ adata = adata.copy()
392
+ adata.var_names = adata.var_names.astype(str).str.upper()
393
+ adata.var_names_make_unique()
394
+ genes_in_adata = list(adata.var_names)
395
+
396
+ score_genes = _sc.tl.score_genes
397
+ try:
398
+ sig = inspect.signature(score_genes)
399
+ supports_ctrl_as_ref = "ctrl_as_ref" in sig.parameters
400
+ except Exception:
401
+ supports_ctrl_as_ref = False
402
+
403
+ for gmt_path in _iter_gmt_files():
404
+ logger.info(f"Scoring pathways from {gmt_path.name}")
405
+ gene_sets = _read_gmt(gmt_path)
406
+
407
+ for pathway, genes in gene_sets.items():
408
+ if not genes:
409
+ continue
410
+
411
+ genes_upper = [g.upper() for g in genes]
412
+ genes_inter = [g for g in genes_upper if g in genes_in_adata]
413
+ if not genes_inter:
414
+ continue
415
+
416
+ kwargs = dict(
417
+ gene_list=genes_inter,
418
+ score_name=pathway,
419
+ use_raw=False,
420
+ )
421
+
422
+ if supports_ctrl_as_ref:
423
+ kwargs["ctrl_as_ref"] = False
424
+
425
+ try:
426
+ score_genes(adata, **kwargs)
427
+ except RuntimeError as e:
428
+ msg = str(e)
429
+ if "No control genes found in any cut" in msg:
430
+ logger.warning(
431
+ "score_genes failed for pathway '%s' due to lack of control "
432
+ "genes (message: %s). Skipping this pathway. "
433
+ "Consider adjusting gene sets or ctrl_size if needed.",
434
+ pathway,
435
+ msg,
436
+ )
437
+ continue
438
+ raise
439
+
440
+ return adata
441
+
442
+ @staticmethod
443
+ def _build_feature_matrix(
444
+ obs: pd.DataFrame,
445
+ feature_names: List[str],
446
+ ) -> pd.DataFrame:
447
+ X = pd.DataFrame(index=obs.index)
448
+
449
+ for feat in feature_names:
450
+ if feat in obs.columns:
451
+ col = obs[feat]
452
+ else:
453
+ col = 0.0
454
+
455
+ col = pd.to_numeric(col, errors="coerce").fillna(0.0).astype(np.float32)
456
+ X[feat] = col
457
+
458
+ return X
459
+
460
+ def _annotate_with_model(self, adata: ad.AnnData, device: Optional[str] = None) -> ad.AnnData:
461
+ adata = adata.copy()
462
+
463
+ bundle = _load_model_bundle(device=device)
464
+ model = bundle["model"]
465
+ feature_list = list(bundle.get("features", []))
466
+ scaler = bundle.get("scaler", None)
467
+ label_map = bundle.get(
468
+ "label_map",
469
+ {0: "normal", 1: "abnormal", 2: "pre-malignant", 3: "malignant"},
470
+ )
471
+
472
+ if not feature_list:
473
+ logger.warning(
474
+ "Model bundle has empty 'features'; skipping oncoordinate annotation."
475
+ )
476
+ return adata
477
+
478
+ meta = adata.obs.copy()
479
+ X_df = self._build_feature_matrix(meta, feature_list)
480
+ X_df_float = X_df.astype(np.float32)
481
+
482
+ if scaler is not None:
483
+ try:
484
+ X_scaled = scaler.transform(X_df_float)
485
+ except Exception as e:
486
+ logger.warning(
487
+ f"Scaler transform failed ({e}); falling back to MinMaxScaler fit on the fly."
488
+ )
489
+ X_scaled = MinMaxScaler().fit_transform(X_df_float)
490
+ else:
491
+ X_scaled = MinMaxScaler().fit_transform(X_df_float)
492
+
493
+ X_scaled = np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0)
494
+
495
+ base_model = model
496
+ if hasattr(base_model, "best_estimator_"):
497
+ base_model = base_model.best_estimator_
498
+
499
+ base_model = self._ensure_cpu_safe_model(base_model, device=device)
500
+
501
+ if not hasattr(base_model, "predict_proba"):
502
+ raise RuntimeError(
503
+ "Loaded oncoordinate model does not support predict_proba(). "
504
+ "Cannot annotate cells."
505
+ )
506
+
507
+ proba = np.asarray(base_model.predict_proba(X_scaled), dtype=float)
508
+
509
+ if proba.ndim != 2:
510
+ raise ValueError(
511
+ f"Model.predict_proba returned array with shape {proba.shape}, expected 2D."
512
+ )
513
+
514
+ classes_attr = getattr(base_model, "classes_", None)
515
+ if classes_attr is not None:
516
+ try:
517
+ class_names = [label_map.get(int(c), str(c)) for c in classes_attr]
518
+ except Exception:
519
+ class_names = [str(c) for c in classes_attr]
520
+ else:
521
+ class_names = [
522
+ label_map.get(i, str(i)) for i in range(proba.shape[1])
523
+ ]
524
+
525
+ pred_idx = np.argmax(proba, axis=1)
526
+ pred_labels = [class_names[i] for i in pred_idx]
527
+
528
+ adata.obs["oncoordinate_stage_idx"] = pred_idx.astype(int)
529
+ adata.obs["oncoordinate_stage"] = pd.Categorical(
530
+ pred_labels,
531
+ categories=class_names,
532
+ ordered=True,
533
+ )
534
+
535
+ for j, cname in enumerate(class_names):
536
+ col_name = f"oncoordinate_proba_{cname}"
537
+ adata.obs[col_name] = proba[:, j]
538
+
539
+ return adata
540
+
541
+ def _process_one_sample(
542
+ self,
543
+ adata_s: ad.AnnData,
544
+ *,
545
+ pca_n_comps: Optional[int],
546
+ neighbors_n_pcs: Optional[int],
547
+ neighbors_k: Optional[int],
548
+ min_cells: int,
549
+ min_genes: int,
550
+ use_batch_correction: bool,
551
+ device: Optional[str] = None,
552
+ run_embeddings: bool = True,
553
+ ) -> ad.AnnData:
554
+
555
+ ad = self._preprocess(
556
+ adata_s,
557
+ min_cells=min_cells,
558
+ min_genes=min_genes,
559
+ )
560
+
561
+ if run_embeddings:
562
+ ad = self._run_embeddings(
563
+ ad,
564
+ pca_n_comps=pca_n_comps,
565
+ neighbors_n_pcs=neighbors_n_pcs,
566
+ neighbors_k=neighbors_k,
567
+ use_batch_correction=use_batch_correction,
568
+ )
569
+
570
+ ad = self._score_pathways(ad)
571
+ ad = self._annotate_with_model(ad, device=device)
572
+
573
+ ad.uns.setdefault("oncoordinate_params", {})
574
+ ad.uns["oncoordinate_params"].update(
575
+ dict(
576
+ min_cells=int(min_cells),
577
+ min_genes=int(min_genes),
578
+ device=device or self.device,
579
+ )
580
+ )
581
+
582
+ return ad
583
+
584
+
585
+ def annotate(
586
+ self,
587
+ *,
588
+ sample_key: Optional[str] = None,
589
+ pca_n_comps: Optional[int] = None,
590
+ neighbors_n_pcs: Optional[int] = None,
591
+ neighbors_k: Optional[int] = None,
592
+ min_cells: int = 3,
593
+ min_genes: int = 200,
594
+ use_batch_correction: bool = False,
595
+ save_path: Optional[Union[str, Path]] = None,
596
+ device: Optional[str] = None,
597
+ ) -> ad.AnnData:
598
+ if self.adata is None:
599
+ raise AttributeError("OnCoordinateSC.adata is None – please provide an AnnData.")
600
+
601
+ adata = self.adata.copy()
602
+ adata.obs.index = adata.obs.index.astype(str)
603
+ adata.var.index = adata.var.index.astype(str)
604
+
605
+ skey = sample_key or self.sample_key
606
+
607
+ if skey not in adata.obs.columns:
608
+ logger.info(
609
+ "'%s' not found in adata.obs – processing as a single sample.",
610
+ skey,
611
+ )
612
+ annotated = self._process_one_sample(
613
+ adata,
614
+ pca_n_comps=pca_n_comps,
615
+ neighbors_n_pcs=neighbors_n_pcs,
616
+ neighbors_k=neighbors_k,
617
+ min_cells=min_cells,
618
+ min_genes=min_genes,
619
+ use_batch_correction=use_batch_correction,
620
+ device=(device or self.device),
621
+ )
622
+ else:
623
+ col = adata.obs[skey].astype(str).str.strip()
624
+ adata.obs[skey] = col
625
+
626
+ unique_samples = col.unique().tolist()
627
+
628
+ if len(unique_samples) == 1:
629
+ logger.info(
630
+ "Single sample detected (sample_key='%s'). Processing whole AnnData.",
631
+ skey,
632
+ )
633
+ annotated = self._process_one_sample(
634
+ adata,
635
+ pca_n_comps=pca_n_comps,
636
+ neighbors_n_pcs=neighbors_n_pcs,
637
+ neighbors_k=neighbors_k,
638
+ min_cells=min_cells,
639
+ min_genes=min_genes,
640
+ use_batch_correction=use_batch_correction,
641
+ device=(device or self.device),
642
+ )
643
+ else:
644
+ logger.info(
645
+ "Multiple samples detected in '%s': %s (use_batch_correction=%s)",
646
+ skey,
647
+ unique_samples,
648
+ use_batch_correction,
649
+ )
650
+ annotated_list: List[ad.AnnData] = []
651
+
652
+ for sid in unique_samples:
653
+ mask = adata.obs[skey] == sid
654
+ if mask.sum() == 0:
655
+ continue
656
+ ad_s = adata[mask, :].copy()
657
+
658
+ if use_batch_correction:
659
+ ann_s = self._process_one_sample(
660
+ ad_s,
661
+ pca_n_comps=pca_n_comps,
662
+ neighbors_n_pcs=neighbors_n_pcs,
663
+ neighbors_k=neighbors_k,
664
+ min_cells=min_cells,
665
+ min_genes=min_genes,
666
+ use_batch_correction=False,
667
+ device=(device or self.device),
668
+ run_embeddings=False,
669
+ )
670
+ else:
671
+ ann_s = self._process_one_sample(
672
+ ad_s,
673
+ pca_n_comps=pca_n_comps,
674
+ neighbors_n_pcs=neighbors_n_pcs,
675
+ neighbors_k=neighbors_k,
676
+ min_cells=min_cells,
677
+ min_genes=min_genes,
678
+ use_batch_correction=False,
679
+ device=(device or self.device),
680
+ )
681
+
682
+ ann_s.obs[skey] = sid
683
+ annotated_list.append(ann_s)
684
+
685
+ if not annotated_list:
686
+ raise ValueError(
687
+ f"No samples contained any cells after filtering (sample_key='{skey}')."
688
+ )
689
+
690
+ annotated = ad.concat(
691
+ annotated_list,
692
+ label="__sample__",
693
+ keys=[str(a.obs[skey].unique()[0]) for a in annotated_list],
694
+ join="outer",
695
+ index_unique=None,
696
+ )
697
+
698
+ if use_batch_correction:
699
+ batch_key = self.batch_key or skey
700
+ logger.info(
701
+ "Running BBKNN batch correction once on concatenated AnnData "
702
+ "(n_obs=%d, n_vars=%d, batch_key='%s').",
703
+ annotated.n_obs,
704
+ annotated.n_vars,
705
+ batch_key,
706
+ )
707
+ annotated = self._run_embeddings(
708
+ annotated,
709
+ pca_n_comps=pca_n_comps,
710
+ neighbors_n_pcs=neighbors_n_pcs,
711
+ neighbors_k=neighbors_k,
712
+ use_batch_correction=True,
713
+ )
714
+ annotated.uns.setdefault("oncoordinate_params", {})
715
+ annotated.uns["oncoordinate_params"].update(
716
+ dict(
717
+ min_cells=int(min_cells),
718
+ min_genes=int(min_genes),
719
+ device=device or self.device,
720
+ )
721
+ )
722
+
723
+ if save_path is not None:
724
+ save_path = Path(save_path)
725
+ save_path.parent.mkdir(parents=True, exist_ok=True)
726
+ annotated.write_h5ad(save_path)
727
+ logger.info("Wrote oncoordinate-annotated AnnData to: %s", save_path)
728
+
729
+ return annotated