oncoordinate 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oncoordinate/HallmarkPathGMT/HALLMARK_ADIPOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ALLOGRAFT_REJECTION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ANDROGEN_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ANGIOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_JUNCTION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_SURFACE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_APOPTOSIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_BILE_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_CHOLESTEROL_HOMEOSTASIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_COAGULATION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_COMPLEMENT.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_DNA_REPAIR.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_E2F_TARGETS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_EARLY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_LATE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_FATTY_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_G2M_CHECKPOINT.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_GLYCOLYSIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HEDGEHOG_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HEME_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_HYPOXIA.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_IL2_STAT5_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_IL6_JAK_STAT3_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INFLAMMATORY_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_ALPHA_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_GAMMA_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_DN.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_UP.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MITOTIC_SPINDLE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MTORC1_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V1.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V2.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_MYOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_NOTCH_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_OXIDATIVE_PHOSPHORYLATION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_P53_PATHWAY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PANCREAS_BETA_CELLS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PEROXISOME.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PI3K_AKT_MTOR_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_PROTEIN_SECRETION.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_SPERMATOGENESIS.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_TGF_BETA_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_TNFA_SIGNALING_VIA_NFKB.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UNFOLDED_PROTEIN_RESPONSE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_DN.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_UP.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_WNT_BETA_CATENIN_SIGNALING.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/HALLMARK_XENOBIOTIC_METABOLISM.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/REACTOME_CELL_CYCLE.v2023.2.Hs.gmt +1 -0
- oncoordinate/HallmarkPathGMT/REACTOME_SIGNALING_BY_EGFR_IN_CANCER.v2023.2.Hs.gmt +1 -0
- oncoordinate/__init__.py +0 -0
- oncoordinate/lt.py +472 -0
- oncoordinate/oncoordinate.joblib +0 -0
- oncoordinate/sc.py +729 -0
- oncoordinate/sp.py +513 -0
- oncoordinate-0.1.7.dist-info/METADATA +93 -0
- oncoordinate-0.1.7.dist-info/RECORD +62 -0
- oncoordinate-0.1.7.dist-info/WHEEL +5 -0
- oncoordinate-0.1.7.dist-info/licenses/LICENSE +21 -0
- oncoordinate-0.1.7.dist-info/top_level.txt +1 -0
oncoordinate/lt.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from importlib import resources
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Union, Sequence
|
|
7
|
+
|
|
8
|
+
import anndata as ad
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import scanpy as sc
|
|
12
|
+
from scipy import sparse
|
|
13
|
+
import scvi
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
def _load_reference_adata(
|
|
18
|
+
reference: Optional[Union[str, Path, ad.AnnData]] = None,
|
|
19
|
+
) -> ad.AnnData:
|
|
20
|
+
if isinstance(reference, ad.AnnData):
|
|
21
|
+
return reference.copy()
|
|
22
|
+
|
|
23
|
+
if reference is not None:
|
|
24
|
+
path = Path(reference)
|
|
25
|
+
if not path.is_file():
|
|
26
|
+
raise FileNotFoundError(f"Reference AnnData not found: {path}")
|
|
27
|
+
return ad.read_h5ad(path)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
pkg_root = resources.files("oncoordinate")
|
|
31
|
+
res = pkg_root.joinpath("reference_sc.h5ad")
|
|
32
|
+
with resources.as_file(res) as p:
|
|
33
|
+
return ad.read_h5ad(p)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
raise FileNotFoundError(
|
|
36
|
+
"No reference AnnData provided and default 'reference_sc.h5ad' "
|
|
37
|
+
"not found within the oncoordinate package."
|
|
38
|
+
) from e
|
|
39
|
+
|
|
40
|
+
def _load_spatial_adata(
|
|
41
|
+
spatial: Union[str, Path, ad.AnnData],
|
|
42
|
+
) -> ad.AnnData:
|
|
43
|
+
if isinstance(spatial, ad.AnnData):
|
|
44
|
+
return spatial.copy()
|
|
45
|
+
path = Path(spatial)
|
|
46
|
+
if not path.is_file():
|
|
47
|
+
raise FileNotFoundError(f"Spatial AnnData not found: {path}")
|
|
48
|
+
return ad.read_h5ad(path)
|
|
49
|
+
|
|
50
|
+
def _get_counts_layer(
|
|
51
|
+
adata: ad.AnnData,
|
|
52
|
+
counts_layer: Optional[str] = "counts",
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
if counts_layer is not None and counts_layer in adata.layers:
|
|
55
|
+
X = adata.layers[counts_layer]
|
|
56
|
+
else:
|
|
57
|
+
X = adata.X
|
|
58
|
+
|
|
59
|
+
if sparse.issparse(X):
|
|
60
|
+
X = X.tocsr(copy=False)
|
|
61
|
+
X.data = np.rint(np.clip(X.data, 0, None)).astype(np.int32)
|
|
62
|
+
else:
|
|
63
|
+
X = np.rint(np.clip(np.asarray(X), 0, None)).astype(np.int32)
|
|
64
|
+
return X
|
|
65
|
+
|
|
66
|
+
def _make_pseudospots(
|
|
67
|
+
adata_ref: ad.AnnData,
|
|
68
|
+
*,
|
|
69
|
+
celltype_key: str,
|
|
70
|
+
stage_key: str,
|
|
71
|
+
pseudospot_size: int = 10,
|
|
72
|
+
) -> ad.AnnData:
|
|
73
|
+
X = adata_ref.layers["counts"]
|
|
74
|
+
if sparse.issparse(X):
|
|
75
|
+
X = X.tocsr(copy=False)
|
|
76
|
+
|
|
77
|
+
obs = adata_ref.obs[[celltype_key, stage_key]].astype(str).copy()
|
|
78
|
+
obs["cell_state"] = obs[celltype_key] + "::" + obs[stage_key]
|
|
79
|
+
|
|
80
|
+
groups = obs.groupby("cell_state", observed=True).indices
|
|
81
|
+
|
|
82
|
+
X_ps_list = []
|
|
83
|
+
obs_ps_rows = []
|
|
84
|
+
rng = np.random.default_rng(0)
|
|
85
|
+
|
|
86
|
+
for cs, idxs in groups.items():
|
|
87
|
+
idxs = np.asarray(idxs)
|
|
88
|
+
if idxs.size == 0:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
rng.shuffle(idxs)
|
|
92
|
+
n_spots = max(1, idxs.size // pseudospot_size)
|
|
93
|
+
|
|
94
|
+
for k in range(n_spots):
|
|
95
|
+
start = k * pseudospot_size
|
|
96
|
+
end = (k + 1) * pseudospot_size
|
|
97
|
+
sel = idxs[start:end]
|
|
98
|
+
if sel.size == 0:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
v = X[sel].sum(axis=0)
|
|
102
|
+
v = np.asarray(v).ravel() if sparse.issparse(v) else np.asarray(v).ravel()
|
|
103
|
+
|
|
104
|
+
ct, st = cs.split("::", 1)
|
|
105
|
+
X_ps_list.append(v)
|
|
106
|
+
obs_ps_rows.append((ct, st, cs))
|
|
107
|
+
|
|
108
|
+
if not X_ps_list:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"No pseudospots could be created. Check celltype_key/stage_key "
|
|
111
|
+
"and pseudospot_size."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
X_ps = np.vstack(X_ps_list).astype(np.int32)
|
|
115
|
+
obs_ps = pd.DataFrame(
|
|
116
|
+
obs_ps_rows,
|
|
117
|
+
columns=[celltype_key, stage_key, "cell_state"],
|
|
118
|
+
index=[f"ps_{i}" for i in range(len(X_ps))],
|
|
119
|
+
)
|
|
120
|
+
obs_ps["tech"] = "reference"
|
|
121
|
+
|
|
122
|
+
adata_ps = ad.AnnData(X=X_ps, obs=obs_ps, var=adata_ref.var.copy())
|
|
123
|
+
adata_ps.layers["counts"] = X_ps
|
|
124
|
+
return adata_ps
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _match_library_sizes(
|
|
128
|
+
adata_ref_ps: ad.AnnData,
|
|
129
|
+
adata_sp: ad.AnnData,
|
|
130
|
+
*,
|
|
131
|
+
layer: str = "counts",
|
|
132
|
+
) -> None:
|
|
133
|
+
|
|
134
|
+
def _libsizes(a: ad.AnnData) -> np.ndarray:
|
|
135
|
+
Xc = a.layers[layer]
|
|
136
|
+
return (
|
|
137
|
+
np.asarray(Xc.sum(axis=1)).ravel()
|
|
138
|
+
if sparse.issparse(Xc)
|
|
139
|
+
else np.asarray(Xc.sum(axis=1)).ravel()
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
lib_ref = _libsizes(adata_ref_ps)
|
|
143
|
+
lib_sp = _libsizes(adata_sp)
|
|
144
|
+
|
|
145
|
+
med_ref = float(np.median(lib_ref)) if lib_ref.size else 0.0
|
|
146
|
+
med_sp = float(np.median(lib_sp)) if lib_sp.size else 0.0
|
|
147
|
+
|
|
148
|
+
if med_ref <= 0 or med_sp <= 0:
|
|
149
|
+
logger.warning("Unable to estimate library sizes; skipping scaling.")
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
ratio = med_sp / med_ref
|
|
153
|
+
Xc = adata_ref_ps.layers[layer]
|
|
154
|
+
if sparse.issparse(Xc):
|
|
155
|
+
Xc = Xc.tocsr(copy=False)
|
|
156
|
+
Xc.data = np.rint(Xc.data * ratio).astype(np.int32)
|
|
157
|
+
else:
|
|
158
|
+
Xc = np.rint(Xc * ratio).astype(np.int32)
|
|
159
|
+
adata_ref_ps.layers[layer] = Xc
|
|
160
|
+
|
|
161
|
+
def _run_transfer_single(
|
|
162
|
+
adata_spatial_in: ad.AnnData,
|
|
163
|
+
adata_ref_in: ad.AnnData,
|
|
164
|
+
*,
|
|
165
|
+
celltype_key: str = "celltype",
|
|
166
|
+
stage_key: str = "oncoordinate_stage",
|
|
167
|
+
counts_layer: Optional[str] = "counts",
|
|
168
|
+
stage_order: Optional[Sequence[str]] = ("normal", "abnormal", "pre-malignant", "malignant"),
|
|
169
|
+
max_reference_cells: int = 900_000,
|
|
170
|
+
pseudospot_size: int = 10,
|
|
171
|
+
n_hvg: int = 2000,
|
|
172
|
+
scvi_max_epochs: int = 250,
|
|
173
|
+
scanvi_max_epochs: int = 250,
|
|
174
|
+
scvi_model_dir: Optional[Union[str, Path]] = None,
|
|
175
|
+
scanvi_model_dir: Optional[Union[str, Path]] = None,
|
|
176
|
+
) -> ad.AnnData:
|
|
177
|
+
adata_ref = adata_ref_in.copy()
|
|
178
|
+
adata_spatial = adata_spatial_in.copy()
|
|
179
|
+
|
|
180
|
+
if celltype_key not in adata_ref.obs.columns:
|
|
181
|
+
raise KeyError(f"celltype_key '{celltype_key}' not found in reference.obs")
|
|
182
|
+
if stage_key not in adata_ref.obs.columns:
|
|
183
|
+
raise KeyError(f"stage_key '{stage_key}' not found in reference.obs")
|
|
184
|
+
|
|
185
|
+
for a in (adata_ref, adata_spatial):
|
|
186
|
+
a.obs.index = a.obs.index.astype(str)
|
|
187
|
+
a.var.index = a.var.index.astype(str)
|
|
188
|
+
a.var_names = a.var_names.astype(str).str.upper()
|
|
189
|
+
a.var_names_make_unique()
|
|
190
|
+
|
|
191
|
+
shared_genes = np.intersect1d(adata_ref.var_names, adata_spatial.var_names)
|
|
192
|
+
if shared_genes.size == 0:
|
|
193
|
+
raise ValueError("No overlapping genes between reference and spatial data.")
|
|
194
|
+
|
|
195
|
+
adata_ref = adata_ref[:, shared_genes].copy()
|
|
196
|
+
adata_sp = adata_spatial[:, shared_genes].copy()
|
|
197
|
+
adata_ref.layers["counts"] = _get_counts_layer(adata_ref, counts_layer)
|
|
198
|
+
adata_sp.layers["counts"] = _get_counts_layer(adata_sp, counts_layer)
|
|
199
|
+
|
|
200
|
+
if adata_ref.n_obs > max_reference_cells:
|
|
201
|
+
logger.info(
|
|
202
|
+
f"Reference has {adata_ref.n_obs} cells; downsampling to ~{max_reference_cells}."
|
|
203
|
+
)
|
|
204
|
+
rng = np.random.default_rng(0)
|
|
205
|
+
groups = adata_ref.obs.groupby([celltype_key, stage_key], observed=True).indices
|
|
206
|
+
total = adata_ref.n_obs
|
|
207
|
+
keep_idx = []
|
|
208
|
+
for _, idxs in groups.items():
|
|
209
|
+
idxs = np.asarray(idxs)
|
|
210
|
+
if idxs.size == 0:
|
|
211
|
+
continue
|
|
212
|
+
n = min(int(idxs.size / total * max_reference_cells), idxs.size)
|
|
213
|
+
if n > 0:
|
|
214
|
+
keep_idx.append(rng.choice(idxs, size=n, replace=False))
|
|
215
|
+
if not keep_idx:
|
|
216
|
+
raise RuntimeError("Subsampling removed all reference cells.")
|
|
217
|
+
keep_idx = np.concatenate(keep_idx)
|
|
218
|
+
adata_ref = adata_ref[keep_idx].copy()
|
|
219
|
+
|
|
220
|
+
idx_multi = pd.MultiIndex.from_frame(
|
|
221
|
+
adata_ref.obs[[celltype_key, stage_key]].astype(str)
|
|
222
|
+
)
|
|
223
|
+
counts = idx_multi.value_counts()
|
|
224
|
+
valid_states = counts[counts >= 3].index
|
|
225
|
+
mask_valid = idx_multi.isin(valid_states)
|
|
226
|
+
adata_ref = adata_ref[mask_valid].copy()
|
|
227
|
+
if adata_ref.n_obs == 0:
|
|
228
|
+
raise RuntimeError(
|
|
229
|
+
"No reference cells left after filtering small (celltype, stage) groups."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
adata_ps = _make_pseudospots(
|
|
233
|
+
adata_ref,
|
|
234
|
+
celltype_key=celltype_key,
|
|
235
|
+
stage_key=stage_key,
|
|
236
|
+
pseudospot_size=pseudospot_size,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
adata_sp = adata_sp.copy()
|
|
240
|
+
adata_sp.obs["tech"] = "spatial"
|
|
241
|
+
adata_sp.obs["cell_state"] = "unlabeled"
|
|
242
|
+
|
|
243
|
+
_match_library_sizes(adata_ps, adata_sp, layer="counts")
|
|
244
|
+
|
|
245
|
+
adata_comb = ad.concat(
|
|
246
|
+
[adata_ps, adata_sp],
|
|
247
|
+
join="inner",
|
|
248
|
+
merge="same",
|
|
249
|
+
label="__source__",
|
|
250
|
+
keys=["reference", "spatial"],
|
|
251
|
+
index_unique=None,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
sc.pp.highly_variable_genes(
|
|
255
|
+
adata_comb,
|
|
256
|
+
layer="counts",
|
|
257
|
+
n_top_genes=n_hvg,
|
|
258
|
+
flavor="seurat_v3",
|
|
259
|
+
batch_key="tech",
|
|
260
|
+
inplace=True,
|
|
261
|
+
)
|
|
262
|
+
adata_comb = adata_comb[:, adata_comb.var["highly_variable"].values].copy()
|
|
263
|
+
|
|
264
|
+
scvi.model.SCVI.setup_anndata(
|
|
265
|
+
adata_comb,
|
|
266
|
+
batch_key="tech",
|
|
267
|
+
labels_key="cell_state",
|
|
268
|
+
layer="counts",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
scvi_dir = Path(scvi_model_dir) if scvi_model_dir is not None else None
|
|
272
|
+
if scvi_dir is not None and scvi_dir.exists():
|
|
273
|
+
logger.info(f"Loading SCVI model from {scvi_dir}")
|
|
274
|
+
scvi_model = scvi.model.SCVI.load(scvi_dir, adata=adata_comb)
|
|
275
|
+
else:
|
|
276
|
+
scvi_model = scvi.model.SCVI(adata_comb, n_layers=3, n_latent=32)
|
|
277
|
+
scvi_model.train(max_epochs=scvi_max_epochs, batch_size=512)
|
|
278
|
+
if scvi_dir is not None:
|
|
279
|
+
scvi_model.save(scvi_dir, overwrite=True)
|
|
280
|
+
|
|
281
|
+
adata_comb.obsm["X_oncoordinate_scvi"] = scvi_model.get_latent_representation()
|
|
282
|
+
scanvi_dir = Path(scanvi_model_dir) if scanvi_model_dir is not None else None
|
|
283
|
+
if scanvi_dir is not None and scanvi_dir.exists():
|
|
284
|
+
logger.info(f"Loading SCANVI model from {scanvi_dir}")
|
|
285
|
+
scanvi_model = scvi.model.SCANVI.load(scanvi_dir, adata=adata_comb)
|
|
286
|
+
else:
|
|
287
|
+
scanvi_model = scvi.model.SCANVI.from_scvi_model(
|
|
288
|
+
scvi_model,
|
|
289
|
+
unlabeled_category="unlabeled",
|
|
290
|
+
)
|
|
291
|
+
scanvi_model.train(
|
|
292
|
+
max_epochs=scanvi_max_epochs,
|
|
293
|
+
batch_size=256,
|
|
294
|
+
plan_kwargs={"lr": 5e-4},
|
|
295
|
+
gradient_clip_val=10.0,
|
|
296
|
+
)
|
|
297
|
+
if scanvi_dir is not None:
|
|
298
|
+
scanvi_model.save(scanvi_dir, overwrite=True)
|
|
299
|
+
|
|
300
|
+
adata_comb.obsm["X_oncoordinate_scanvi"] = scanvi_model.get_latent_representation(
|
|
301
|
+
adata_comb
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
state_pred = scanvi_model.predict(adata_comb)
|
|
305
|
+
adata_comb.obs["oncoordinate_tl_cell_state"] = state_pred.astype(str)
|
|
306
|
+
soft = scanvi_model.predict(adata_comb, soft=True)
|
|
307
|
+
|
|
308
|
+
if isinstance(soft, pd.DataFrame):
|
|
309
|
+
soft_df = soft.loc[adata_comb.obs_names]
|
|
310
|
+
else:
|
|
311
|
+
soft_arr = np.asarray(soft)
|
|
312
|
+
if soft_arr.ndim == 1:
|
|
313
|
+
soft_arr = soft_arr.reshape(-1, 1)
|
|
314
|
+
soft_df = pd.DataFrame(
|
|
315
|
+
soft_arr,
|
|
316
|
+
index=adata_comb.obs_names,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
adata_comb.obsm["oncoordinate_tl_state_proba"] = soft_df.to_numpy()
|
|
320
|
+
adata_comb.obs["oncoordinate_tl_state_confidence"] = soft_df.max(axis=1)
|
|
321
|
+
|
|
322
|
+
def _split_state(s: str) -> tuple[str, str]:
|
|
323
|
+
return s.split("::", 1) if "::" in s else (s, "NA")
|
|
324
|
+
|
|
325
|
+
ct_pred, st_pred = zip(
|
|
326
|
+
*[_split_state(s) for s in adata_comb.obs["oncoordinate_tl_cell_state"].astype(str)]
|
|
327
|
+
)
|
|
328
|
+
adata_comb.obs["oncoordinate_tl_celltype"] = list(ct_pred)
|
|
329
|
+
|
|
330
|
+
if stage_order is None:
|
|
331
|
+
stage_order_use = sorted(set(st_pred))
|
|
332
|
+
else:
|
|
333
|
+
stage_order_use = list(stage_order)
|
|
334
|
+
|
|
335
|
+
adata_comb.obs["oncoordinate_tl_stage"] = pd.Categorical(
|
|
336
|
+
list(st_pred),
|
|
337
|
+
categories=stage_order_use,
|
|
338
|
+
ordered=True,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if isinstance(soft, pd.DataFrame):
|
|
342
|
+
for stage in stage_order_use:
|
|
343
|
+
cols_stage = [c for c in soft_df.columns if c.endswith(f"::{stage}")]
|
|
344
|
+
if cols_stage:
|
|
345
|
+
series = soft_df[cols_stage].sum(axis=1)
|
|
346
|
+
else:
|
|
347
|
+
series = pd.Series(0.0, index=soft_df.index)
|
|
348
|
+
adata_comb.obs[f"oncoordinate_tl_stage_proba_{stage}"] = series
|
|
349
|
+
else:
|
|
350
|
+
for stage in stage_order_use:
|
|
351
|
+
adata_comb.obs[f"oncoordinate_tl_stage_proba_{stage}"] = 0.0
|
|
352
|
+
|
|
353
|
+
is_spatial = adata_comb.obs["tech"].values == "spatial"
|
|
354
|
+
idx_spatial_comb = adata_comb.obs.index[is_spatial]
|
|
355
|
+
|
|
356
|
+
spatial_out = adata_spatial_in.copy()
|
|
357
|
+
pred_obs = adata_comb.obs.loc[idx_spatial_comb]
|
|
358
|
+
pred_obs = pred_obs.reindex(spatial_out.obs_names)
|
|
359
|
+
|
|
360
|
+
cols_to_copy = [
|
|
361
|
+
"oncoordinate_tl_cell_state",
|
|
362
|
+
"oncoordinate_tl_celltype",
|
|
363
|
+
"oncoordinate_tl_stage",
|
|
364
|
+
"oncoordinate_tl_state_confidence",
|
|
365
|
+
] + [f"oncoordinate_tl_stage_proba_{stage}" for stage in stage_order_use]
|
|
366
|
+
|
|
367
|
+
for col in cols_to_copy:
|
|
368
|
+
spatial_out.obs[col] = pred_obs[col]
|
|
369
|
+
|
|
370
|
+
latent_all = adata_comb.obsm["X_oncoordinate_scanvi"]
|
|
371
|
+
latent_spatial = latent_all[is_spatial, :]
|
|
372
|
+
latent_df = pd.DataFrame(latent_spatial, index=idx_spatial_comb)
|
|
373
|
+
latent_df = latent_df.reindex(spatial_out.obs_names)
|
|
374
|
+
spatial_out.obsm["X_oncoordinate_scanvi"] = latent_df.to_numpy()
|
|
375
|
+
|
|
376
|
+
spatial_out.uns.setdefault("oncoordinate_tl_params", {})
|
|
377
|
+
spatial_out.uns["oncoordinate_tl_params"].update(
|
|
378
|
+
dict(
|
|
379
|
+
celltype_key=celltype_key,
|
|
380
|
+
stage_key=stage_key,
|
|
381
|
+
counts_layer=counts_layer,
|
|
382
|
+
max_reference_cells=int(max_reference_cells),
|
|
383
|
+
pseudospot_size=int(pseudospot_size),
|
|
384
|
+
n_hvg=int(n_hvg),
|
|
385
|
+
stage_order=list(stage_order_use),
|
|
386
|
+
)
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
return spatial_out
|
|
390
|
+
|
|
391
|
+
def run_transfer(
|
|
392
|
+
spatial: Union[str, Path, ad.AnnData],
|
|
393
|
+
reference: Optional[Union[str, Path, ad.AnnData]] = None,
|
|
394
|
+
*,
|
|
395
|
+
celltype_key: str = "celltype",
|
|
396
|
+
stage_key: str = "oncoordinate_stage",
|
|
397
|
+
counts_layer: Optional[str] = "counts",
|
|
398
|
+
stage_order: Optional[Sequence[str]] = ("normal", "abnormal", "pre-malignant", "malignant"),
|
|
399
|
+
max_reference_cells: int = 900_000,
|
|
400
|
+
pseudospot_size: int = 10,
|
|
401
|
+
n_hvg: int = 2000,
|
|
402
|
+
scvi_max_epochs: int = 250,
|
|
403
|
+
scanvi_max_epochs: int = 250,
|
|
404
|
+
scvi_model_dir: Optional[Union[str, Path]] = None,
|
|
405
|
+
scanvi_model_dir: Optional[Union[str, Path]] = None,
|
|
406
|
+
per_sample: bool = True,
|
|
407
|
+
sample_key: str = "sample",
|
|
408
|
+
) -> ad.AnnData:
|
|
409
|
+
adata_ref = _load_reference_adata(reference)
|
|
410
|
+
adata_spatial_in = _load_spatial_adata(spatial)
|
|
411
|
+
|
|
412
|
+
if (
|
|
413
|
+
per_sample
|
|
414
|
+
and sample_key is not None
|
|
415
|
+
and sample_key in adata_spatial_in.obs.columns
|
|
416
|
+
):
|
|
417
|
+
groups = adata_spatial_in.obs.groupby(sample_key, observed=True).indices
|
|
418
|
+
if len(groups) > 1:
|
|
419
|
+
logger.info(
|
|
420
|
+
"Found %d spatial samples in obs['%s']; running transfer per sample.",
|
|
421
|
+
len(groups),
|
|
422
|
+
sample_key,
|
|
423
|
+
)
|
|
424
|
+
outs = []
|
|
425
|
+
for sample_name, idx in groups.items():
|
|
426
|
+
sub = adata_spatial_in[idx].copy()
|
|
427
|
+
logger.info(
|
|
428
|
+
"Running transfer for sample '%s' with %d spots.",
|
|
429
|
+
sample_name,
|
|
430
|
+
sub.n_obs,
|
|
431
|
+
)
|
|
432
|
+
out_sub = _run_transfer_single(
|
|
433
|
+
sub,
|
|
434
|
+
adata_ref,
|
|
435
|
+
celltype_key=celltype_key,
|
|
436
|
+
stage_key=stage_key,
|
|
437
|
+
counts_layer=counts_layer,
|
|
438
|
+
stage_order=stage_order,
|
|
439
|
+
max_reference_cells=max_reference_cells,
|
|
440
|
+
pseudospot_size=pseudospot_size,
|
|
441
|
+
n_hvg=n_hvg,
|
|
442
|
+
scvi_max_epochs=scvi_max_epochs,
|
|
443
|
+
scanvi_max_epochs=scanvi_max_epochs,
|
|
444
|
+
scvi_model_dir=scvi_model_dir,
|
|
445
|
+
scanvi_model_dir=scanvi_model_dir,
|
|
446
|
+
)
|
|
447
|
+
outs.append(out_sub)
|
|
448
|
+
|
|
449
|
+
combined = ad.concat(
|
|
450
|
+
outs,
|
|
451
|
+
join="outer",
|
|
452
|
+
merge="same",
|
|
453
|
+
label=None,
|
|
454
|
+
index_unique=None,
|
|
455
|
+
)
|
|
456
|
+
return combined
|
|
457
|
+
|
|
458
|
+
return _run_transfer_single(
|
|
459
|
+
adata_spatial_in,
|
|
460
|
+
adata_ref,
|
|
461
|
+
celltype_key=celltype_key,
|
|
462
|
+
stage_key=stage_key,
|
|
463
|
+
counts_layer=counts_layer,
|
|
464
|
+
stage_order=stage_order,
|
|
465
|
+
max_reference_cells=max_reference_cells,
|
|
466
|
+
pseudospot_size=pseudospot_size,
|
|
467
|
+
n_hvg=n_hvg,
|
|
468
|
+
scvi_max_epochs=scvi_max_epochs,
|
|
469
|
+
scanvi_max_epochs=scanvi_max_epochs,
|
|
470
|
+
scvi_model_dir=scvi_model_dir,
|
|
471
|
+
scanvi_model_dir=scanvi_model_dir,
|
|
472
|
+
)
|
|
Binary file
|