chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RCTD (Robust Cell Type Decomposition) deconvolution method.
|
|
3
|
+
|
|
4
|
+
RCTD is an R-based deconvolution method that performs robust
|
|
5
|
+
decomposition of cell type mixtures via the spacexr package.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
from ...utils.dependency_manager import validate_r_package
|
|
18
|
+
from ...utils.exceptions import DataError, ParameterError, ProcessingError
|
|
19
|
+
from .base import PreparedDeconvolutionData, create_deconvolution_stats
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def deconvolve(
|
|
23
|
+
data: PreparedDeconvolutionData,
|
|
24
|
+
mode: str = "full",
|
|
25
|
+
max_cores: int = 4,
|
|
26
|
+
confidence_threshold: float = 10.0,
|
|
27
|
+
doublet_threshold: float = 25.0,
|
|
28
|
+
max_multi_types: int = 4,
|
|
29
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
30
|
+
"""Deconvolve spatial data using RCTD from spacexr R package.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
data: Prepared deconvolution data (immutable, includes spatial coordinates)
|
|
34
|
+
mode: RCTD mode - 'full', 'doublet', or 'multi'
|
|
35
|
+
max_cores: Maximum CPU cores
|
|
36
|
+
confidence_threshold: Confidence threshold
|
|
37
|
+
doublet_threshold: Doublet detection threshold
|
|
38
|
+
max_multi_types: Max cell types per spot in multi mode
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
42
|
+
"""
|
|
43
|
+
import anndata2ri
|
|
44
|
+
import rpy2.robjects as ro
|
|
45
|
+
from rpy2.robjects import numpy2ri, pandas2ri
|
|
46
|
+
from rpy2.robjects.conversion import localconverter
|
|
47
|
+
|
|
48
|
+
ctx = data.ctx
|
|
49
|
+
|
|
50
|
+
# Validate mode-specific parameters
|
|
51
|
+
if mode == "multi" and max_multi_types >= data.n_cell_types:
|
|
52
|
+
raise ParameterError(
|
|
53
|
+
f"MAX_MULTI_TYPES ({max_multi_types}) must be less than "
|
|
54
|
+
f"total cell types ({data.n_cell_types})."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Validate R package
|
|
58
|
+
validate_r_package(
|
|
59
|
+
"spacexr",
|
|
60
|
+
ctx,
|
|
61
|
+
install_cmd="devtools::install_github('dmcable/spacexr', build_vignettes = FALSE)",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
# Load R packages using ro.r() instead of importr() to avoid
|
|
66
|
+
# conversion context issues in async environments
|
|
67
|
+
with localconverter(ro.default_converter + pandas2ri.converter):
|
|
68
|
+
ro.r("library(spacexr)")
|
|
69
|
+
|
|
70
|
+
# Data already copied in prepare_deconvolution
|
|
71
|
+
spatial_data = data.spatial
|
|
72
|
+
reference_data = data.reference
|
|
73
|
+
|
|
74
|
+
# Get spatial coordinates from prepared data
|
|
75
|
+
if data.spatial_coords is not None:
|
|
76
|
+
coords = pd.DataFrame(
|
|
77
|
+
data.spatial_coords[:, :2],
|
|
78
|
+
index=spatial_data.obs_names,
|
|
79
|
+
columns=["x", "y"],
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
coords = pd.DataFrame(
|
|
83
|
+
{"x": range(spatial_data.n_obs), "y": [0] * spatial_data.n_obs},
|
|
84
|
+
index=spatial_data.obs_names,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Prepare cell type information
|
|
88
|
+
cell_types = reference_data.obs[data.cell_type_key].copy()
|
|
89
|
+
cell_types = cell_types.str.replace("/", "_", regex=False)
|
|
90
|
+
cell_types = cell_types.str.replace(" ", "_", regex=False)
|
|
91
|
+
|
|
92
|
+
# RCTD requires minimum 25 cells per cell type
|
|
93
|
+
MIN_CELLS_PER_TYPE = 25
|
|
94
|
+
cell_type_counts = cell_types.value_counts()
|
|
95
|
+
rare_types = cell_type_counts[
|
|
96
|
+
cell_type_counts < MIN_CELLS_PER_TYPE
|
|
97
|
+
].index.tolist()
|
|
98
|
+
|
|
99
|
+
if rare_types:
|
|
100
|
+
warnings.warn(
|
|
101
|
+
f"RCTD requires ≥{MIN_CELLS_PER_TYPE} cells per cell type. "
|
|
102
|
+
f"Filtering {len(rare_types)} rare types: {rare_types}",
|
|
103
|
+
UserWarning,
|
|
104
|
+
stacklevel=2,
|
|
105
|
+
)
|
|
106
|
+
keep_mask = ~cell_types.isin(rare_types)
|
|
107
|
+
reference_data = reference_data[keep_mask].copy()
|
|
108
|
+
cell_types = cell_types[keep_mask]
|
|
109
|
+
|
|
110
|
+
remaining_types = cell_types.unique()
|
|
111
|
+
if len(remaining_types) < 2:
|
|
112
|
+
raise DataError(
|
|
113
|
+
f"After filtering rare cell types, only {len(remaining_types)} "
|
|
114
|
+
f"cell type(s) remain. RCTD requires at least 2 cell types."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
cell_types_series = pd.Series(
|
|
118
|
+
cell_types.values, index=reference_data.obs_names, name="cell_type"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Calculate nUMI
|
|
122
|
+
spatial_numi = pd.Series(
|
|
123
|
+
np.asarray(spatial_data.X.sum(axis=1)).ravel(),
|
|
124
|
+
index=spatial_data.obs_names,
|
|
125
|
+
name="nUMI",
|
|
126
|
+
)
|
|
127
|
+
reference_numi = pd.Series(
|
|
128
|
+
np.asarray(reference_data.X.sum(axis=1)).ravel(),
|
|
129
|
+
index=reference_data.obs_names,
|
|
130
|
+
name="nUMI",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Transfer matrices to R
|
|
134
|
+
with localconverter(ro.default_converter + anndata2ri.converter):
|
|
135
|
+
ro.globalenv["spatial_counts"] = spatial_data.X.T
|
|
136
|
+
ro.globalenv["reference_counts"] = reference_data.X.T
|
|
137
|
+
|
|
138
|
+
ro.globalenv["gene_names_spatial"] = ro.StrVector(spatial_data.var_names)
|
|
139
|
+
ro.globalenv["spot_names"] = ro.StrVector(spatial_data.obs_names)
|
|
140
|
+
ro.globalenv["gene_names_ref"] = ro.StrVector(reference_data.var_names)
|
|
141
|
+
ro.globalenv["cell_names"] = ro.StrVector(reference_data.obs_names)
|
|
142
|
+
|
|
143
|
+
ro.r(
|
|
144
|
+
"""
|
|
145
|
+
rownames(spatial_counts) <- gene_names_spatial
|
|
146
|
+
colnames(spatial_counts) <- spot_names
|
|
147
|
+
rownames(reference_counts) <- gene_names_ref
|
|
148
|
+
colnames(reference_counts) <- cell_names
|
|
149
|
+
"""
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Transfer other data
|
|
153
|
+
with localconverter(ro.default_converter + pandas2ri.converter):
|
|
154
|
+
ro.globalenv["coords"] = ro.conversion.py2rpy(coords)
|
|
155
|
+
ro.globalenv["numi_spatial"] = ro.conversion.py2rpy(spatial_numi)
|
|
156
|
+
ro.globalenv["cell_types_vec"] = ro.conversion.py2rpy(cell_types_series)
|
|
157
|
+
ro.globalenv["numi_ref"] = ro.conversion.py2rpy(reference_numi)
|
|
158
|
+
ro.globalenv["max_cores_val"] = max_cores
|
|
159
|
+
ro.globalenv["rctd_mode"] = mode
|
|
160
|
+
ro.globalenv["conf_thresh"] = confidence_threshold
|
|
161
|
+
ro.globalenv["doub_thresh"] = doublet_threshold
|
|
162
|
+
ro.globalenv["max_multi_types_val"] = max_multi_types
|
|
163
|
+
|
|
164
|
+
# Run RCTD in R
|
|
165
|
+
ro.r(
|
|
166
|
+
"""
|
|
167
|
+
puck <- SpatialRNA(coords, spatial_counts, numi_spatial)
|
|
168
|
+
cell_types_factor <- as.factor(cell_types_vec)
|
|
169
|
+
names(cell_types_factor) <- names(cell_types_vec)
|
|
170
|
+
reference <- Reference(reference_counts, cell_types_factor, numi_ref, min_UMI = 5)
|
|
171
|
+
myRCTD <- create.RCTD(puck, reference, max_cores = max_cores_val,
|
|
172
|
+
MAX_MULTI_TYPES = max_multi_types_val, UMI_min_sigma = 10)
|
|
173
|
+
myRCTD@config$CONFIDENCE_THRESHOLD <- conf_thresh
|
|
174
|
+
myRCTD@config$DOUBLET_THRESHOLD <- doub_thresh
|
|
175
|
+
myRCTD <- run.RCTD(myRCTD, doublet_mode = rctd_mode)
|
|
176
|
+
"""
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Extract results
|
|
180
|
+
proportions = _extract_rctd_results(mode)
|
|
181
|
+
|
|
182
|
+
# Validate results
|
|
183
|
+
if proportions.isna().any().any():
|
|
184
|
+
nan_count = proportions.isna().sum().sum()
|
|
185
|
+
warnings.warn(
|
|
186
|
+
f"RCTD produced {nan_count} NaN values", UserWarning, stacklevel=2
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if (proportions < 0).any().any():
|
|
190
|
+
neg_count = (proportions < 0).sum().sum()
|
|
191
|
+
raise ProcessingError(f"RCTD error: {neg_count} negative values")
|
|
192
|
+
|
|
193
|
+
# Create statistics
|
|
194
|
+
stats = create_deconvolution_stats(
|
|
195
|
+
proportions,
|
|
196
|
+
data.common_genes,
|
|
197
|
+
method=f"RCTD-{mode}",
|
|
198
|
+
device="CPU",
|
|
199
|
+
mode=mode,
|
|
200
|
+
max_cores=max_cores,
|
|
201
|
+
confidence_threshold=confidence_threshold,
|
|
202
|
+
doublet_threshold=doublet_threshold,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Clean up R global environment
|
|
206
|
+
ro.r(
|
|
207
|
+
"""
|
|
208
|
+
rm(list = c("spatial_counts", "reference_counts", "gene_names_spatial",
|
|
209
|
+
"spot_names", "gene_names_ref", "cell_names", "coords",
|
|
210
|
+
"numi_spatial", "cell_types_vec", "numi_ref", "max_cores_val",
|
|
211
|
+
"rctd_mode", "conf_thresh", "doub_thresh", "max_multi_types_val",
|
|
212
|
+
"puck", "cell_types_factor", "reference", "myRCTD",
|
|
213
|
+
"weights_matrix", "cell_type_names"),
|
|
214
|
+
envir = .GlobalEnv)
|
|
215
|
+
gc()
|
|
216
|
+
"""
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return proportions, stats
|
|
220
|
+
|
|
221
|
+
except Exception as e:
|
|
222
|
+
if isinstance(e, (ParameterError, ProcessingError)):
|
|
223
|
+
raise
|
|
224
|
+
raise ProcessingError(f"RCTD deconvolution failed: {e}") from e
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _extract_rctd_results(mode: str) -> pd.DataFrame:
|
|
228
|
+
"""Extract RCTD results from R environment."""
|
|
229
|
+
import rpy2.robjects as ro
|
|
230
|
+
from rpy2.robjects import numpy2ri, pandas2ri
|
|
231
|
+
from rpy2.robjects.conversion import localconverter
|
|
232
|
+
|
|
233
|
+
with localconverter(
|
|
234
|
+
ro.default_converter + pandas2ri.converter + numpy2ri.converter
|
|
235
|
+
):
|
|
236
|
+
if mode == "full":
|
|
237
|
+
ro.r(
|
|
238
|
+
"""
|
|
239
|
+
weights_matrix <- myRCTD@results$weights
|
|
240
|
+
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
|
|
241
|
+
spot_names <- rownames(weights_matrix)
|
|
242
|
+
"""
|
|
243
|
+
)
|
|
244
|
+
elif mode == "doublet":
|
|
245
|
+
ro.r(
|
|
246
|
+
"""
|
|
247
|
+
if("weights_doublet" %in% names(myRCTD@results) && "results_df" %in% names(myRCTD@results)) {
|
|
248
|
+
weights_doublet <- myRCTD@results$weights_doublet
|
|
249
|
+
results_df <- myRCTD@results$results_df
|
|
250
|
+
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
|
|
251
|
+
spot_names <- rownames(results_df)
|
|
252
|
+
n_spots <- length(spot_names)
|
|
253
|
+
n_cell_types <- length(cell_type_names)
|
|
254
|
+
weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
|
|
255
|
+
rownames(weights_matrix) <- spot_names
|
|
256
|
+
colnames(weights_matrix) <- cell_type_names
|
|
257
|
+
for(i in 1:n_spots) {
|
|
258
|
+
spot_class <- results_df$spot_class[i]
|
|
259
|
+
if(spot_class %in% c("doublet_certain", "doublet_uncertain")) {
|
|
260
|
+
first_type <- as.character(results_df$first_type[i])
|
|
261
|
+
second_type <- as.character(results_df$second_type[i])
|
|
262
|
+
if(first_type %in% cell_type_names) {
|
|
263
|
+
first_idx <- which(cell_type_names == first_type)
|
|
264
|
+
weights_matrix[i, first_idx] <- weights_doublet[i, "first_type"]
|
|
265
|
+
}
|
|
266
|
+
if(second_type %in% cell_type_names && second_type != first_type) {
|
|
267
|
+
second_idx <- which(cell_type_names == second_type)
|
|
268
|
+
weights_matrix[i, second_idx] <- weights_doublet[i, "second_type"]
|
|
269
|
+
}
|
|
270
|
+
} else if(spot_class == "singlet") {
|
|
271
|
+
first_type <- as.character(results_df$first_type[i])
|
|
272
|
+
if(first_type %in% cell_type_names) {
|
|
273
|
+
first_idx <- which(cell_type_names == first_type)
|
|
274
|
+
weights_matrix[i, first_idx] <- 1.0
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
} else {
|
|
279
|
+
stop("Official doublet mode structures not found")
|
|
280
|
+
}
|
|
281
|
+
"""
|
|
282
|
+
)
|
|
283
|
+
else: # multi mode
|
|
284
|
+
ro.r(
|
|
285
|
+
"""
|
|
286
|
+
results_list <- myRCTD@results
|
|
287
|
+
spot_names <- colnames(myRCTD@spatialRNA@counts)
|
|
288
|
+
cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
|
|
289
|
+
n_spots <- length(spot_names)
|
|
290
|
+
n_cell_types <- length(cell_type_names)
|
|
291
|
+
weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
|
|
292
|
+
rownames(weights_matrix) <- spot_names
|
|
293
|
+
colnames(weights_matrix) <- cell_type_names
|
|
294
|
+
for(i in 1:n_spots) {
|
|
295
|
+
spot_result <- results_list[[i]]
|
|
296
|
+
predicted_types <- spot_result$cell_type_list
|
|
297
|
+
proportions <- spot_result$sub_weights
|
|
298
|
+
for(j in seq_along(predicted_types)) {
|
|
299
|
+
cell_type <- predicted_types[j]
|
|
300
|
+
if(cell_type %in% cell_type_names) {
|
|
301
|
+
col_idx <- which(cell_type_names == cell_type)
|
|
302
|
+
weights_matrix[i, col_idx] <- proportions[j]
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
"""
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
weights_r = ro.r("as.matrix(weights_matrix)")
|
|
310
|
+
cell_type_names_r = ro.r("cell_type_names")
|
|
311
|
+
spot_names_r = ro.r("spot_names")
|
|
312
|
+
|
|
313
|
+
weights_array = ro.conversion.rpy2py(weights_r)
|
|
314
|
+
cell_type_names = ro.conversion.rpy2py(cell_type_names_r)
|
|
315
|
+
spot_names = ro.conversion.rpy2py(spot_names_r)
|
|
316
|
+
|
|
317
|
+
return pd.DataFrame(weights_array, index=spot_names, columns=cell_type_names)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SPOTlight deconvolution method.
|
|
3
|
+
|
|
4
|
+
SPOTlight is an R-based deconvolution method that uses NMF
|
|
5
|
+
(Non-negative Matrix Factorization) for cell type decomposition.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
from ...utils.adata_utils import to_dense
|
|
17
|
+
from ...utils.dependency_manager import validate_r_package
|
|
18
|
+
from ...utils.exceptions import DataError, ProcessingError
|
|
19
|
+
from .base import PreparedDeconvolutionData, create_deconvolution_stats
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def deconvolve(
|
|
23
|
+
data: PreparedDeconvolutionData,
|
|
24
|
+
n_top_genes: int = 2000,
|
|
25
|
+
nmf_model: str = "ns",
|
|
26
|
+
min_prop: float = 0.01,
|
|
27
|
+
scale: bool = True,
|
|
28
|
+
weight_id: str = "mean.AUC",
|
|
29
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
30
|
+
"""Deconvolve spatial data using SPOTlight R package.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
data: Prepared deconvolution data (immutable, includes spatial coordinates)
|
|
34
|
+
n_top_genes: Number of top HVGs to use
|
|
35
|
+
nmf_model: NMF model type - 'ns' (non-smooth) or 'std' (standard)
|
|
36
|
+
min_prop: Minimum proportion threshold
|
|
37
|
+
scale: Whether to scale data
|
|
38
|
+
weight_id: Column name for marker gene weights
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
42
|
+
"""
|
|
43
|
+
import rpy2.robjects as ro
|
|
44
|
+
from rpy2.robjects import numpy2ri, pandas2ri
|
|
45
|
+
from rpy2.robjects.conversion import localconverter
|
|
46
|
+
|
|
47
|
+
ctx = data.ctx
|
|
48
|
+
|
|
49
|
+
# Validate R package
|
|
50
|
+
validate_r_package(
|
|
51
|
+
"SPOTlight",
|
|
52
|
+
ctx,
|
|
53
|
+
install_cmd="BiocManager::install('SPOTlight')",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
# Validate spatial coordinates from prepared data
|
|
58
|
+
if data.spatial_coords is None:
|
|
59
|
+
raise DataError(
|
|
60
|
+
"SPOTlight requires spatial coordinates. "
|
|
61
|
+
"Ensure spatial data has 'spatial' key in obsm."
|
|
62
|
+
)
|
|
63
|
+
spatial_coords = data.spatial_coords
|
|
64
|
+
|
|
65
|
+
# Data already copied in prepare_deconvolution
|
|
66
|
+
spatial_data = data.spatial
|
|
67
|
+
reference_data = data.reference
|
|
68
|
+
|
|
69
|
+
# Ensure integer counts for R interface
|
|
70
|
+
dense = to_dense(spatial_data.X)
|
|
71
|
+
spatial_counts = (
|
|
72
|
+
dense.astype(np.int32, copy=False) if dense.dtype != np.int32 else dense
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
dense = to_dense(reference_data.X)
|
|
76
|
+
reference_counts = (
|
|
77
|
+
dense.astype(np.int32, copy=False) if dense.dtype != np.int32 else dense
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Clean cell type labels
|
|
81
|
+
cell_types = reference_data.obs[data.cell_type_key].astype(str)
|
|
82
|
+
cell_types = cell_types.str.replace("/", "_", regex=False)
|
|
83
|
+
cell_types = cell_types.str.replace(" ", "_", regex=False)
|
|
84
|
+
|
|
85
|
+
# Load R libraries first (using ro.r() to avoid importr conversion issues)
|
|
86
|
+
with localconverter(ro.default_converter + pandas2ri.converter):
|
|
87
|
+
ro.r("library(SPOTlight)")
|
|
88
|
+
ro.r("library(SingleCellExperiment)")
|
|
89
|
+
ro.r("library(SpatialExperiment)")
|
|
90
|
+
ro.r("library(scran)")
|
|
91
|
+
ro.r("library(scuttle)")
|
|
92
|
+
|
|
93
|
+
# Transfer matrices to R using numpy2ri
|
|
94
|
+
with localconverter(ro.default_converter + numpy2ri.converter):
|
|
95
|
+
ro.globalenv["spatial_counts"] = spatial_counts.T
|
|
96
|
+
ro.globalenv["reference_counts"] = reference_counts.T
|
|
97
|
+
|
|
98
|
+
# Transfer other data
|
|
99
|
+
with localconverter(
|
|
100
|
+
ro.default_converter + pandas2ri.converter + numpy2ri.converter
|
|
101
|
+
):
|
|
102
|
+
ro.globalenv["spatial_coords"] = spatial_coords
|
|
103
|
+
ro.globalenv["gene_names"] = ro.StrVector(data.common_genes)
|
|
104
|
+
ro.globalenv["spatial_names"] = ro.StrVector(list(spatial_data.obs_names))
|
|
105
|
+
ro.globalenv["reference_names"] = ro.StrVector(
|
|
106
|
+
list(reference_data.obs_names)
|
|
107
|
+
)
|
|
108
|
+
ro.globalenv["cell_types"] = ro.StrVector(cell_types.tolist())
|
|
109
|
+
ro.globalenv["nmf_model"] = nmf_model
|
|
110
|
+
ro.globalenv["min_prop"] = min_prop
|
|
111
|
+
ro.globalenv["scale_data"] = scale
|
|
112
|
+
ro.globalenv["weight_id"] = weight_id
|
|
113
|
+
|
|
114
|
+
# Create SCE and SPE objects, run SPOTlight
|
|
115
|
+
ro.r(
|
|
116
|
+
"""
|
|
117
|
+
# Create SingleCellExperiment for reference
|
|
118
|
+
sce <- SingleCellExperiment(
|
|
119
|
+
assays = list(counts = reference_counts),
|
|
120
|
+
colData = data.frame(
|
|
121
|
+
cell_type = factor(cell_types),
|
|
122
|
+
row.names = reference_names
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
rownames(sce) <- gene_names
|
|
126
|
+
sce <- logNormCounts(sce)
|
|
127
|
+
|
|
128
|
+
# Create SpatialExperiment for spatial data
|
|
129
|
+
spe <- SpatialExperiment(
|
|
130
|
+
assays = list(counts = spatial_counts),
|
|
131
|
+
spatialCoords = spatial_coords,
|
|
132
|
+
colData = data.frame(row.names = spatial_names)
|
|
133
|
+
)
|
|
134
|
+
rownames(spe) <- gene_names
|
|
135
|
+
colnames(spe) <- spatial_names
|
|
136
|
+
|
|
137
|
+
# Find marker genes using scran
|
|
138
|
+
markers <- findMarkers(sce, groups = sce$cell_type, test.type = "wilcox")
|
|
139
|
+
|
|
140
|
+
# Format marker genes for SPOTlight
|
|
141
|
+
cell_type_names <- names(markers)
|
|
142
|
+
mgs_list <- list()
|
|
143
|
+
|
|
144
|
+
for (ct in cell_type_names) {
|
|
145
|
+
ct_markers <- markers[[ct]]
|
|
146
|
+
n_markers <- min(50, nrow(ct_markers))
|
|
147
|
+
top_markers <- head(ct_markers[order(ct_markers$p.value), ], n_markers)
|
|
148
|
+
|
|
149
|
+
mgs_df <- data.frame(
|
|
150
|
+
gene = rownames(top_markers),
|
|
151
|
+
cluster = ct,
|
|
152
|
+
mean.AUC = -log10(top_markers$p.value + 1e-10)
|
|
153
|
+
)
|
|
154
|
+
mgs_list[[ct]] <- mgs_df
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
mgs <- do.call(rbind, mgs_list)
|
|
158
|
+
|
|
159
|
+
# Run SPOTlight
|
|
160
|
+
spotlight_result <- SPOTlight(
|
|
161
|
+
x = sce,
|
|
162
|
+
y = spe,
|
|
163
|
+
groups = sce$cell_type,
|
|
164
|
+
mgs = mgs,
|
|
165
|
+
weight_id = weight_id,
|
|
166
|
+
group_id = "cluster",
|
|
167
|
+
gene_id = "gene",
|
|
168
|
+
model = nmf_model,
|
|
169
|
+
min_prop = min_prop,
|
|
170
|
+
scale = scale_data,
|
|
171
|
+
verbose = TRUE
|
|
172
|
+
)
|
|
173
|
+
"""
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Extract results
|
|
177
|
+
with localconverter(
|
|
178
|
+
ro.default_converter + pandas2ri.converter + numpy2ri.converter
|
|
179
|
+
):
|
|
180
|
+
proportions_np = np.array(ro.r("spotlight_result$mat"))
|
|
181
|
+
spot_names = list(ro.r("rownames(spotlight_result$mat)"))
|
|
182
|
+
cell_type_names = list(ro.r("colnames(spotlight_result$mat)"))
|
|
183
|
+
|
|
184
|
+
proportions = pd.DataFrame(
|
|
185
|
+
proportions_np, index=spot_names, columns=cell_type_names
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Create statistics
|
|
189
|
+
stats = create_deconvolution_stats(
|
|
190
|
+
proportions,
|
|
191
|
+
data.common_genes,
|
|
192
|
+
method="SPOTlight",
|
|
193
|
+
device="CPU",
|
|
194
|
+
n_top_genes=n_top_genes,
|
|
195
|
+
nmf_model=nmf_model,
|
|
196
|
+
min_prop=min_prop,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Clean up R global environment
|
|
200
|
+
ro.r(
|
|
201
|
+
"""
|
|
202
|
+
rm(list = c("spatial_counts", "reference_counts", "spatial_coords",
|
|
203
|
+
"gene_names", "spatial_names", "reference_names", "cell_types",
|
|
204
|
+
"nmf_model", "min_prop", "scale_data", "weight_id",
|
|
205
|
+
"sce", "spe", "markers", "mgs", "spotlight_result"),
|
|
206
|
+
envir = .GlobalEnv)
|
|
207
|
+
gc()
|
|
208
|
+
"""
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return proportions, stats
|
|
212
|
+
|
|
213
|
+
except Exception as e:
|
|
214
|
+
if isinstance(e, ProcessingError):
|
|
215
|
+
raise
|
|
216
|
+
raise ProcessingError(f"SPOTlight deconvolution failed: {e}") from e
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Stereoscope deconvolution method.
|
|
3
|
+
|
|
4
|
+
Stereoscope uses a two-stage training workflow:
|
|
5
|
+
1. Train RNAStereoscope model on reference data
|
|
6
|
+
2. Train SpatialStereoscope model on spatial data using RNA model
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import gc
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from ...utils.adata_utils import ensure_categorical
|
|
15
|
+
from ...utils.exceptions import ProcessingError
|
|
16
|
+
from .base import PreparedDeconvolutionData, create_deconvolution_stats
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def deconvolve(
|
|
20
|
+
data: PreparedDeconvolutionData,
|
|
21
|
+
n_epochs: int = 150000,
|
|
22
|
+
learning_rate: float = 0.01,
|
|
23
|
+
batch_size: int = 128,
|
|
24
|
+
use_gpu: bool = False,
|
|
25
|
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
|
26
|
+
"""Deconvolve spatial data using Stereoscope from scvi-tools.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data: Prepared deconvolution data (immutable)
|
|
30
|
+
n_epochs: Total epochs (default: 150000, split 75K+75K)
|
|
31
|
+
learning_rate: Learning rate (default: 0.01)
|
|
32
|
+
batch_size: Minibatch size (default: 128)
|
|
33
|
+
use_gpu: Use GPU acceleration
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Tuple of (proportions DataFrame, statistics dictionary)
|
|
37
|
+
"""
|
|
38
|
+
from scvi.external import RNAStereoscope, SpatialStereoscope
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
# Data already copied in prepare_deconvolution
|
|
42
|
+
spatial_data = data.spatial
|
|
43
|
+
ref_data = data.reference
|
|
44
|
+
|
|
45
|
+
# Ensure categorical cell type
|
|
46
|
+
ensure_categorical(ref_data, data.cell_type_key)
|
|
47
|
+
|
|
48
|
+
cell_types = list(ref_data.obs[data.cell_type_key].cat.categories)
|
|
49
|
+
|
|
50
|
+
# Calculate epoch split
|
|
51
|
+
if n_epochs == 150000:
|
|
52
|
+
rna_epochs, spatial_epochs = 75000, 75000
|
|
53
|
+
else:
|
|
54
|
+
rna_epochs = n_epochs // 2
|
|
55
|
+
spatial_epochs = n_epochs - rna_epochs
|
|
56
|
+
|
|
57
|
+
plan_kwargs = {"lr": learning_rate}
|
|
58
|
+
accelerator = "gpu" if use_gpu else "cpu"
|
|
59
|
+
|
|
60
|
+
# ===== Stage 1: Train RNAStereoscope =====
|
|
61
|
+
RNAStereoscope.setup_anndata(ref_data, labels_key=data.cell_type_key)
|
|
62
|
+
rna_model = RNAStereoscope(ref_data)
|
|
63
|
+
|
|
64
|
+
train_kwargs = {
|
|
65
|
+
"max_epochs": rna_epochs,
|
|
66
|
+
"batch_size": batch_size,
|
|
67
|
+
"plan_kwargs": plan_kwargs,
|
|
68
|
+
}
|
|
69
|
+
if use_gpu:
|
|
70
|
+
train_kwargs["accelerator"] = accelerator
|
|
71
|
+
rna_model.train(**train_kwargs)
|
|
72
|
+
|
|
73
|
+
# ===== Stage 2: Train SpatialStereoscope =====
|
|
74
|
+
SpatialStereoscope.setup_anndata(spatial_data)
|
|
75
|
+
spatial_model = SpatialStereoscope.from_rna_model(spatial_data, rna_model)
|
|
76
|
+
|
|
77
|
+
train_kwargs["max_epochs"] = spatial_epochs
|
|
78
|
+
spatial_model.train(**train_kwargs)
|
|
79
|
+
|
|
80
|
+
# Extract proportions
|
|
81
|
+
proportions = pd.DataFrame(
|
|
82
|
+
spatial_model.get_proportions(),
|
|
83
|
+
index=spatial_data.obs_names,
|
|
84
|
+
columns=cell_types,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Create statistics
|
|
88
|
+
stats = create_deconvolution_stats(
|
|
89
|
+
proportions,
|
|
90
|
+
data.common_genes,
|
|
91
|
+
method="Stereoscope",
|
|
92
|
+
device="gpu" if use_gpu else "cpu",
|
|
93
|
+
n_epochs=n_epochs,
|
|
94
|
+
rna_epochs=rna_epochs,
|
|
95
|
+
spatial_epochs=spatial_epochs,
|
|
96
|
+
learning_rate=learning_rate,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Memory cleanup
|
|
100
|
+
del spatial_model, rna_model
|
|
101
|
+
del spatial_data, ref_data
|
|
102
|
+
gc.collect()
|
|
103
|
+
|
|
104
|
+
return proportions, stats
|
|
105
|
+
|
|
106
|
+
except Exception as e:
|
|
107
|
+
if isinstance(e, ProcessingError):
|
|
108
|
+
raise
|
|
109
|
+
raise ProcessingError(f"Stereoscope deconvolution failed: {e}") from e
|