oncoordinate 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oncoordinate/HallmarkPathGMT/HALLMARK_ADIPOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ALLOGRAFT_REJECTION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ANDROGEN_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ANGIOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_JUNCTION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_SURFACE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APOPTOSIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_BILE_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_CHOLESTEROL_HOMEOSTASIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_COAGULATION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_COMPLEMENT.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_DNA_REPAIR.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_E2F_TARGETS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_EARLY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_LATE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_FATTY_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_G2M_CHECKPOINT.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_GLYCOLYSIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HEDGEHOG_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HEME_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HYPOXIA.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_IL2_STAT5_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_IL6_JAK_STAT3_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INFLAMMATORY_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_ALPHA_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_GAMMA_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_DN.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_UP.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MITOTIC_SPINDLE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MTORC1_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V1.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V2.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_NOTCH_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_OXIDATIVE_PHOSPHORYLATION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_P53_PATHWAY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PANCREAS_BETA_CELLS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PEROXISOME.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PI3K_AKT_MTOR_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PROTEIN_SECRETION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_SPERMATOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_TGF_BETA_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_TNFA_SIGNALING_VIA_NFKB.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UNFOLDED_PROTEIN_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_DN.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_UP.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_WNT_BETA_CATENIN_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_XENOBIOTIC_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/REACTOME_CELL_CYCLE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/REACTOME_SIGNALING_BY_EGFR_IN_CANCER.v2023.2.Hs.gmt +1 -0
- oncoordinate/__init__.py +0 -0
- oncoordinate/lt.py +472 -0
- oncoordinate/oncoordinate.joblib +0 -0
- oncoordinate/sc.py +729 -0
- oncoordinate/sp.py +513 -0
- oncoordinate-0.1.7.dist-info/METADATA +93 -0
- oncoordinate-0.1.7.dist-info/RECORD +62 -0
- oncoordinate-0.1.7.dist-info/WHEEL +5 -0
- oncoordinate-0.1.7.dist-info/licenses/LICENSE +21 -0
- oncoordinate-0.1.7.dist-info/top_level.txt +1 -0
oncoordinate/sc.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from importlib import resources
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Dict, Optional, Union, List
|
|
7
|
+
|
|
8
|
+
import anndata as ad
|
|
9
|
+
import joblib
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import scanpy as _sc
|
|
13
|
+
_sc.settings.verbosity = 0
|
|
14
|
+
|
|
15
|
+
from scipy import sparse as sp
|
|
16
|
+
from sklearn.preprocessing import MinMaxScaler
|
|
17
|
+
import os
|
|
18
|
+
import inspect
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
logging.getLogger().setLevel(logging.ERROR)
|
|
22
|
+
ad.settings.allow_write_nullable_strings = True
|
|
23
|
+
|
|
24
|
+
def _configure_debug_log():
|
|
25
|
+
try:
|
|
26
|
+
import __main__
|
|
27
|
+
main_file = getattr(__main__, "__file__", None)
|
|
28
|
+
if main_file:
|
|
29
|
+
log_dir = Path(main_file).resolve().parent
|
|
30
|
+
else:
|
|
31
|
+
log_dir = Path.cwd()
|
|
32
|
+
except Exception:
|
|
33
|
+
log_dir = Path.cwd()
|
|
34
|
+
|
|
35
|
+
log_path = log_dir / "debug_log.txt"
|
|
36
|
+
pkg_logger = logging.getLogger("oncoordinate")
|
|
37
|
+
pkg_logger.setLevel(logging.DEBUG)
|
|
38
|
+
|
|
39
|
+
for h in pkg_logger.handlers:
|
|
40
|
+
if isinstance(h, logging.FileHandler):
|
|
41
|
+
try:
|
|
42
|
+
if Path(getattr(h, "baseFilename", "")) == log_path:
|
|
43
|
+
break
|
|
44
|
+
except Exception:
|
|
45
|
+
continue
|
|
46
|
+
else:
|
|
47
|
+
try:
|
|
48
|
+
handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
|
49
|
+
formatter = logging.Formatter(
|
|
50
|
+
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
51
|
+
)
|
|
52
|
+
handler.setFormatter(formatter)
|
|
53
|
+
handler.setLevel(logging.DEBUG)
|
|
54
|
+
pkg_logger.addHandler(handler)
|
|
55
|
+
pkg_logger.debug("Initialized oncoordinate debug logging to %s", log_path)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.warning("Failed to configure debug_log.txt logging: %s", e)
|
|
58
|
+
|
|
59
|
+
_configure_debug_log()
|
|
60
|
+
|
|
61
|
+
def _iter_gmt_files() -> Iterable[Path]:
|
|
62
|
+
try:
|
|
63
|
+
pkg_root = resources.files("oncoordinate").joinpath("HallmarkPathGMT")
|
|
64
|
+
for res in pkg_root.iterdir():
|
|
65
|
+
if res.name.endswith(".gmt"):
|
|
66
|
+
with resources.as_file(res) as p:
|
|
67
|
+
yield Path(p)
|
|
68
|
+
return
|
|
69
|
+
except Exception:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
local_root = Path(__file__).resolve().parent / "HallmarkPathGMT"
|
|
73
|
+
if local_root.is_dir():
|
|
74
|
+
for p in sorted(local_root.glob("*.gmt")):
|
|
75
|
+
yield p
|
|
76
|
+
else:
|
|
77
|
+
logger.warning("HallmarkPathGMT directory not found; no pathway scores will be computed.")
|
|
78
|
+
|
|
79
|
+
def _safe_joblib_load(path: Path, device: Optional[str] = None) -> Dict:
|
|
80
|
+
desired_raw = device or "auto"
|
|
81
|
+
desired = desired_raw.strip().lower() if isinstance(desired_raw, str) else "auto"
|
|
82
|
+
if desired == "gpu":
|
|
83
|
+
desired = "cuda"
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
import torch
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.info(
|
|
89
|
+
"torch import failed or unavailable (%s); attempting joblib.load without torch mapping.",
|
|
90
|
+
e,
|
|
91
|
+
)
|
|
92
|
+
return joblib.load(path)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
has_cuda = bool(torch.cuda.is_available())
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.info("torch.cuda.is_available() check failed (%s); assuming CUDA unavailable.", e)
|
|
98
|
+
has_cuda = False
|
|
99
|
+
|
|
100
|
+
if desired == "cpu":
|
|
101
|
+
use_gpu = False
|
|
102
|
+
elif desired == "cuda":
|
|
103
|
+
use_gpu = has_cuda
|
|
104
|
+
else:
|
|
105
|
+
use_gpu = has_cuda
|
|
106
|
+
|
|
107
|
+
if use_gpu:
|
|
108
|
+
logger.info(
|
|
109
|
+
"Loading %s preferring GPU (device=%s, torch.cuda.is_available=%s).",
|
|
110
|
+
path.name,
|
|
111
|
+
desired,
|
|
112
|
+
has_cuda,
|
|
113
|
+
)
|
|
114
|
+
return joblib.load(path)
|
|
115
|
+
|
|
116
|
+
logger.info(
|
|
117
|
+
"Forcing CPU load for %s (device=%s). Will patch torch.load to set map_location='cpu' during joblib.load.",
|
|
118
|
+
path.name,
|
|
119
|
+
desired,
|
|
120
|
+
)
|
|
121
|
+
orig_torch_load = getattr(torch, "load", None)
|
|
122
|
+
|
|
123
|
+
def _torch_load_cpu(*args, **kwargs):
|
|
124
|
+
kwargs.setdefault("map_location", "cpu")
|
|
125
|
+
return orig_torch_load(*args, **kwargs)
|
|
126
|
+
|
|
127
|
+
if orig_torch_load is None:
|
|
128
|
+
logger.info(
|
|
129
|
+
"torch.load not found; falling back to joblib.load() without torch map_location patch."
|
|
130
|
+
)
|
|
131
|
+
return joblib.load(path)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
torch.load = _torch_load_cpu
|
|
135
|
+
return joblib.load(path)
|
|
136
|
+
finally:
|
|
137
|
+
try:
|
|
138
|
+
torch.load = orig_torch_load
|
|
139
|
+
except Exception:
|
|
140
|
+
logger.warning("Could not restore original torch.load after joblib.load().")
|
|
141
|
+
|
|
142
|
+
def _load_model_bundle(device: Optional[str] = None) -> Dict:
|
|
143
|
+
try:
|
|
144
|
+
res = resources.files("oncoordinate").joinpath("oncoordinate.joblib")
|
|
145
|
+
with resources.as_file(res) as p:
|
|
146
|
+
path = Path(p)
|
|
147
|
+
if path.is_file():
|
|
148
|
+
return _safe_joblib_load(path, device=device)
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
local_path = Path(__file__).resolve().parent / "oncoordinate.joblib"
|
|
153
|
+
if not local_path.is_file():
|
|
154
|
+
raise FileNotFoundError(
|
|
155
|
+
"Could not find 'oncoordinate.joblib' in package resources or next to sc.py"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return _safe_joblib_load(local_path, device=device)
|
|
159
|
+
|
|
160
|
+
def _read_gmt(path: Path) -> Dict[str, List[str]]:
|
|
161
|
+
gene_sets: Dict[str, List[str]] = {}
|
|
162
|
+
with path.open() as f:
|
|
163
|
+
for line in f:
|
|
164
|
+
line = line.strip()
|
|
165
|
+
if not line or line.startswith("#"):
|
|
166
|
+
continue
|
|
167
|
+
parts = line.split("\t")
|
|
168
|
+
if len(parts) < 3:
|
|
169
|
+
continue
|
|
170
|
+
name = parts[0]
|
|
171
|
+
genes = parts[2:]
|
|
172
|
+
gene_sets[name] = genes
|
|
173
|
+
return gene_sets
|
|
174
|
+
|
|
175
|
+
class sc:
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
adata: Optional[ad.AnnData] = None,
|
|
179
|
+
*,
|
|
180
|
+
sample_key: str = "sample",
|
|
181
|
+
batch_key: Optional[str] = None,
|
|
182
|
+
user_celltype_vectors: Optional[pd.DataFrame] = None,
|
|
183
|
+
device: Optional[str] = None,
|
|
184
|
+
):
|
|
185
|
+
|
|
186
|
+
self.adata = adata
|
|
187
|
+
self.sample_key = sample_key
|
|
188
|
+
self.batch_key = batch_key
|
|
189
|
+
self.device = (device.strip().lower() if isinstance(device, str) else None) or "auto"
|
|
190
|
+
|
|
191
|
+
if self.adata is not None:
|
|
192
|
+
self.adata.obs.index = self.adata.obs.index.astype(str)
|
|
193
|
+
self.adata.var.index = self.adata.var.index.astype(str)
|
|
194
|
+
|
|
195
|
+
if self.adata is not None and user_celltype_vectors is not None:
|
|
196
|
+
inter = user_celltype_vectors.loc[
|
|
197
|
+
user_celltype_vectors.index.intersection(self.adata.obs_names)
|
|
198
|
+
].copy()
|
|
199
|
+
inter = inter.apply(pd.to_numeric, errors="coerce").fillna(0.0)
|
|
200
|
+
inter.columns = [
|
|
201
|
+
c if str(c).startswith("ctv_") else f"ctv_{c}" for c in inter.columns
|
|
202
|
+
]
|
|
203
|
+
self.adata.obs = self.adata.obs.join(inter, how="left")
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def _ensure_cpu_safe_model(model, device: Optional[str] = None):
|
|
207
|
+
try:
|
|
208
|
+
import torch
|
|
209
|
+
except Exception:
|
|
210
|
+
return model
|
|
211
|
+
|
|
212
|
+
requested = device or "auto"
|
|
213
|
+
if isinstance(requested, str):
|
|
214
|
+
requested = requested.strip().lower()
|
|
215
|
+
else:
|
|
216
|
+
requested = "auto"
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
has_cuda = bool(torch.cuda.is_available())
|
|
220
|
+
except Exception:
|
|
221
|
+
has_cuda = False
|
|
222
|
+
|
|
223
|
+
force_cpu = (requested == "cpu") or (not has_cuda)
|
|
224
|
+
if not force_cpu:
|
|
225
|
+
return model
|
|
226
|
+
|
|
227
|
+
cls_name = type(model).__name__.lower()
|
|
228
|
+
if cls_name.startswith("tabnet"):
|
|
229
|
+
try:
|
|
230
|
+
model.device_name = "cpu"
|
|
231
|
+
except Exception:
|
|
232
|
+
pass
|
|
233
|
+
try:
|
|
234
|
+
model.device = torch.device("cpu")
|
|
235
|
+
except Exception:
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
net = getattr(model, "network", None)
|
|
239
|
+
if net is not None:
|
|
240
|
+
try:
|
|
241
|
+
net.to(torch.device("cpu"))
|
|
242
|
+
except Exception:
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
return model
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _preprocess(
|
|
249
|
+
adata: ad.AnnData,
|
|
250
|
+
*,
|
|
251
|
+
min_genes: int = 200,
|
|
252
|
+
min_cells: int = 10,
|
|
253
|
+
) -> ad.AnnData:
|
|
254
|
+
|
|
255
|
+
adata = adata.copy()
|
|
256
|
+
adata.obs.index = adata.obs.index.astype(str)
|
|
257
|
+
adata.var.index = adata.var.index.astype(str)
|
|
258
|
+
adata.var_names = adata.var_names.astype(str)
|
|
259
|
+
adata.var["mt"] = adata.var_names.str.upper().str.startswith("MT-")
|
|
260
|
+
|
|
261
|
+
_sc.pp.calculate_qc_metrics(
|
|
262
|
+
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
_sc.pp.filter_cells(adata, min_genes=min_genes)
|
|
266
|
+
|
|
267
|
+
if min_cells is not None and min_cells > 0 and adata.n_obs > 0:
|
|
268
|
+
if sp.issparse(adata.X):
|
|
269
|
+
gene_nonzero = np.asarray((adata.X > 0).sum(axis=0)).ravel()
|
|
270
|
+
else:
|
|
271
|
+
gene_nonzero = np.count_nonzero(adata.X > 0, axis=0)
|
|
272
|
+
|
|
273
|
+
survivors = gene_nonzero >= min_cells
|
|
274
|
+
n_survivors = int(survivors.sum())
|
|
275
|
+
|
|
276
|
+
if n_survivors == 0:
|
|
277
|
+
relaxed = gene_nonzero >= 1
|
|
278
|
+
n_relaxed = int(relaxed.sum())
|
|
279
|
+
if n_relaxed == 0:
|
|
280
|
+
logger.warning(
|
|
281
|
+
"After filter_cells (min_genes=%d), all genes have zero counts "
|
|
282
|
+
"for this sample (n_obs=%d). Skipping gene filtering.",
|
|
283
|
+
min_genes,
|
|
284
|
+
adata.n_obs,
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
logger.warning(
|
|
288
|
+
"filter_genes(min_cells=%d) would remove ALL genes for this "
|
|
289
|
+
"sample (n_obs=%d). Relaxing to min_cells=1; keeping %d genes.",
|
|
290
|
+
min_cells,
|
|
291
|
+
adata.n_obs,
|
|
292
|
+
n_relaxed,
|
|
293
|
+
)
|
|
294
|
+
adata = adata[:, relaxed]
|
|
295
|
+
else:
|
|
296
|
+
adata = adata[:, survivors]
|
|
297
|
+
|
|
298
|
+
if adata.n_vars == 0:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"After preprocessing (min_genes={min_genes}, min_cells={min_cells}), "
|
|
301
|
+
f"no genes remain for this sample (n_obs={adata.n_obs}). "
|
|
302
|
+
"Try lowering min_cells and/or min_genes."
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if sp.issparse(adata.X):
|
|
306
|
+
if hasattr(adata.X, "data"):
|
|
307
|
+
adata.X.data = np.nan_to_num(adata.X.data)
|
|
308
|
+
else:
|
|
309
|
+
adata.X = np.nan_to_num(adata.X)
|
|
310
|
+
|
|
311
|
+
if adata.is_view:
|
|
312
|
+
adata = adata.copy()
|
|
313
|
+
|
|
314
|
+
_sc.pp.normalize_total(adata, target_sum=1e4)
|
|
315
|
+
_sc.pp.log1p(adata)
|
|
316
|
+
|
|
317
|
+
return adata
|
|
318
|
+
|
|
319
|
+
def _run_embeddings(
|
|
320
|
+
self,
|
|
321
|
+
adata: ad.AnnData,
|
|
322
|
+
*,
|
|
323
|
+
pca_n_comps: Optional[int] = None,
|
|
324
|
+
neighbors_n_pcs: Optional[int] = None,
|
|
325
|
+
neighbors_k: Optional[int] = None,
|
|
326
|
+
use_batch_correction: bool = False,
|
|
327
|
+
) -> ad.AnnData:
|
|
328
|
+
|
|
329
|
+
adata = adata.copy()
|
|
330
|
+
|
|
331
|
+
max_rank = max(0, min(adata.n_obs, adata.n_vars) - 1)
|
|
332
|
+
n_comps = int(pca_n_comps) if pca_n_comps is not None else max(
|
|
333
|
+
2, min(50, max_rank)
|
|
334
|
+
)
|
|
335
|
+
_sc.tl.pca(adata, n_comps=n_comps, svd_solver="arpack")
|
|
336
|
+
|
|
337
|
+
n_pcs_avail = int(adata.obsm["X_pca"].shape[1])
|
|
338
|
+
if n_pcs_avail < 2:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Too few PCs after preprocessing (got {n_pcs_avail}). "
|
|
341
|
+
"Consider lowering filtering thresholds or increasing pca_n_comps."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
n_pcs_use = int(neighbors_n_pcs) if neighbors_n_pcs is not None else min(
|
|
345
|
+
40, n_pcs_avail
|
|
346
|
+
)
|
|
347
|
+
if n_pcs_use > n_pcs_avail:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Requested neighbors_n_pcs={n_pcs_use} > available PCs ({n_pcs_avail})."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
k = int(neighbors_k) if neighbors_k is not None else min(
|
|
353
|
+
15, max(2, adata.n_obs - 1)
|
|
354
|
+
)
|
|
355
|
+
k = max(2, min(k, max(2, adata.n_obs - 1)))
|
|
356
|
+
|
|
357
|
+
batch_key = self.batch_key or self.sample_key
|
|
358
|
+
if use_batch_correction and batch_key in adata.obs.columns:
|
|
359
|
+
try:
|
|
360
|
+
import scanpy.external as sce
|
|
361
|
+
|
|
362
|
+
logger.info(
|
|
363
|
+
f"Running BBKNN batch correction with batch_key='{batch_key}'"
|
|
364
|
+
)
|
|
365
|
+
sce.pp.bbknn(adata, batch_key=batch_key, n_pcs=n_pcs_use)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
logger.warning(
|
|
368
|
+
f"BBKNN batch correction failed ({e}); falling back to standard neighbors."
|
|
369
|
+
)
|
|
370
|
+
_sc.pp.neighbors(
|
|
371
|
+
adata,
|
|
372
|
+
n_neighbors=k,
|
|
373
|
+
n_pcs=n_pcs_use,
|
|
374
|
+
use_rep="X_pca",
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
_sc.pp.neighbors(
|
|
378
|
+
adata,
|
|
379
|
+
n_neighbors=k,
|
|
380
|
+
n_pcs=n_pcs_use,
|
|
381
|
+
use_rep="X_pca",
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
_sc.tl.umap(adata)
|
|
385
|
+
_sc.tl.diffmap(adata)
|
|
386
|
+
|
|
387
|
+
return adata
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
def _score_pathways(adata: ad.AnnData) -> ad.AnnData:
|
|
391
|
+
adata = adata.copy()
|
|
392
|
+
adata.var_names = adata.var_names.astype(str).str.upper()
|
|
393
|
+
adata.var_names_make_unique()
|
|
394
|
+
genes_in_adata = list(adata.var_names)
|
|
395
|
+
|
|
396
|
+
score_genes = _sc.tl.score_genes
|
|
397
|
+
try:
|
|
398
|
+
sig = inspect.signature(score_genes)
|
|
399
|
+
supports_ctrl_as_ref = "ctrl_as_ref" in sig.parameters
|
|
400
|
+
except Exception:
|
|
401
|
+
supports_ctrl_as_ref = False
|
|
402
|
+
|
|
403
|
+
for gmt_path in _iter_gmt_files():
|
|
404
|
+
logger.info(f"Scoring pathways from {gmt_path.name}")
|
|
405
|
+
gene_sets = _read_gmt(gmt_path)
|
|
406
|
+
|
|
407
|
+
for pathway, genes in gene_sets.items():
|
|
408
|
+
if not genes:
|
|
409
|
+
continue
|
|
410
|
+
|
|
411
|
+
genes_upper = [g.upper() for g in genes]
|
|
412
|
+
genes_inter = [g for g in genes_upper if g in genes_in_adata]
|
|
413
|
+
if not genes_inter:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
kwargs = dict(
|
|
417
|
+
gene_list=genes_inter,
|
|
418
|
+
score_name=pathway,
|
|
419
|
+
use_raw=False,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if supports_ctrl_as_ref:
|
|
423
|
+
kwargs["ctrl_as_ref"] = False
|
|
424
|
+
|
|
425
|
+
try:
|
|
426
|
+
score_genes(adata, **kwargs)
|
|
427
|
+
except RuntimeError as e:
|
|
428
|
+
msg = str(e)
|
|
429
|
+
if "No control genes found in any cut" in msg:
|
|
430
|
+
logger.warning(
|
|
431
|
+
"score_genes failed for pathway '%s' due to lack of control "
|
|
432
|
+
"genes (message: %s). Skipping this pathway. "
|
|
433
|
+
"Consider adjusting gene sets or ctrl_size if needed.",
|
|
434
|
+
pathway,
|
|
435
|
+
msg,
|
|
436
|
+
)
|
|
437
|
+
continue
|
|
438
|
+
raise
|
|
439
|
+
|
|
440
|
+
return adata
|
|
441
|
+
|
|
442
|
+
@staticmethod
|
|
443
|
+
def _build_feature_matrix(
|
|
444
|
+
obs: pd.DataFrame,
|
|
445
|
+
feature_names: List[str],
|
|
446
|
+
) -> pd.DataFrame:
|
|
447
|
+
X = pd.DataFrame(index=obs.index)
|
|
448
|
+
|
|
449
|
+
for feat in feature_names:
|
|
450
|
+
if feat in obs.columns:
|
|
451
|
+
col = obs[feat]
|
|
452
|
+
else:
|
|
453
|
+
col = 0.0
|
|
454
|
+
|
|
455
|
+
col = pd.to_numeric(col, errors="coerce").fillna(0.0).astype(np.float32)
|
|
456
|
+
X[feat] = col
|
|
457
|
+
|
|
458
|
+
return X
|
|
459
|
+
|
|
460
|
+
def _annotate_with_model(self, adata: ad.AnnData, device: Optional[str] = None) -> ad.AnnData:
|
|
461
|
+
adata = adata.copy()
|
|
462
|
+
|
|
463
|
+
bundle = _load_model_bundle(device=device)
|
|
464
|
+
model = bundle["model"]
|
|
465
|
+
feature_list = list(bundle.get("features", []))
|
|
466
|
+
scaler = bundle.get("scaler", None)
|
|
467
|
+
label_map = bundle.get(
|
|
468
|
+
"label_map",
|
|
469
|
+
{0: "normal", 1: "abnormal", 2: "pre-malignant", 3: "malignant"},
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
if not feature_list:
|
|
473
|
+
logger.warning(
|
|
474
|
+
"Model bundle has empty 'features'; skipping oncoordinate annotation."
|
|
475
|
+
)
|
|
476
|
+
return adata
|
|
477
|
+
|
|
478
|
+
meta = adata.obs.copy()
|
|
479
|
+
X_df = self._build_feature_matrix(meta, feature_list)
|
|
480
|
+
X_df_float = X_df.astype(np.float32)
|
|
481
|
+
|
|
482
|
+
if scaler is not None:
|
|
483
|
+
try:
|
|
484
|
+
X_scaled = scaler.transform(X_df_float)
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.warning(
|
|
487
|
+
f"Scaler transform failed ({e}); falling back to MinMaxScaler fit on the fly."
|
|
488
|
+
)
|
|
489
|
+
X_scaled = MinMaxScaler().fit_transform(X_df_float)
|
|
490
|
+
else:
|
|
491
|
+
X_scaled = MinMaxScaler().fit_transform(X_df_float)
|
|
492
|
+
|
|
493
|
+
X_scaled = np.nan_to_num(X_scaled, nan=0.0, posinf=0.0, neginf=0.0)
|
|
494
|
+
|
|
495
|
+
base_model = model
|
|
496
|
+
if hasattr(base_model, "best_estimator_"):
|
|
497
|
+
base_model = base_model.best_estimator_
|
|
498
|
+
|
|
499
|
+
base_model = self._ensure_cpu_safe_model(base_model, device=device)
|
|
500
|
+
|
|
501
|
+
if not hasattr(base_model, "predict_proba"):
|
|
502
|
+
raise RuntimeError(
|
|
503
|
+
"Loaded oncoordinate model does not support predict_proba(). "
|
|
504
|
+
"Cannot annotate cells."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
proba = np.asarray(base_model.predict_proba(X_scaled), dtype=float)
|
|
508
|
+
|
|
509
|
+
if proba.ndim != 2:
|
|
510
|
+
raise ValueError(
|
|
511
|
+
f"Model.predict_proba returned array with shape {proba.shape}, expected 2D."
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
classes_attr = getattr(base_model, "classes_", None)
|
|
515
|
+
if classes_attr is not None:
|
|
516
|
+
try:
|
|
517
|
+
class_names = [label_map.get(int(c), str(c)) for c in classes_attr]
|
|
518
|
+
except Exception:
|
|
519
|
+
class_names = [str(c) for c in classes_attr]
|
|
520
|
+
else:
|
|
521
|
+
class_names = [
|
|
522
|
+
label_map.get(i, str(i)) for i in range(proba.shape[1])
|
|
523
|
+
]
|
|
524
|
+
|
|
525
|
+
pred_idx = np.argmax(proba, axis=1)
|
|
526
|
+
pred_labels = [class_names[i] for i in pred_idx]
|
|
527
|
+
|
|
528
|
+
adata.obs["oncoordinate_stage_idx"] = pred_idx.astype(int)
|
|
529
|
+
adata.obs["oncoordinate_stage"] = pd.Categorical(
|
|
530
|
+
pred_labels,
|
|
531
|
+
categories=class_names,
|
|
532
|
+
ordered=True,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
for j, cname in enumerate(class_names):
|
|
536
|
+
col_name = f"oncoordinate_proba_{cname}"
|
|
537
|
+
adata.obs[col_name] = proba[:, j]
|
|
538
|
+
|
|
539
|
+
return adata
|
|
540
|
+
|
|
541
|
+
def _process_one_sample(
|
|
542
|
+
self,
|
|
543
|
+
adata_s: ad.AnnData,
|
|
544
|
+
*,
|
|
545
|
+
pca_n_comps: Optional[int],
|
|
546
|
+
neighbors_n_pcs: Optional[int],
|
|
547
|
+
neighbors_k: Optional[int],
|
|
548
|
+
min_cells: int,
|
|
549
|
+
min_genes: int,
|
|
550
|
+
use_batch_correction: bool,
|
|
551
|
+
device: Optional[str] = None,
|
|
552
|
+
run_embeddings: bool = True,
|
|
553
|
+
) -> ad.AnnData:
|
|
554
|
+
|
|
555
|
+
ad = self._preprocess(
|
|
556
|
+
adata_s,
|
|
557
|
+
min_cells=min_cells,
|
|
558
|
+
min_genes=min_genes,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
if run_embeddings:
|
|
562
|
+
ad = self._run_embeddings(
|
|
563
|
+
ad,
|
|
564
|
+
pca_n_comps=pca_n_comps,
|
|
565
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
566
|
+
neighbors_k=neighbors_k,
|
|
567
|
+
use_batch_correction=use_batch_correction,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
ad = self._score_pathways(ad)
|
|
571
|
+
ad = self._annotate_with_model(ad, device=device)
|
|
572
|
+
|
|
573
|
+
ad.uns.setdefault("oncoordinate_params", {})
|
|
574
|
+
ad.uns["oncoordinate_params"].update(
|
|
575
|
+
dict(
|
|
576
|
+
min_cells=int(min_cells),
|
|
577
|
+
min_genes=int(min_genes),
|
|
578
|
+
device=device or self.device,
|
|
579
|
+
)
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
return ad
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def annotate(
|
|
586
|
+
self,
|
|
587
|
+
*,
|
|
588
|
+
sample_key: Optional[str] = None,
|
|
589
|
+
pca_n_comps: Optional[int] = None,
|
|
590
|
+
neighbors_n_pcs: Optional[int] = None,
|
|
591
|
+
neighbors_k: Optional[int] = None,
|
|
592
|
+
min_cells: int = 3,
|
|
593
|
+
min_genes: int = 200,
|
|
594
|
+
use_batch_correction: bool = False,
|
|
595
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
596
|
+
device: Optional[str] = None,
|
|
597
|
+
) -> ad.AnnData:
|
|
598
|
+
if self.adata is None:
|
|
599
|
+
raise AttributeError("OnCoordinateSC.adata is None – please provide an AnnData.")
|
|
600
|
+
|
|
601
|
+
adata = self.adata.copy()
|
|
602
|
+
adata.obs.index = adata.obs.index.astype(str)
|
|
603
|
+
adata.var.index = adata.var.index.astype(str)
|
|
604
|
+
|
|
605
|
+
skey = sample_key or self.sample_key
|
|
606
|
+
|
|
607
|
+
if skey not in adata.obs.columns:
|
|
608
|
+
logger.info(
|
|
609
|
+
"'%s' not found in adata.obs – processing as a single sample.",
|
|
610
|
+
skey,
|
|
611
|
+
)
|
|
612
|
+
annotated = self._process_one_sample(
|
|
613
|
+
adata,
|
|
614
|
+
pca_n_comps=pca_n_comps,
|
|
615
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
616
|
+
neighbors_k=neighbors_k,
|
|
617
|
+
min_cells=min_cells,
|
|
618
|
+
min_genes=min_genes,
|
|
619
|
+
use_batch_correction=use_batch_correction,
|
|
620
|
+
device=(device or self.device),
|
|
621
|
+
)
|
|
622
|
+
else:
|
|
623
|
+
col = adata.obs[skey].astype(str).str.strip()
|
|
624
|
+
adata.obs[skey] = col
|
|
625
|
+
|
|
626
|
+
unique_samples = col.unique().tolist()
|
|
627
|
+
|
|
628
|
+
if len(unique_samples) == 1:
|
|
629
|
+
logger.info(
|
|
630
|
+
"Single sample detected (sample_key='%s'). Processing whole AnnData.",
|
|
631
|
+
skey,
|
|
632
|
+
)
|
|
633
|
+
annotated = self._process_one_sample(
|
|
634
|
+
adata,
|
|
635
|
+
pca_n_comps=pca_n_comps,
|
|
636
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
637
|
+
neighbors_k=neighbors_k,
|
|
638
|
+
min_cells=min_cells,
|
|
639
|
+
min_genes=min_genes,
|
|
640
|
+
use_batch_correction=use_batch_correction,
|
|
641
|
+
device=(device or self.device),
|
|
642
|
+
)
|
|
643
|
+
else:
|
|
644
|
+
logger.info(
|
|
645
|
+
"Multiple samples detected in '%s': %s (use_batch_correction=%s)",
|
|
646
|
+
skey,
|
|
647
|
+
unique_samples,
|
|
648
|
+
use_batch_correction,
|
|
649
|
+
)
|
|
650
|
+
annotated_list: List[ad.AnnData] = []
|
|
651
|
+
|
|
652
|
+
for sid in unique_samples:
|
|
653
|
+
mask = adata.obs[skey] == sid
|
|
654
|
+
if mask.sum() == 0:
|
|
655
|
+
continue
|
|
656
|
+
ad_s = adata[mask, :].copy()
|
|
657
|
+
|
|
658
|
+
if use_batch_correction:
|
|
659
|
+
ann_s = self._process_one_sample(
|
|
660
|
+
ad_s,
|
|
661
|
+
pca_n_comps=pca_n_comps,
|
|
662
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
663
|
+
neighbors_k=neighbors_k,
|
|
664
|
+
min_cells=min_cells,
|
|
665
|
+
min_genes=min_genes,
|
|
666
|
+
use_batch_correction=False,
|
|
667
|
+
device=(device or self.device),
|
|
668
|
+
run_embeddings=False,
|
|
669
|
+
)
|
|
670
|
+
else:
|
|
671
|
+
ann_s = self._process_one_sample(
|
|
672
|
+
ad_s,
|
|
673
|
+
pca_n_comps=pca_n_comps,
|
|
674
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
675
|
+
neighbors_k=neighbors_k,
|
|
676
|
+
min_cells=min_cells,
|
|
677
|
+
min_genes=min_genes,
|
|
678
|
+
use_batch_correction=False,
|
|
679
|
+
device=(device or self.device),
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
ann_s.obs[skey] = sid
|
|
683
|
+
annotated_list.append(ann_s)
|
|
684
|
+
|
|
685
|
+
if not annotated_list:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
f"No samples contained any cells after filtering (sample_key='{skey}')."
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
annotated = ad.concat(
|
|
691
|
+
annotated_list,
|
|
692
|
+
label="__sample__",
|
|
693
|
+
keys=[str(a.obs[skey].unique()[0]) for a in annotated_list],
|
|
694
|
+
join="outer",
|
|
695
|
+
index_unique=None,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
if use_batch_correction:
|
|
699
|
+
batch_key = self.batch_key or skey
|
|
700
|
+
logger.info(
|
|
701
|
+
"Running BBKNN batch correction once on concatenated AnnData "
|
|
702
|
+
"(n_obs=%d, n_vars=%d, batch_key='%s').",
|
|
703
|
+
annotated.n_obs,
|
|
704
|
+
annotated.n_vars,
|
|
705
|
+
batch_key,
|
|
706
|
+
)
|
|
707
|
+
annotated = self._run_embeddings(
|
|
708
|
+
annotated,
|
|
709
|
+
pca_n_comps=pca_n_comps,
|
|
710
|
+
neighbors_n_pcs=neighbors_n_pcs,
|
|
711
|
+
neighbors_k=neighbors_k,
|
|
712
|
+
use_batch_correction=True,
|
|
713
|
+
)
|
|
714
|
+
annotated.uns.setdefault("oncoordinate_params", {})
|
|
715
|
+
annotated.uns["oncoordinate_params"].update(
|
|
716
|
+
dict(
|
|
717
|
+
min_cells=int(min_cells),
|
|
718
|
+
min_genes=int(min_genes),
|
|
719
|
+
device=device or self.device,
|
|
720
|
+
)
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
if save_path is not None:
|
|
724
|
+
save_path = Path(save_path)
|
|
725
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
726
|
+
annotated.write_h5ad(save_path)
|
|
727
|
+
logger.info("Wrote oncoordinate-annotated AnnData to: %s", save_path)
|
|
728
|
+
|
|
729
|
+
return annotated
|