oncoordinate 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. oncoordinate/HallmarkPathGMT/HALLMARK_ADIPOGENESIS.v2023.2.Hs.gmt +1 -0
  2. oncoordinate/HallmarkPathGMT/HALLMARK_ALLOGRAFT_REJECTION.v2023.2.Hs.gmt +1 -0
  3. oncoordinate/HallmarkPathGMT/HALLMARK_ANDROGEN_RESPONSE.v2023.2.Hs.gmt +1 -0
  4. oncoordinate/HallmarkPathGMT/HALLMARK_ANGIOGENESIS.v2023.2.Hs.gmt +1 -0
  5. oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_JUNCTION.v2023.2.Hs.gmt +1 -0
  6. oncoordinate/HallmarkPathGMT/HALLMARK_APICAL_SURFACE.v2023.2.Hs.gmt +1 -0
  7. oncoordinate/HallmarkPathGMT/HALLMARK_APOPTOSIS.v2023.2.Hs.gmt +1 -0
  8. oncoordinate/HallmarkPathGMT/HALLMARK_BILE_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
  9. oncoordinate/HallmarkPathGMT/HALLMARK_CHOLESTEROL_HOMEOSTASIS.v2023.2.Hs.gmt +1 -0
  10. oncoordinate/HallmarkPathGMT/HALLMARK_COAGULATION.v2023.2.Hs.gmt +1 -0
  11. oncoordinate/HallmarkPathGMT/HALLMARK_COMPLEMENT.v2023.2.Hs.gmt +1 -0
  12. oncoordinate/HallmarkPathGMT/HALLMARK_DNA_REPAIR.v2023.2.Hs.gmt +1 -0
  13. oncoordinate/HallmarkPathGMT/HALLMARK_E2F_TARGETS.v2023.2.Hs.gmt +1 -0
  14. oncoordinate/HallmarkPathGMT/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v2023.2.Hs.gmt +1 -0
  15. oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_EARLY.v2023.2.Hs.gmt +1 -0
  16. oncoordinate/HallmarkPathGMT/HALLMARK_ESTROGEN_RESPONSE_LATE.v2023.2.Hs.gmt +1 -0
  17. oncoordinate/HallmarkPathGMT/HALLMARK_FATTY_ACID_METABOLISM.v2023.2.Hs.gmt +1 -0
  18. oncoordinate/HallmarkPathGMT/HALLMARK_G2M_CHECKPOINT.v2023.2.Hs.gmt +1 -0
  19. oncoordinate/HallmarkPathGMT/HALLMARK_GLYCOLYSIS.v2023.2.Hs.gmt +1 -0
  20. oncoordinate/HallmarkPathGMT/HALLMARK_HEDGEHOG_SIGNALING.v2023.2.Hs.gmt +1 -0
  21. oncoordinate/HallmarkPathGMT/HALLMARK_HEME_METABOLISM.v2023.2.Hs.gmt +1 -0
  22. oncoordinate/HallmarkPathGMT/HALLMARK_HYPOXIA.v2023.2.Hs.gmt +1 -0
  23. oncoordinate/HallmarkPathGMT/HALLMARK_IL2_STAT5_SIGNALING.v2023.2.Hs.gmt +1 -0
  24. oncoordinate/HallmarkPathGMT/HALLMARK_IL6_JAK_STAT3_SIGNALING.v2023.2.Hs.gmt +1 -0
  25. oncoordinate/HallmarkPathGMT/HALLMARK_INFLAMMATORY_RESPONSE.v2023.2.Hs.gmt +1 -0
  26. oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_ALPHA_RESPONSE.v2023.2.Hs.gmt +1 -0
  27. oncoordinate/HallmarkPathGMT/HALLMARK_INTERFERON_GAMMA_RESPONSE.v2023.2.Hs.gmt +1 -0
  28. oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_DN.v2023.2.Hs.gmt +1 -0
  29. oncoordinate/HallmarkPathGMT/HALLMARK_KRAS_SIGNALING_UP.v2023.2.Hs.gmt +1 -0
  30. oncoordinate/HallmarkPathGMT/HALLMARK_MITOTIC_SPINDLE.v2023.2.Hs.gmt +1 -0
  31. oncoordinate/HallmarkPathGMT/HALLMARK_MTORC1_SIGNALING.v2023.2.Hs.gmt +1 -0
  32. oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V1.v2023.2.Hs.gmt +1 -0
  33. oncoordinate/HallmarkPathGMT/HALLMARK_MYC_TARGETS_V2.v2023.2.Hs.gmt +1 -0
  34. oncoordinate/HallmarkPathGMT/HALLMARK_MYOGENESIS.v2023.2.Hs.gmt +1 -0
  35. oncoordinate/HallmarkPathGMT/HALLMARK_NOTCH_SIGNALING.v2023.2.Hs.gmt +1 -0
  36. oncoordinate/HallmarkPathGMT/HALLMARK_OXIDATIVE_PHOSPHORYLATION.v2023.2.Hs.gmt +1 -0
  37. oncoordinate/HallmarkPathGMT/HALLMARK_P53_PATHWAY.v2023.2.Hs.gmt +1 -0
  38. oncoordinate/HallmarkPathGMT/HALLMARK_PANCREAS_BETA_CELLS.v2023.2.Hs.gmt +1 -0
  39. oncoordinate/HallmarkPathGMT/HALLMARK_PEROXISOME.v2023.2.Hs.gmt +1 -0
  40. oncoordinate/HallmarkPathGMT/HALLMARK_PI3K_AKT_MTOR_SIGNALING.v2023.2.Hs.gmt +1 -0
  41. oncoordinate/HallmarkPathGMT/HALLMARK_PROTEIN_SECRETION.v2023.2.Hs.gmt +1 -0
  42. oncoordinate/HallmarkPathGMT/HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY.v2023.2.Hs.gmt +1 -0
  43. oncoordinate/HallmarkPathGMT/HALLMARK_SPERMATOGENESIS.v2023.2.Hs.gmt +1 -0
  44. oncoordinate/HallmarkPathGMT/HALLMARK_TGF_BETA_SIGNALING.v2023.2.Hs.gmt +1 -0
  45. oncoordinate/HallmarkPathGMT/HALLMARK_TNFA_SIGNALING_VIA_NFKB.v2023.2.Hs.gmt +1 -0
  46. oncoordinate/HallmarkPathGMT/HALLMARK_UNFOLDED_PROTEIN_RESPONSE.v2023.2.Hs.gmt +1 -0
  47. oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_DN.v2023.2.Hs.gmt +1 -0
  48. oncoordinate/HallmarkPathGMT/HALLMARK_UV_RESPONSE_UP.v2023.2.Hs.gmt +1 -0
  49. oncoordinate/HallmarkPathGMT/HALLMARK_WNT_BETA_CATENIN_SIGNALING.v2023.2.Hs.gmt +1 -0
  50. oncoordinate/HallmarkPathGMT/HALLMARK_XENOBIOTIC_METABOLISM.v2023.2.Hs.gmt +1 -0
  51. oncoordinate/HallmarkPathGMT/REACTOME_CELL_CYCLE.v2023.2.Hs.gmt +1 -0
  52. oncoordinate/HallmarkPathGMT/REACTOME_SIGNALING_BY_EGFR_IN_CANCER.v2023.2.Hs.gmt +1 -0
  53. oncoordinate/__init__.py +0 -0
  54. oncoordinate/lt.py +472 -0
  55. oncoordinate/oncoordinate.joblib +0 -0
  56. oncoordinate/sc.py +729 -0
  57. oncoordinate/sp.py +513 -0
  58. oncoordinate-0.1.7.dist-info/METADATA +93 -0
  59. oncoordinate-0.1.7.dist-info/RECORD +62 -0
  60. oncoordinate-0.1.7.dist-info/WHEEL +5 -0
  61. oncoordinate-0.1.7.dist-info/licenses/LICENSE +21 -0
  62. oncoordinate-0.1.7.dist-info/top_level.txt +1 -0
oncoordinate/lt.py ADDED
@@ -0,0 +1,472 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from importlib import resources
5
+ from pathlib import Path
6
+ from typing import Optional, Union, Sequence
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import pandas as pd
11
+ import scanpy as sc
12
+ from scipy import sparse
13
+ import scvi
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def _load_reference_adata(
18
+ reference: Optional[Union[str, Path, ad.AnnData]] = None,
19
+ ) -> ad.AnnData:
20
+ if isinstance(reference, ad.AnnData):
21
+ return reference.copy()
22
+
23
+ if reference is not None:
24
+ path = Path(reference)
25
+ if not path.is_file():
26
+ raise FileNotFoundError(f"Reference AnnData not found: {path}")
27
+ return ad.read_h5ad(path)
28
+
29
+ try:
30
+ pkg_root = resources.files("oncoordinate")
31
+ res = pkg_root.joinpath("reference_sc.h5ad")
32
+ with resources.as_file(res) as p:
33
+ return ad.read_h5ad(p)
34
+ except Exception as e:
35
+ raise FileNotFoundError(
36
+ "No reference AnnData provided and default 'reference_sc.h5ad' "
37
+ "not found within the oncoordinate package."
38
+ ) from e
39
+
40
+ def _load_spatial_adata(
41
+ spatial: Union[str, Path, ad.AnnData],
42
+ ) -> ad.AnnData:
43
+ if isinstance(spatial, ad.AnnData):
44
+ return spatial.copy()
45
+ path = Path(spatial)
46
+ if not path.is_file():
47
+ raise FileNotFoundError(f"Spatial AnnData not found: {path}")
48
+ return ad.read_h5ad(path)
49
+
50
+ def _get_counts_layer(
51
+ adata: ad.AnnData,
52
+ counts_layer: Optional[str] = "counts",
53
+ ) -> np.ndarray:
54
+ if counts_layer is not None and counts_layer in adata.layers:
55
+ X = adata.layers[counts_layer]
56
+ else:
57
+ X = adata.X
58
+
59
+ if sparse.issparse(X):
60
+ X = X.tocsr(copy=False)
61
+ X.data = np.rint(np.clip(X.data, 0, None)).astype(np.int32)
62
+ else:
63
+ X = np.rint(np.clip(np.asarray(X), 0, None)).astype(np.int32)
64
+ return X
65
+
66
+ def _make_pseudospots(
67
+ adata_ref: ad.AnnData,
68
+ *,
69
+ celltype_key: str,
70
+ stage_key: str,
71
+ pseudospot_size: int = 10,
72
+ ) -> ad.AnnData:
73
+ X = adata_ref.layers["counts"]
74
+ if sparse.issparse(X):
75
+ X = X.tocsr(copy=False)
76
+
77
+ obs = adata_ref.obs[[celltype_key, stage_key]].astype(str).copy()
78
+ obs["cell_state"] = obs[celltype_key] + "::" + obs[stage_key]
79
+
80
+ groups = obs.groupby("cell_state", observed=True).indices
81
+
82
+ X_ps_list = []
83
+ obs_ps_rows = []
84
+ rng = np.random.default_rng(0)
85
+
86
+ for cs, idxs in groups.items():
87
+ idxs = np.asarray(idxs)
88
+ if idxs.size == 0:
89
+ continue
90
+
91
+ rng.shuffle(idxs)
92
+ n_spots = max(1, idxs.size // pseudospot_size)
93
+
94
+ for k in range(n_spots):
95
+ start = k * pseudospot_size
96
+ end = (k + 1) * pseudospot_size
97
+ sel = idxs[start:end]
98
+ if sel.size == 0:
99
+ continue
100
+
101
+ v = X[sel].sum(axis=0)
102
+ v = np.asarray(v).ravel() if sparse.issparse(v) else np.asarray(v).ravel()
103
+
104
+ ct, st = cs.split("::", 1)
105
+ X_ps_list.append(v)
106
+ obs_ps_rows.append((ct, st, cs))
107
+
108
+ if not X_ps_list:
109
+ raise RuntimeError(
110
+ "No pseudospots could be created. Check celltype_key/stage_key "
111
+ "and pseudospot_size."
112
+ )
113
+
114
+ X_ps = np.vstack(X_ps_list).astype(np.int32)
115
+ obs_ps = pd.DataFrame(
116
+ obs_ps_rows,
117
+ columns=[celltype_key, stage_key, "cell_state"],
118
+ index=[f"ps_{i}" for i in range(len(X_ps))],
119
+ )
120
+ obs_ps["tech"] = "reference"
121
+
122
+ adata_ps = ad.AnnData(X=X_ps, obs=obs_ps, var=adata_ref.var.copy())
123
+ adata_ps.layers["counts"] = X_ps
124
+ return adata_ps
125
+
126
+
127
+ def _match_library_sizes(
128
+ adata_ref_ps: ad.AnnData,
129
+ adata_sp: ad.AnnData,
130
+ *,
131
+ layer: str = "counts",
132
+ ) -> None:
133
+
134
+ def _libsizes(a: ad.AnnData) -> np.ndarray:
135
+ Xc = a.layers[layer]
136
+ return (
137
+ np.asarray(Xc.sum(axis=1)).ravel()
138
+ if sparse.issparse(Xc)
139
+ else np.asarray(Xc.sum(axis=1)).ravel()
140
+ )
141
+
142
+ lib_ref = _libsizes(adata_ref_ps)
143
+ lib_sp = _libsizes(adata_sp)
144
+
145
+ med_ref = float(np.median(lib_ref)) if lib_ref.size else 0.0
146
+ med_sp = float(np.median(lib_sp)) if lib_sp.size else 0.0
147
+
148
+ if med_ref <= 0 or med_sp <= 0:
149
+ logger.warning("Unable to estimate library sizes; skipping scaling.")
150
+ return
151
+
152
+ ratio = med_sp / med_ref
153
+ Xc = adata_ref_ps.layers[layer]
154
+ if sparse.issparse(Xc):
155
+ Xc = Xc.tocsr(copy=False)
156
+ Xc.data = np.rint(Xc.data * ratio).astype(np.int32)
157
+ else:
158
+ Xc = np.rint(Xc * ratio).astype(np.int32)
159
+ adata_ref_ps.layers[layer] = Xc
160
+
161
+ def _run_transfer_single(
162
+ adata_spatial_in: ad.AnnData,
163
+ adata_ref_in: ad.AnnData,
164
+ *,
165
+ celltype_key: str = "celltype",
166
+ stage_key: str = "oncoordinate_stage",
167
+ counts_layer: Optional[str] = "counts",
168
+ stage_order: Optional[Sequence[str]] = ("normal", "abnormal", "pre-malignant", "malignant"),
169
+ max_reference_cells: int = 900_000,
170
+ pseudospot_size: int = 10,
171
+ n_hvg: int = 2000,
172
+ scvi_max_epochs: int = 250,
173
+ scanvi_max_epochs: int = 250,
174
+ scvi_model_dir: Optional[Union[str, Path]] = None,
175
+ scanvi_model_dir: Optional[Union[str, Path]] = None,
176
+ ) -> ad.AnnData:
177
+ adata_ref = adata_ref_in.copy()
178
+ adata_spatial = adata_spatial_in.copy()
179
+
180
+ if celltype_key not in adata_ref.obs.columns:
181
+ raise KeyError(f"celltype_key '{celltype_key}' not found in reference.obs")
182
+ if stage_key not in adata_ref.obs.columns:
183
+ raise KeyError(f"stage_key '{stage_key}' not found in reference.obs")
184
+
185
+ for a in (adata_ref, adata_spatial):
186
+ a.obs.index = a.obs.index.astype(str)
187
+ a.var.index = a.var.index.astype(str)
188
+ a.var_names = a.var_names.astype(str).str.upper()
189
+ a.var_names_make_unique()
190
+
191
+ shared_genes = np.intersect1d(adata_ref.var_names, adata_spatial.var_names)
192
+ if shared_genes.size == 0:
193
+ raise ValueError("No overlapping genes between reference and spatial data.")
194
+
195
+ adata_ref = adata_ref[:, shared_genes].copy()
196
+ adata_sp = adata_spatial[:, shared_genes].copy()
197
+ adata_ref.layers["counts"] = _get_counts_layer(adata_ref, counts_layer)
198
+ adata_sp.layers["counts"] = _get_counts_layer(adata_sp, counts_layer)
199
+
200
+ if adata_ref.n_obs > max_reference_cells:
201
+ logger.info(
202
+ f"Reference has {adata_ref.n_obs} cells; downsampling to ~{max_reference_cells}."
203
+ )
204
+ rng = np.random.default_rng(0)
205
+ groups = adata_ref.obs.groupby([celltype_key, stage_key], observed=True).indices
206
+ total = adata_ref.n_obs
207
+ keep_idx = []
208
+ for _, idxs in groups.items():
209
+ idxs = np.asarray(idxs)
210
+ if idxs.size == 0:
211
+ continue
212
+ n = min(int(idxs.size / total * max_reference_cells), idxs.size)
213
+ if n > 0:
214
+ keep_idx.append(rng.choice(idxs, size=n, replace=False))
215
+ if not keep_idx:
216
+ raise RuntimeError("Subsampling removed all reference cells.")
217
+ keep_idx = np.concatenate(keep_idx)
218
+ adata_ref = adata_ref[keep_idx].copy()
219
+
220
+ idx_multi = pd.MultiIndex.from_frame(
221
+ adata_ref.obs[[celltype_key, stage_key]].astype(str)
222
+ )
223
+ counts = idx_multi.value_counts()
224
+ valid_states = counts[counts >= 3].index
225
+ mask_valid = idx_multi.isin(valid_states)
226
+ adata_ref = adata_ref[mask_valid].copy()
227
+ if adata_ref.n_obs == 0:
228
+ raise RuntimeError(
229
+ "No reference cells left after filtering small (celltype, stage) groups."
230
+ )
231
+
232
+ adata_ps = _make_pseudospots(
233
+ adata_ref,
234
+ celltype_key=celltype_key,
235
+ stage_key=stage_key,
236
+ pseudospot_size=pseudospot_size,
237
+ )
238
+
239
+ adata_sp = adata_sp.copy()
240
+ adata_sp.obs["tech"] = "spatial"
241
+ adata_sp.obs["cell_state"] = "unlabeled"
242
+
243
+ _match_library_sizes(adata_ps, adata_sp, layer="counts")
244
+
245
+ adata_comb = ad.concat(
246
+ [adata_ps, adata_sp],
247
+ join="inner",
248
+ merge="same",
249
+ label="__source__",
250
+ keys=["reference", "spatial"],
251
+ index_unique=None,
252
+ )
253
+
254
+ sc.pp.highly_variable_genes(
255
+ adata_comb,
256
+ layer="counts",
257
+ n_top_genes=n_hvg,
258
+ flavor="seurat_v3",
259
+ batch_key="tech",
260
+ inplace=True,
261
+ )
262
+ adata_comb = adata_comb[:, adata_comb.var["highly_variable"].values].copy()
263
+
264
+ scvi.model.SCVI.setup_anndata(
265
+ adata_comb,
266
+ batch_key="tech",
267
+ labels_key="cell_state",
268
+ layer="counts",
269
+ )
270
+
271
+ scvi_dir = Path(scvi_model_dir) if scvi_model_dir is not None else None
272
+ if scvi_dir is not None and scvi_dir.exists():
273
+ logger.info(f"Loading SCVI model from {scvi_dir}")
274
+ scvi_model = scvi.model.SCVI.load(scvi_dir, adata=adata_comb)
275
+ else:
276
+ scvi_model = scvi.model.SCVI(adata_comb, n_layers=3, n_latent=32)
277
+ scvi_model.train(max_epochs=scvi_max_epochs, batch_size=512)
278
+ if scvi_dir is not None:
279
+ scvi_model.save(scvi_dir, overwrite=True)
280
+
281
+ adata_comb.obsm["X_oncoordinate_scvi"] = scvi_model.get_latent_representation()
282
+ scanvi_dir = Path(scanvi_model_dir) if scanvi_model_dir is not None else None
283
+ if scanvi_dir is not None and scanvi_dir.exists():
284
+ logger.info(f"Loading SCANVI model from {scanvi_dir}")
285
+ scanvi_model = scvi.model.SCANVI.load(scanvi_dir, adata=adata_comb)
286
+ else:
287
+ scanvi_model = scvi.model.SCANVI.from_scvi_model(
288
+ scvi_model,
289
+ unlabeled_category="unlabeled",
290
+ )
291
+ scanvi_model.train(
292
+ max_epochs=scanvi_max_epochs,
293
+ batch_size=256,
294
+ plan_kwargs={"lr": 5e-4},
295
+ gradient_clip_val=10.0,
296
+ )
297
+ if scanvi_dir is not None:
298
+ scanvi_model.save(scanvi_dir, overwrite=True)
299
+
300
+ adata_comb.obsm["X_oncoordinate_scanvi"] = scanvi_model.get_latent_representation(
301
+ adata_comb
302
+ )
303
+
304
+ state_pred = scanvi_model.predict(adata_comb)
305
+ adata_comb.obs["oncoordinate_tl_cell_state"] = state_pred.astype(str)
306
+ soft = scanvi_model.predict(adata_comb, soft=True)
307
+
308
+ if isinstance(soft, pd.DataFrame):
309
+ soft_df = soft.loc[adata_comb.obs_names]
310
+ else:
311
+ soft_arr = np.asarray(soft)
312
+ if soft_arr.ndim == 1:
313
+ soft_arr = soft_arr.reshape(-1, 1)
314
+ soft_df = pd.DataFrame(
315
+ soft_arr,
316
+ index=adata_comb.obs_names,
317
+ )
318
+
319
+ adata_comb.obsm["oncoordinate_tl_state_proba"] = soft_df.to_numpy()
320
+ adata_comb.obs["oncoordinate_tl_state_confidence"] = soft_df.max(axis=1)
321
+
322
+ def _split_state(s: str) -> tuple[str, str]:
323
+ return s.split("::", 1) if "::" in s else (s, "NA")
324
+
325
+ ct_pred, st_pred = zip(
326
+ *[_split_state(s) for s in adata_comb.obs["oncoordinate_tl_cell_state"].astype(str)]
327
+ )
328
+ adata_comb.obs["oncoordinate_tl_celltype"] = list(ct_pred)
329
+
330
+ if stage_order is None:
331
+ stage_order_use = sorted(set(st_pred))
332
+ else:
333
+ stage_order_use = list(stage_order)
334
+
335
+ adata_comb.obs["oncoordinate_tl_stage"] = pd.Categorical(
336
+ list(st_pred),
337
+ categories=stage_order_use,
338
+ ordered=True,
339
+ )
340
+
341
+ if isinstance(soft, pd.DataFrame):
342
+ for stage in stage_order_use:
343
+ cols_stage = [c for c in soft_df.columns if c.endswith(f"::{stage}")]
344
+ if cols_stage:
345
+ series = soft_df[cols_stage].sum(axis=1)
346
+ else:
347
+ series = pd.Series(0.0, index=soft_df.index)
348
+ adata_comb.obs[f"oncoordinate_tl_stage_proba_{stage}"] = series
349
+ else:
350
+ for stage in stage_order_use:
351
+ adata_comb.obs[f"oncoordinate_tl_stage_proba_{stage}"] = 0.0
352
+
353
+ is_spatial = adata_comb.obs["tech"].values == "spatial"
354
+ idx_spatial_comb = adata_comb.obs.index[is_spatial]
355
+
356
+ spatial_out = adata_spatial_in.copy()
357
+ pred_obs = adata_comb.obs.loc[idx_spatial_comb]
358
+ pred_obs = pred_obs.reindex(spatial_out.obs_names)
359
+
360
+ cols_to_copy = [
361
+ "oncoordinate_tl_cell_state",
362
+ "oncoordinate_tl_celltype",
363
+ "oncoordinate_tl_stage",
364
+ "oncoordinate_tl_state_confidence",
365
+ ] + [f"oncoordinate_tl_stage_proba_{stage}" for stage in stage_order_use]
366
+
367
+ for col in cols_to_copy:
368
+ spatial_out.obs[col] = pred_obs[col]
369
+
370
+ latent_all = adata_comb.obsm["X_oncoordinate_scanvi"]
371
+ latent_spatial = latent_all[is_spatial, :]
372
+ latent_df = pd.DataFrame(latent_spatial, index=idx_spatial_comb)
373
+ latent_df = latent_df.reindex(spatial_out.obs_names)
374
+ spatial_out.obsm["X_oncoordinate_scanvi"] = latent_df.to_numpy()
375
+
376
+ spatial_out.uns.setdefault("oncoordinate_tl_params", {})
377
+ spatial_out.uns["oncoordinate_tl_params"].update(
378
+ dict(
379
+ celltype_key=celltype_key,
380
+ stage_key=stage_key,
381
+ counts_layer=counts_layer,
382
+ max_reference_cells=int(max_reference_cells),
383
+ pseudospot_size=int(pseudospot_size),
384
+ n_hvg=int(n_hvg),
385
+ stage_order=list(stage_order_use),
386
+ )
387
+ )
388
+
389
+ return spatial_out
390
+
391
+ def run_transfer(
392
+ spatial: Union[str, Path, ad.AnnData],
393
+ reference: Optional[Union[str, Path, ad.AnnData]] = None,
394
+ *,
395
+ celltype_key: str = "celltype",
396
+ stage_key: str = "oncoordinate_stage",
397
+ counts_layer: Optional[str] = "counts",
398
+ stage_order: Optional[Sequence[str]] = ("normal", "abnormal", "pre-malignant", "malignant"),
399
+ max_reference_cells: int = 900_000,
400
+ pseudospot_size: int = 10,
401
+ n_hvg: int = 2000,
402
+ scvi_max_epochs: int = 250,
403
+ scanvi_max_epochs: int = 250,
404
+ scvi_model_dir: Optional[Union[str, Path]] = None,
405
+ scanvi_model_dir: Optional[Union[str, Path]] = None,
406
+ per_sample: bool = True,
407
+ sample_key: str = "sample",
408
+ ) -> ad.AnnData:
409
+ adata_ref = _load_reference_adata(reference)
410
+ adata_spatial_in = _load_spatial_adata(spatial)
411
+
412
+ if (
413
+ per_sample
414
+ and sample_key is not None
415
+ and sample_key in adata_spatial_in.obs.columns
416
+ ):
417
+ groups = adata_spatial_in.obs.groupby(sample_key, observed=True).indices
418
+ if len(groups) > 1:
419
+ logger.info(
420
+ "Found %d spatial samples in obs['%s']; running transfer per sample.",
421
+ len(groups),
422
+ sample_key,
423
+ )
424
+ outs = []
425
+ for sample_name, idx in groups.items():
426
+ sub = adata_spatial_in[idx].copy()
427
+ logger.info(
428
+ "Running transfer for sample '%s' with %d spots.",
429
+ sample_name,
430
+ sub.n_obs,
431
+ )
432
+ out_sub = _run_transfer_single(
433
+ sub,
434
+ adata_ref,
435
+ celltype_key=celltype_key,
436
+ stage_key=stage_key,
437
+ counts_layer=counts_layer,
438
+ stage_order=stage_order,
439
+ max_reference_cells=max_reference_cells,
440
+ pseudospot_size=pseudospot_size,
441
+ n_hvg=n_hvg,
442
+ scvi_max_epochs=scvi_max_epochs,
443
+ scanvi_max_epochs=scanvi_max_epochs,
444
+ scvi_model_dir=scvi_model_dir,
445
+ scanvi_model_dir=scanvi_model_dir,
446
+ )
447
+ outs.append(out_sub)
448
+
449
+ combined = ad.concat(
450
+ outs,
451
+ join="outer",
452
+ merge="same",
453
+ label=None,
454
+ index_unique=None,
455
+ )
456
+ return combined
457
+
458
+ return _run_transfer_single(
459
+ adata_spatial_in,
460
+ adata_ref,
461
+ celltype_key=celltype_key,
462
+ stage_key=stage_key,
463
+ counts_layer=counts_layer,
464
+ stage_order=stage_order,
465
+ max_reference_cells=max_reference_cells,
466
+ pseudospot_size=pseudospot_size,
467
+ n_hvg=n_hvg,
468
+ scvi_max_epochs=scvi_max_epochs,
469
+ scanvi_max_epochs=scanvi_max_epochs,
470
+ scvi_model_dir=scvi_model_dir,
471
+ scanvi_model_dir=scanvi_model_dir,
472
+ )
Binary file