chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cell2location deconvolution method.
|
|
3
|
+
|
|
4
|
+
Cell2location uses a two-stage training process:
|
|
5
|
+
1. Reference model (NB regression) learns cell type gene expression signatures
|
|
6
|
+
2. Cell2location model performs spatial mapping using these signatures
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import gc
|
|
10
|
+
import warnings
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import anndata as ad
|
|
18
|
+
|
|
19
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
20
|
+
|
|
21
|
+
from ...utils.dependency_manager import is_available, require
|
|
22
|
+
from ...utils.device_utils import get_device
|
|
23
|
+
from ...utils.exceptions import DataError, ProcessingError
|
|
24
|
+
from ...utils.image_utils import non_interactive_backend
|
|
25
|
+
from ...utils.mcp_utils import suppress_output
|
|
26
|
+
from .base import (
|
|
27
|
+
PreparedDeconvolutionData,
|
|
28
|
+
check_model_convergence,
|
|
29
|
+
create_deconvolution_stats,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def apply_gene_filtering(
|
|
34
|
+
adata: "ad.AnnData",
|
|
35
|
+
ctx: "ToolContext",
|
|
36
|
+
cell_count_cutoff: int = 5,
|
|
37
|
+
cell_percentage_cutoff2: float = 0.03,
|
|
38
|
+
nonz_mean_cutoff: float = 1.12,
|
|
39
|
+
) -> "ad.AnnData":
|
|
40
|
+
"""Apply cell2location's official gene filtering.
|
|
41
|
+
|
|
42
|
+
Reference: cell2location tutorial - "very permissive gene selection"
|
|
43
|
+
|
|
44
|
+
This function is called by the preprocess hook in __init__.py before
|
|
45
|
+
common gene identification.
|
|
46
|
+
|
|
47
|
+
Note: The original filter_genes function creates a matplotlib figure.
|
|
48
|
+
We suppress this by using Agg backend and closing the figure immediately.
|
|
49
|
+
"""
|
|
50
|
+
if not is_available("cell2location"):
|
|
51
|
+
await ctx.warning(
|
|
52
|
+
"cell2location.utils.filtering not available. "
|
|
53
|
+
"Skipping gene filtering (may degrade results)."
|
|
54
|
+
)
|
|
55
|
+
return adata.copy()
|
|
56
|
+
|
|
57
|
+
import matplotlib.pyplot as plt
|
|
58
|
+
|
|
59
|
+
with non_interactive_backend():
|
|
60
|
+
from cell2location.utils.filtering import filter_genes
|
|
61
|
+
|
|
62
|
+
selected = filter_genes(
|
|
63
|
+
adata,
|
|
64
|
+
cell_count_cutoff=cell_count_cutoff,
|
|
65
|
+
cell_percentage_cutoff2=cell_percentage_cutoff2,
|
|
66
|
+
nonz_mean_cutoff=nonz_mean_cutoff,
|
|
67
|
+
)
|
|
68
|
+
plt.close("all")
|
|
69
|
+
|
|
70
|
+
return adata[:, selected].copy()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def deconvolve(
|
|
74
|
+
data: PreparedDeconvolutionData,
|
|
75
|
+
ref_model_epochs: int = 250,
|
|
76
|
+
n_epochs: int = 30000,
|
|
77
|
+
n_cells_per_spot: int = 30,
|
|
78
|
+
detection_alpha: float = 20.0,
|
|
79
|
+
use_gpu: bool = False,
|
|
80
|
+
batch_key: Optional[str] = None,
|
|
81
|
+
categorical_covariate_keys: Optional[list[str]] = None,
|
|
82
|
+
ref_model_lr: float = 0.002,
|
|
83
|
+
cell2location_lr: float = 0.005,
|
|
84
|
+
ref_model_train_size: float = 1.0,
|
|
85
|
+
cell2location_train_size: float = 1.0,
|
|
86
|
+
early_stopping: bool = False,
|
|
87
|
+
early_stopping_patience: int = 45,
|
|
88
|
+
early_stopping_threshold: float = 0.0,
|
|
89
|
+
use_aggressive_training: bool = False,
|
|
90
|
+
validation_size: float = 0.1,
|
|
91
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
92
|
+
"""Deconvolve spatial data using Cell2location.
|
|
93
|
+
|
|
94
|
+
Note: Gene filtering is handled by the preprocess hook in __init__.py.
|
|
95
|
+
The data parameter contains already filtered and subset data.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
data: Prepared deconvolution data (immutable, already filtered)
|
|
99
|
+
ref_model_epochs: Epochs for reference model (default: 250)
|
|
100
|
+
n_epochs: Epochs for Cell2location model (default: 30000)
|
|
101
|
+
n_cells_per_spot: Expected cells per location (default: 30)
|
|
102
|
+
detection_alpha: RNA detection sensitivity (default: 20, NEW 2024)
|
|
103
|
+
use_gpu: Use GPU acceleration
|
|
104
|
+
batch_key: Column for batch correction
|
|
105
|
+
categorical_covariate_keys: Technical covariates
|
|
106
|
+
ref_model_lr: Reference model learning rate (default: 0.002)
|
|
107
|
+
cell2location_lr: Cell2location learning rate (default: 0.005)
|
|
108
|
+
*_train_size: Training data fractions
|
|
109
|
+
early_stopping*: Early stopping parameters
|
|
110
|
+
use_aggressive_training: Use train_aggressive() method
|
|
111
|
+
validation_size: Validation set size
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
115
|
+
"""
|
|
116
|
+
require("cell2location")
|
|
117
|
+
from cell2location.models import Cell2location, RegressionModel
|
|
118
|
+
|
|
119
|
+
cell_type_key = data.cell_type_key
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
device = get_device(prefer_gpu=use_gpu)
|
|
123
|
+
|
|
124
|
+
# Data already copied in prepare_deconvolution
|
|
125
|
+
ref = data.reference
|
|
126
|
+
sp = data.spatial
|
|
127
|
+
|
|
128
|
+
# Ensure float32 for scvi-tools compatibility
|
|
129
|
+
if ref.X.dtype != np.float32:
|
|
130
|
+
ref.X = ref.X.astype(np.float32)
|
|
131
|
+
if sp.X.dtype != np.float32:
|
|
132
|
+
sp.X = sp.X.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
# Handle NaN in cell types
|
|
135
|
+
if ref.obs[cell_type_key].isna().any():
|
|
136
|
+
warnings.warn(
|
|
137
|
+
f"Reference has NaN in {cell_type_key}. Excluding.",
|
|
138
|
+
UserWarning,
|
|
139
|
+
stacklevel=2,
|
|
140
|
+
)
|
|
141
|
+
ref = ref[~ref.obs[cell_type_key].isna()].copy()
|
|
142
|
+
|
|
143
|
+
# ===== Stage 1: Train Reference Model =====
|
|
144
|
+
RegressionModel.setup_anndata(
|
|
145
|
+
adata=ref,
|
|
146
|
+
labels_key=cell_type_key,
|
|
147
|
+
batch_key=batch_key,
|
|
148
|
+
categorical_covariate_keys=categorical_covariate_keys,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
ref_model = RegressionModel(ref)
|
|
152
|
+
with suppress_output():
|
|
153
|
+
train_kwargs = _build_train_kwargs(
|
|
154
|
+
epochs=ref_model_epochs,
|
|
155
|
+
lr=ref_model_lr,
|
|
156
|
+
train_size=ref_model_train_size,
|
|
157
|
+
device=device,
|
|
158
|
+
early_stopping=early_stopping,
|
|
159
|
+
early_stopping_patience=early_stopping_patience,
|
|
160
|
+
validation_size=validation_size,
|
|
161
|
+
use_aggressive=use_aggressive_training,
|
|
162
|
+
)
|
|
163
|
+
ref_model.train(**train_kwargs)
|
|
164
|
+
|
|
165
|
+
# Check convergence
|
|
166
|
+
converged, warning_msg = check_model_convergence(ref_model, "ReferenceModel")
|
|
167
|
+
if not converged and warning_msg:
|
|
168
|
+
warnings.warn(warning_msg, UserWarning, stacklevel=2)
|
|
169
|
+
|
|
170
|
+
# Export reference signatures
|
|
171
|
+
ref = ref_model.export_posterior(
|
|
172
|
+
ref, sample_kwargs={"num_samples": 1000, "batch_size": 2500}
|
|
173
|
+
)
|
|
174
|
+
ref_signatures = _extract_reference_signatures(ref)
|
|
175
|
+
|
|
176
|
+
# ===== Stage 2: Train Cell2location Model =====
|
|
177
|
+
Cell2location.setup_anndata(
|
|
178
|
+
adata=sp,
|
|
179
|
+
batch_key=batch_key,
|
|
180
|
+
categorical_covariate_keys=categorical_covariate_keys,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
cell2loc_model = Cell2location(
|
|
184
|
+
sp,
|
|
185
|
+
cell_state_df=ref_signatures,
|
|
186
|
+
N_cells_per_location=n_cells_per_spot,
|
|
187
|
+
detection_alpha=detection_alpha,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
with suppress_output():
|
|
191
|
+
train_kwargs = _build_train_kwargs(
|
|
192
|
+
epochs=n_epochs,
|
|
193
|
+
lr=cell2location_lr,
|
|
194
|
+
train_size=cell2location_train_size,
|
|
195
|
+
device=device,
|
|
196
|
+
early_stopping=early_stopping,
|
|
197
|
+
early_stopping_patience=early_stopping_patience,
|
|
198
|
+
validation_size=validation_size,
|
|
199
|
+
use_aggressive=use_aggressive_training,
|
|
200
|
+
)
|
|
201
|
+
cell2loc_model.train(**train_kwargs)
|
|
202
|
+
|
|
203
|
+
# Check convergence
|
|
204
|
+
converged, warning_msg = check_model_convergence(
|
|
205
|
+
cell2loc_model, "Cell2location"
|
|
206
|
+
)
|
|
207
|
+
if not converged and warning_msg:
|
|
208
|
+
warnings.warn(warning_msg, UserWarning, stacklevel=2)
|
|
209
|
+
|
|
210
|
+
# Export results
|
|
211
|
+
sp = cell2loc_model.export_posterior(
|
|
212
|
+
sp, sample_kwargs={"num_samples": 1000, "batch_size": 2500}
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Extract cell abundance
|
|
216
|
+
cell_abundance = _extract_cell_abundance(sp)
|
|
217
|
+
|
|
218
|
+
# Create proportions DataFrame
|
|
219
|
+
proportions = pd.DataFrame(
|
|
220
|
+
cell_abundance,
|
|
221
|
+
index=sp.obs_names,
|
|
222
|
+
columns=ref_signatures.columns,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Create statistics
|
|
226
|
+
stats = create_deconvolution_stats(
|
|
227
|
+
proportions,
|
|
228
|
+
data.common_genes,
|
|
229
|
+
method="Cell2location",
|
|
230
|
+
device=device,
|
|
231
|
+
n_epochs=n_epochs,
|
|
232
|
+
n_cells_per_spot=n_cells_per_spot,
|
|
233
|
+
detection_alpha=detection_alpha,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Add model performance metrics
|
|
237
|
+
if hasattr(cell2loc_model, "history") and cell2loc_model.history is not None:
|
|
238
|
+
history = cell2loc_model.history
|
|
239
|
+
if "elbo_train" in history and not history["elbo_train"].empty:
|
|
240
|
+
stats["final_elbo"] = float(history["elbo_train"].iloc[-1])
|
|
241
|
+
|
|
242
|
+
# Memory cleanup
|
|
243
|
+
del cell2loc_model, ref_model
|
|
244
|
+
del ref, sp, ref_signatures
|
|
245
|
+
gc.collect()
|
|
246
|
+
|
|
247
|
+
return proportions, stats
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
if isinstance(e, (ProcessingError, DataError)):
|
|
251
|
+
raise
|
|
252
|
+
raise ProcessingError(f"Cell2location deconvolution failed: {e}") from e
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _build_train_kwargs(
|
|
256
|
+
epochs: int,
|
|
257
|
+
lr: float,
|
|
258
|
+
train_size: float,
|
|
259
|
+
device: str,
|
|
260
|
+
early_stopping: bool,
|
|
261
|
+
early_stopping_patience: int,
|
|
262
|
+
validation_size: float,
|
|
263
|
+
use_aggressive: bool,
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Build training kwargs for scvi-tools models."""
|
|
266
|
+
if use_aggressive:
|
|
267
|
+
kwargs = {"max_epochs": epochs, "lr": lr}
|
|
268
|
+
if device == "cuda":
|
|
269
|
+
kwargs["accelerator"] = "gpu"
|
|
270
|
+
if early_stopping:
|
|
271
|
+
kwargs["early_stopping"] = True
|
|
272
|
+
kwargs["early_stopping_patience"] = early_stopping_patience
|
|
273
|
+
kwargs["check_val_every_n_epoch"] = 1
|
|
274
|
+
kwargs["train_size"] = 1.0 - validation_size
|
|
275
|
+
else:
|
|
276
|
+
kwargs["train_size"] = train_size
|
|
277
|
+
else:
|
|
278
|
+
kwargs = {
|
|
279
|
+
"max_epochs": epochs,
|
|
280
|
+
"batch_size": 2500,
|
|
281
|
+
"lr": lr,
|
|
282
|
+
"train_size": train_size,
|
|
283
|
+
}
|
|
284
|
+
if device == "cuda":
|
|
285
|
+
kwargs["accelerator"] = "gpu"
|
|
286
|
+
return kwargs
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _extract_reference_signatures(ref: "ad.AnnData") -> pd.DataFrame:
|
|
290
|
+
"""Extract reference signatures from trained RegressionModel."""
|
|
291
|
+
factor_names = ref.uns["mod"]["factor_names"]
|
|
292
|
+
cols = [f"means_per_cluster_mu_fg_{i}" for i in factor_names]
|
|
293
|
+
|
|
294
|
+
if "means_per_cluster_mu_fg" in ref.varm:
|
|
295
|
+
signatures = ref.varm["means_per_cluster_mu_fg"][cols].copy()
|
|
296
|
+
else:
|
|
297
|
+
signatures = ref.var[cols].copy()
|
|
298
|
+
|
|
299
|
+
signatures.columns = factor_names
|
|
300
|
+
return signatures
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _extract_cell_abundance(sp: "ad.AnnData"):
|
|
304
|
+
"""Extract cell abundance from Cell2location results.
|
|
305
|
+
|
|
306
|
+
Cell2location stores results as DataFrames with prefixed column names like
|
|
307
|
+
'q05cell_abundance_w_sf_CellType'. We need to extract the values and
|
|
308
|
+
return them as a numpy array for consistent downstream processing.
|
|
309
|
+
"""
|
|
310
|
+
possible_keys = [
|
|
311
|
+
"q05_cell_abundance_w_sf",
|
|
312
|
+
"means_cell_abundance_w_sf",
|
|
313
|
+
"q50_cell_abundance_w_sf",
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
for key in possible_keys:
|
|
317
|
+
if key in sp.obsm:
|
|
318
|
+
result = sp.obsm[key]
|
|
319
|
+
if hasattr(result, "values"):
|
|
320
|
+
return result.values
|
|
321
|
+
return result
|
|
322
|
+
|
|
323
|
+
raise ProcessingError(
|
|
324
|
+
f"Cell2location did not produce expected output. "
|
|
325
|
+
f"Available keys: {list(sp.obsm.keys())}"
|
|
326
|
+
)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DestVI deconvolution method.
|
|
3
|
+
|
|
4
|
+
DestVI performs multi-resolution deconvolution by first training a CondSCVI
|
|
5
|
+
model on reference data, then using it to initialize a DestVI model.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from ...utils.dependency_manager import is_available
|
|
14
|
+
from ...utils.exceptions import DataError, DependencyError, ProcessingError
|
|
15
|
+
from .base import PreparedDeconvolutionData, create_deconvolution_stats
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def deconvolve(
|
|
19
|
+
data: PreparedDeconvolutionData,
|
|
20
|
+
n_epochs: int = 10000,
|
|
21
|
+
n_hidden: int = 128,
|
|
22
|
+
n_latent: int = 10,
|
|
23
|
+
n_layers: int = 1,
|
|
24
|
+
dropout_rate: float = 0.1,
|
|
25
|
+
learning_rate: float = 1e-3,
|
|
26
|
+
train_size: float = 0.9,
|
|
27
|
+
vamp_prior_p: int = 15,
|
|
28
|
+
l1_reg: float = 10.0,
|
|
29
|
+
use_gpu: bool = False,
|
|
30
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
31
|
+
"""Deconvolve spatial data using DestVI from scvi-tools.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
data: Prepared deconvolution data (immutable)
|
|
35
|
+
n_epochs: Total epochs (split between CondSCVI and DestVI)
|
|
36
|
+
n_hidden: Hidden units in neural networks
|
|
37
|
+
n_latent: Latent space dimensionality
|
|
38
|
+
n_layers: Number of layers
|
|
39
|
+
dropout_rate: Dropout rate
|
|
40
|
+
learning_rate: Learning rate
|
|
41
|
+
train_size: Fraction for training (default: 0.9)
|
|
42
|
+
vamp_prior_p: VampPrior components (default: 15)
|
|
43
|
+
l1_reg: L1 regularization (default: 10.0)
|
|
44
|
+
use_gpu: Use GPU acceleration
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
48
|
+
"""
|
|
49
|
+
if not is_available("scvi-tools"):
|
|
50
|
+
raise DependencyError(
|
|
51
|
+
"scvi-tools is required for DestVI. Install with: pip install scvi-tools"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
import scvi
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
# Data already copied in prepare_deconvolution
|
|
58
|
+
spatial_data = data.spatial
|
|
59
|
+
ref_data = data.reference
|
|
60
|
+
|
|
61
|
+
# Validate cell types
|
|
62
|
+
if data.n_cell_types < 2:
|
|
63
|
+
raise DataError(
|
|
64
|
+
f"Reference needs at least 2 cell types, found {data.n_cell_types}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Calculate epoch distribution
|
|
68
|
+
condscvi_epochs = max(400, n_epochs // 5)
|
|
69
|
+
destvi_epochs = max(200, n_epochs // 10)
|
|
70
|
+
|
|
71
|
+
# Device setting
|
|
72
|
+
accelerator = "gpu" if use_gpu else "cpu"
|
|
73
|
+
plan_kwargs = {"lr": learning_rate}
|
|
74
|
+
|
|
75
|
+
# ===== Stage 1: Train CondSCVI on reference =====
|
|
76
|
+
scvi.model.CondSCVI.setup_anndata(
|
|
77
|
+
ref_data,
|
|
78
|
+
labels_key=data.cell_type_key,
|
|
79
|
+
batch_key=None,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
condscvi_model = scvi.model.CondSCVI(
|
|
83
|
+
ref_data,
|
|
84
|
+
n_hidden=n_hidden,
|
|
85
|
+
n_latent=n_latent,
|
|
86
|
+
n_layers=n_layers,
|
|
87
|
+
dropout_rate=dropout_rate,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
condscvi_model.train(
|
|
91
|
+
max_epochs=condscvi_epochs,
|
|
92
|
+
accelerator=accelerator,
|
|
93
|
+
train_size=train_size,
|
|
94
|
+
plan_kwargs=plan_kwargs,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# ===== Stage 2: Train DestVI on spatial =====
|
|
98
|
+
scvi.model.DestVI.setup_anndata(spatial_data)
|
|
99
|
+
|
|
100
|
+
destvi_model = scvi.model.DestVI.from_rna_model(
|
|
101
|
+
spatial_data,
|
|
102
|
+
condscvi_model,
|
|
103
|
+
vamp_prior_p=vamp_prior_p,
|
|
104
|
+
l1_reg=l1_reg,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
destvi_model.train(
|
|
108
|
+
max_epochs=destvi_epochs,
|
|
109
|
+
accelerator=accelerator,
|
|
110
|
+
train_size=train_size,
|
|
111
|
+
plan_kwargs=plan_kwargs,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Get proportions
|
|
115
|
+
proportions = destvi_model.get_proportions()
|
|
116
|
+
proportions.index = spatial_data.obs_names
|
|
117
|
+
|
|
118
|
+
if proportions.empty or len(proportions) != spatial_data.n_obs:
|
|
119
|
+
raise ProcessingError("Failed to extract valid proportions from DestVI")
|
|
120
|
+
|
|
121
|
+
# Create statistics
|
|
122
|
+
stats = create_deconvolution_stats(
|
|
123
|
+
proportions,
|
|
124
|
+
data.common_genes,
|
|
125
|
+
method="DestVI",
|
|
126
|
+
device="gpu" if use_gpu else "cpu",
|
|
127
|
+
n_epochs=n_epochs,
|
|
128
|
+
condscvi_epochs=condscvi_epochs,
|
|
129
|
+
destvi_epochs=destvi_epochs,
|
|
130
|
+
n_hidden=n_hidden,
|
|
131
|
+
n_latent=n_latent,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Memory cleanup
|
|
135
|
+
del destvi_model, condscvi_model
|
|
136
|
+
del spatial_data, ref_data
|
|
137
|
+
gc.collect()
|
|
138
|
+
|
|
139
|
+
return proportions, stats
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
if isinstance(e, (DependencyError, DataError, ProcessingError)):
|
|
143
|
+
raise
|
|
144
|
+
raise ProcessingError(f"DestVI deconvolution failed: {e}") from e
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FlashDeconv deconvolution method.
|
|
3
|
+
|
|
4
|
+
FlashDeconv is an ultra-fast spatial transcriptomics deconvolution method
|
|
5
|
+
that uses random sketching for O(N) time complexity.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from ...utils.dependency_manager import is_available
|
|
13
|
+
from ...utils.exceptions import DependencyError, ProcessingError
|
|
14
|
+
from .base import PreparedDeconvolutionData, create_deconvolution_stats
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def deconvolve(
|
|
18
|
+
data: PreparedDeconvolutionData,
|
|
19
|
+
sketch_dim: int = 512,
|
|
20
|
+
lambda_spatial: float = 5000.0,
|
|
21
|
+
n_hvg: int = 2000,
|
|
22
|
+
n_markers_per_type: int = 50,
|
|
23
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
24
|
+
"""Deconvolve spatial data using FlashDeconv.
|
|
25
|
+
|
|
26
|
+
FlashDeconv is an ultra-fast deconvolution method with:
|
|
27
|
+
- O(N) time complexity via random sketching
|
|
28
|
+
- Processes 1M spots in ~3 minutes on CPU
|
|
29
|
+
- No GPU required
|
|
30
|
+
- Automatic marker gene selection
|
|
31
|
+
- Spatial regularization for smooth proportions
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
data: Prepared deconvolution data (immutable)
|
|
35
|
+
sketch_dim: Dimension for random sketching (default: 512)
|
|
36
|
+
lambda_spatial: Spatial regularization strength (default: 5000.0)
|
|
37
|
+
n_hvg: Number of highly variable genes to use (default: 2000)
|
|
38
|
+
n_markers_per_type: Number of marker genes per cell type (default: 50)
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
42
|
+
"""
|
|
43
|
+
if not is_available("flashdeconv"):
|
|
44
|
+
raise DependencyError(
|
|
45
|
+
"FlashDeconv is not available. Install with: pip install flashdeconv"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
import flashdeconv as fd
|
|
50
|
+
|
|
51
|
+
# Data already copied in prepare_deconvolution
|
|
52
|
+
adata_st = data.spatial
|
|
53
|
+
reference = data.reference
|
|
54
|
+
|
|
55
|
+
# Run FlashDeconv
|
|
56
|
+
fd.tl.deconvolve(
|
|
57
|
+
adata_st,
|
|
58
|
+
reference,
|
|
59
|
+
cell_type_key=data.cell_type_key,
|
|
60
|
+
sketch_dim=sketch_dim,
|
|
61
|
+
lambda_spatial=lambda_spatial,
|
|
62
|
+
n_hvg=n_hvg,
|
|
63
|
+
n_markers_per_type=n_markers_per_type,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Extract proportions
|
|
67
|
+
if "flashdeconv" not in adata_st.obsm:
|
|
68
|
+
raise ProcessingError(
|
|
69
|
+
"FlashDeconv did not produce output in adata.obsm['flashdeconv']"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
proportions = adata_st.obsm["flashdeconv"].copy()
|
|
73
|
+
|
|
74
|
+
# Ensure DataFrame format
|
|
75
|
+
if not isinstance(proportions, pd.DataFrame):
|
|
76
|
+
proportions = pd.DataFrame(
|
|
77
|
+
proportions,
|
|
78
|
+
index=data.spatial.obs_names,
|
|
79
|
+
columns=data.cell_types,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
proportions.index = data.spatial.obs_names
|
|
83
|
+
|
|
84
|
+
# Create statistics
|
|
85
|
+
stats = create_deconvolution_stats(
|
|
86
|
+
proportions,
|
|
87
|
+
data.common_genes,
|
|
88
|
+
method="FlashDeconv",
|
|
89
|
+
device="CPU",
|
|
90
|
+
sketch_dim=sketch_dim,
|
|
91
|
+
lambda_spatial=lambda_spatial,
|
|
92
|
+
n_hvg=n_hvg,
|
|
93
|
+
n_markers_per_type=n_markers_per_type,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return proportions, stats
|
|
97
|
+
|
|
98
|
+
except Exception as e:
|
|
99
|
+
if isinstance(e, (DependencyError, ProcessingError)):
|
|
100
|
+
raise
|
|
101
|
+
raise ProcessingError(f"FlashDeconv deconvolution failed: {e}") from e
|