chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,317 @@
1
+ """
2
+ RCTD (Robust Cell Type Decomposition) deconvolution method.
3
+
4
+ RCTD is an R-based deconvolution method that performs robust
5
+ decomposition of cell type mixtures via the spacexr package.
6
+ """
7
+
8
+ import warnings
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ if TYPE_CHECKING:
15
+ pass
16
+
17
+ from ...utils.dependency_manager import validate_r_package
18
+ from ...utils.exceptions import DataError, ParameterError, ProcessingError
19
+ from .base import PreparedDeconvolutionData, create_deconvolution_stats
20
+
21
+
22
+ def deconvolve(
23
+ data: PreparedDeconvolutionData,
24
+ mode: str = "full",
25
+ max_cores: int = 4,
26
+ confidence_threshold: float = 10.0,
27
+ doublet_threshold: float = 25.0,
28
+ max_multi_types: int = 4,
29
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
30
+ """Deconvolve spatial data using RCTD from spacexr R package.
31
+
32
+ Args:
33
+ data: Prepared deconvolution data (immutable, includes spatial coordinates)
34
+ mode: RCTD mode - 'full', 'doublet', or 'multi'
35
+ max_cores: Maximum CPU cores
36
+ confidence_threshold: Confidence threshold
37
+ doublet_threshold: Doublet detection threshold
38
+ max_multi_types: Max cell types per spot in multi mode
39
+
40
+ Returns:
41
+ Tuple of (proportions DataFrame, statistics dictionary)
42
+ """
43
+ import anndata2ri
44
+ import rpy2.robjects as ro
45
+ from rpy2.robjects import numpy2ri, pandas2ri
46
+ from rpy2.robjects.conversion import localconverter
47
+
48
+ ctx = data.ctx
49
+
50
+ # Validate mode-specific parameters
51
+ if mode == "multi" and max_multi_types >= data.n_cell_types:
52
+ raise ParameterError(
53
+ f"MAX_MULTI_TYPES ({max_multi_types}) must be less than "
54
+ f"total cell types ({data.n_cell_types})."
55
+ )
56
+
57
+ # Validate R package
58
+ validate_r_package(
59
+ "spacexr",
60
+ ctx,
61
+ install_cmd="devtools::install_github('dmcable/spacexr', build_vignettes = FALSE)",
62
+ )
63
+
64
+ try:
65
+ # Load R packages using ro.r() instead of importr() to avoid
66
+ # conversion context issues in async environments
67
+ with localconverter(ro.default_converter + pandas2ri.converter):
68
+ ro.r("library(spacexr)")
69
+
70
+ # Data already copied in prepare_deconvolution
71
+ spatial_data = data.spatial
72
+ reference_data = data.reference
73
+
74
+ # Get spatial coordinates from prepared data
75
+ if data.spatial_coords is not None:
76
+ coords = pd.DataFrame(
77
+ data.spatial_coords[:, :2],
78
+ index=spatial_data.obs_names,
79
+ columns=["x", "y"],
80
+ )
81
+ else:
82
+ coords = pd.DataFrame(
83
+ {"x": range(spatial_data.n_obs), "y": [0] * spatial_data.n_obs},
84
+ index=spatial_data.obs_names,
85
+ )
86
+
87
+ # Prepare cell type information
88
+ cell_types = reference_data.obs[data.cell_type_key].copy()
89
+ cell_types = cell_types.str.replace("/", "_", regex=False)
90
+ cell_types = cell_types.str.replace(" ", "_", regex=False)
91
+
92
+ # RCTD requires minimum 25 cells per cell type
93
+ MIN_CELLS_PER_TYPE = 25
94
+ cell_type_counts = cell_types.value_counts()
95
+ rare_types = cell_type_counts[
96
+ cell_type_counts < MIN_CELLS_PER_TYPE
97
+ ].index.tolist()
98
+
99
+ if rare_types:
100
+ warnings.warn(
101
+ f"RCTD requires ≥{MIN_CELLS_PER_TYPE} cells per cell type. "
102
+ f"Filtering {len(rare_types)} rare types: {rare_types}",
103
+ UserWarning,
104
+ stacklevel=2,
105
+ )
106
+ keep_mask = ~cell_types.isin(rare_types)
107
+ reference_data = reference_data[keep_mask].copy()
108
+ cell_types = cell_types[keep_mask]
109
+
110
+ remaining_types = cell_types.unique()
111
+ if len(remaining_types) < 2:
112
+ raise DataError(
113
+ f"After filtering rare cell types, only {len(remaining_types)} "
114
+ f"cell type(s) remain. RCTD requires at least 2 cell types."
115
+ )
116
+
117
+ cell_types_series = pd.Series(
118
+ cell_types.values, index=reference_data.obs_names, name="cell_type"
119
+ )
120
+
121
+ # Calculate nUMI
122
+ spatial_numi = pd.Series(
123
+ np.asarray(spatial_data.X.sum(axis=1)).ravel(),
124
+ index=spatial_data.obs_names,
125
+ name="nUMI",
126
+ )
127
+ reference_numi = pd.Series(
128
+ np.asarray(reference_data.X.sum(axis=1)).ravel(),
129
+ index=reference_data.obs_names,
130
+ name="nUMI",
131
+ )
132
+
133
+ # Transfer matrices to R
134
+ with localconverter(ro.default_converter + anndata2ri.converter):
135
+ ro.globalenv["spatial_counts"] = spatial_data.X.T
136
+ ro.globalenv["reference_counts"] = reference_data.X.T
137
+
138
+ ro.globalenv["gene_names_spatial"] = ro.StrVector(spatial_data.var_names)
139
+ ro.globalenv["spot_names"] = ro.StrVector(spatial_data.obs_names)
140
+ ro.globalenv["gene_names_ref"] = ro.StrVector(reference_data.var_names)
141
+ ro.globalenv["cell_names"] = ro.StrVector(reference_data.obs_names)
142
+
143
+ ro.r(
144
+ """
145
+ rownames(spatial_counts) <- gene_names_spatial
146
+ colnames(spatial_counts) <- spot_names
147
+ rownames(reference_counts) <- gene_names_ref
148
+ colnames(reference_counts) <- cell_names
149
+ """
150
+ )
151
+
152
+ # Transfer other data
153
+ with localconverter(ro.default_converter + pandas2ri.converter):
154
+ ro.globalenv["coords"] = ro.conversion.py2rpy(coords)
155
+ ro.globalenv["numi_spatial"] = ro.conversion.py2rpy(spatial_numi)
156
+ ro.globalenv["cell_types_vec"] = ro.conversion.py2rpy(cell_types_series)
157
+ ro.globalenv["numi_ref"] = ro.conversion.py2rpy(reference_numi)
158
+ ro.globalenv["max_cores_val"] = max_cores
159
+ ro.globalenv["rctd_mode"] = mode
160
+ ro.globalenv["conf_thresh"] = confidence_threshold
161
+ ro.globalenv["doub_thresh"] = doublet_threshold
162
+ ro.globalenv["max_multi_types_val"] = max_multi_types
163
+
164
+ # Run RCTD in R
165
+ ro.r(
166
+ """
167
+ puck <- SpatialRNA(coords, spatial_counts, numi_spatial)
168
+ cell_types_factor <- as.factor(cell_types_vec)
169
+ names(cell_types_factor) <- names(cell_types_vec)
170
+ reference <- Reference(reference_counts, cell_types_factor, numi_ref, min_UMI = 5)
171
+ myRCTD <- create.RCTD(puck, reference, max_cores = max_cores_val,
172
+ MAX_MULTI_TYPES = max_multi_types_val, UMI_min_sigma = 10)
173
+ myRCTD@config$CONFIDENCE_THRESHOLD <- conf_thresh
174
+ myRCTD@config$DOUBLET_THRESHOLD <- doub_thresh
175
+ myRCTD <- run.RCTD(myRCTD, doublet_mode = rctd_mode)
176
+ """
177
+ )
178
+
179
+ # Extract results
180
+ proportions = _extract_rctd_results(mode)
181
+
182
+ # Validate results
183
+ if proportions.isna().any().any():
184
+ nan_count = proportions.isna().sum().sum()
185
+ warnings.warn(
186
+ f"RCTD produced {nan_count} NaN values", UserWarning, stacklevel=2
187
+ )
188
+
189
+ if (proportions < 0).any().any():
190
+ neg_count = (proportions < 0).sum().sum()
191
+ raise ProcessingError(f"RCTD error: {neg_count} negative values")
192
+
193
+ # Create statistics
194
+ stats = create_deconvolution_stats(
195
+ proportions,
196
+ data.common_genes,
197
+ method=f"RCTD-{mode}",
198
+ device="CPU",
199
+ mode=mode,
200
+ max_cores=max_cores,
201
+ confidence_threshold=confidence_threshold,
202
+ doublet_threshold=doublet_threshold,
203
+ )
204
+
205
+ # Clean up R global environment
206
+ ro.r(
207
+ """
208
+ rm(list = c("spatial_counts", "reference_counts", "gene_names_spatial",
209
+ "spot_names", "gene_names_ref", "cell_names", "coords",
210
+ "numi_spatial", "cell_types_vec", "numi_ref", "max_cores_val",
211
+ "rctd_mode", "conf_thresh", "doub_thresh", "max_multi_types_val",
212
+ "puck", "cell_types_factor", "reference", "myRCTD",
213
+ "weights_matrix", "cell_type_names"),
214
+ envir = .GlobalEnv)
215
+ gc()
216
+ """
217
+ )
218
+
219
+ return proportions, stats
220
+
221
+ except Exception as e:
222
+ if isinstance(e, (ParameterError, ProcessingError)):
223
+ raise
224
+ raise ProcessingError(f"RCTD deconvolution failed: {e}") from e
225
+
226
+
227
+ def _extract_rctd_results(mode: str) -> pd.DataFrame:
228
+ """Extract RCTD results from R environment."""
229
+ import rpy2.robjects as ro
230
+ from rpy2.robjects import numpy2ri, pandas2ri
231
+ from rpy2.robjects.conversion import localconverter
232
+
233
+ with localconverter(
234
+ ro.default_converter + pandas2ri.converter + numpy2ri.converter
235
+ ):
236
+ if mode == "full":
237
+ ro.r(
238
+ """
239
+ weights_matrix <- myRCTD@results$weights
240
+ cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
241
+ spot_names <- rownames(weights_matrix)
242
+ """
243
+ )
244
+ elif mode == "doublet":
245
+ ro.r(
246
+ """
247
+ if("weights_doublet" %in% names(myRCTD@results) && "results_df" %in% names(myRCTD@results)) {
248
+ weights_doublet <- myRCTD@results$weights_doublet
249
+ results_df <- myRCTD@results$results_df
250
+ cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
251
+ spot_names <- rownames(results_df)
252
+ n_spots <- length(spot_names)
253
+ n_cell_types <- length(cell_type_names)
254
+ weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
255
+ rownames(weights_matrix) <- spot_names
256
+ colnames(weights_matrix) <- cell_type_names
257
+ for(i in 1:n_spots) {
258
+ spot_class <- results_df$spot_class[i]
259
+ if(spot_class %in% c("doublet_certain", "doublet_uncertain")) {
260
+ first_type <- as.character(results_df$first_type[i])
261
+ second_type <- as.character(results_df$second_type[i])
262
+ if(first_type %in% cell_type_names) {
263
+ first_idx <- which(cell_type_names == first_type)
264
+ weights_matrix[i, first_idx] <- weights_doublet[i, "first_type"]
265
+ }
266
+ if(second_type %in% cell_type_names && second_type != first_type) {
267
+ second_idx <- which(cell_type_names == second_type)
268
+ weights_matrix[i, second_idx] <- weights_doublet[i, "second_type"]
269
+ }
270
+ } else if(spot_class == "singlet") {
271
+ first_type <- as.character(results_df$first_type[i])
272
+ if(first_type %in% cell_type_names) {
273
+ first_idx <- which(cell_type_names == first_type)
274
+ weights_matrix[i, first_idx] <- 1.0
275
+ }
276
+ }
277
+ }
278
+ } else {
279
+ stop("Official doublet mode structures not found")
280
+ }
281
+ """
282
+ )
283
+ else: # multi mode
284
+ ro.r(
285
+ """
286
+ results_list <- myRCTD@results
287
+ spot_names <- colnames(myRCTD@spatialRNA@counts)
288
+ cell_type_names <- myRCTD@cell_type_info$renorm[[2]]
289
+ n_spots <- length(spot_names)
290
+ n_cell_types <- length(cell_type_names)
291
+ weights_matrix <- matrix(0, nrow = n_spots, ncol = n_cell_types)
292
+ rownames(weights_matrix) <- spot_names
293
+ colnames(weights_matrix) <- cell_type_names
294
+ for(i in 1:n_spots) {
295
+ spot_result <- results_list[[i]]
296
+ predicted_types <- spot_result$cell_type_list
297
+ proportions <- spot_result$sub_weights
298
+ for(j in seq_along(predicted_types)) {
299
+ cell_type <- predicted_types[j]
300
+ if(cell_type %in% cell_type_names) {
301
+ col_idx <- which(cell_type_names == cell_type)
302
+ weights_matrix[i, col_idx] <- proportions[j]
303
+ }
304
+ }
305
+ }
306
+ """
307
+ )
308
+
309
+ weights_r = ro.r("as.matrix(weights_matrix)")
310
+ cell_type_names_r = ro.r("cell_type_names")
311
+ spot_names_r = ro.r("spot_names")
312
+
313
+ weights_array = ro.conversion.rpy2py(weights_r)
314
+ cell_type_names = ro.conversion.rpy2py(cell_type_names_r)
315
+ spot_names = ro.conversion.rpy2py(spot_names_r)
316
+
317
+ return pd.DataFrame(weights_array, index=spot_names, columns=cell_type_names)
@@ -0,0 +1,216 @@
1
+ """
2
+ SPOTlight deconvolution method.
3
+
4
+ SPOTlight is an R-based deconvolution method that uses NMF
5
+ (Non-negative Matrix Factorization) for cell type decomposition.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ if TYPE_CHECKING:
14
+ pass
15
+
16
+ from ...utils.adata_utils import to_dense
17
+ from ...utils.dependency_manager import validate_r_package
18
+ from ...utils.exceptions import DataError, ProcessingError
19
+ from .base import PreparedDeconvolutionData, create_deconvolution_stats
20
+
21
+
22
+ def deconvolve(
23
+ data: PreparedDeconvolutionData,
24
+ n_top_genes: int = 2000,
25
+ nmf_model: str = "ns",
26
+ min_prop: float = 0.01,
27
+ scale: bool = True,
28
+ weight_id: str = "mean.AUC",
29
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
30
+ """Deconvolve spatial data using SPOTlight R package.
31
+
32
+ Args:
33
+ data: Prepared deconvolution data (immutable, includes spatial coordinates)
34
+ n_top_genes: Number of top HVGs to use
35
+ nmf_model: NMF model type - 'ns' (non-smooth) or 'std' (standard)
36
+ min_prop: Minimum proportion threshold
37
+ scale: Whether to scale data
38
+ weight_id: Column name for marker gene weights
39
+
40
+ Returns:
41
+ Tuple of (proportions DataFrame, statistics dictionary)
42
+ """
43
+ import rpy2.robjects as ro
44
+ from rpy2.robjects import numpy2ri, pandas2ri
45
+ from rpy2.robjects.conversion import localconverter
46
+
47
+ ctx = data.ctx
48
+
49
+ # Validate R package
50
+ validate_r_package(
51
+ "SPOTlight",
52
+ ctx,
53
+ install_cmd="BiocManager::install('SPOTlight')",
54
+ )
55
+
56
+ try:
57
+ # Validate spatial coordinates from prepared data
58
+ if data.spatial_coords is None:
59
+ raise DataError(
60
+ "SPOTlight requires spatial coordinates. "
61
+ "Ensure spatial data has 'spatial' key in obsm."
62
+ )
63
+ spatial_coords = data.spatial_coords
64
+
65
+ # Data already copied in prepare_deconvolution
66
+ spatial_data = data.spatial
67
+ reference_data = data.reference
68
+
69
+ # Ensure integer counts for R interface
70
+ dense = to_dense(spatial_data.X)
71
+ spatial_counts = (
72
+ dense.astype(np.int32, copy=False) if dense.dtype != np.int32 else dense
73
+ )
74
+
75
+ dense = to_dense(reference_data.X)
76
+ reference_counts = (
77
+ dense.astype(np.int32, copy=False) if dense.dtype != np.int32 else dense
78
+ )
79
+
80
+ # Clean cell type labels
81
+ cell_types = reference_data.obs[data.cell_type_key].astype(str)
82
+ cell_types = cell_types.str.replace("/", "_", regex=False)
83
+ cell_types = cell_types.str.replace(" ", "_", regex=False)
84
+
85
+ # Load R libraries first (using ro.r() to avoid importr conversion issues)
86
+ with localconverter(ro.default_converter + pandas2ri.converter):
87
+ ro.r("library(SPOTlight)")
88
+ ro.r("library(SingleCellExperiment)")
89
+ ro.r("library(SpatialExperiment)")
90
+ ro.r("library(scran)")
91
+ ro.r("library(scuttle)")
92
+
93
+ # Transfer matrices to R using numpy2ri
94
+ with localconverter(ro.default_converter + numpy2ri.converter):
95
+ ro.globalenv["spatial_counts"] = spatial_counts.T
96
+ ro.globalenv["reference_counts"] = reference_counts.T
97
+
98
+ # Transfer other data
99
+ with localconverter(
100
+ ro.default_converter + pandas2ri.converter + numpy2ri.converter
101
+ ):
102
+ ro.globalenv["spatial_coords"] = spatial_coords
103
+ ro.globalenv["gene_names"] = ro.StrVector(data.common_genes)
104
+ ro.globalenv["spatial_names"] = ro.StrVector(list(spatial_data.obs_names))
105
+ ro.globalenv["reference_names"] = ro.StrVector(
106
+ list(reference_data.obs_names)
107
+ )
108
+ ro.globalenv["cell_types"] = ro.StrVector(cell_types.tolist())
109
+ ro.globalenv["nmf_model"] = nmf_model
110
+ ro.globalenv["min_prop"] = min_prop
111
+ ro.globalenv["scale_data"] = scale
112
+ ro.globalenv["weight_id"] = weight_id
113
+
114
+ # Create SCE and SPE objects, run SPOTlight
115
+ ro.r(
116
+ """
117
+ # Create SingleCellExperiment for reference
118
+ sce <- SingleCellExperiment(
119
+ assays = list(counts = reference_counts),
120
+ colData = data.frame(
121
+ cell_type = factor(cell_types),
122
+ row.names = reference_names
123
+ )
124
+ )
125
+ rownames(sce) <- gene_names
126
+ sce <- logNormCounts(sce)
127
+
128
+ # Create SpatialExperiment for spatial data
129
+ spe <- SpatialExperiment(
130
+ assays = list(counts = spatial_counts),
131
+ spatialCoords = spatial_coords,
132
+ colData = data.frame(row.names = spatial_names)
133
+ )
134
+ rownames(spe) <- gene_names
135
+ colnames(spe) <- spatial_names
136
+
137
+ # Find marker genes using scran
138
+ markers <- findMarkers(sce, groups = sce$cell_type, test.type = "wilcox")
139
+
140
+ # Format marker genes for SPOTlight
141
+ cell_type_names <- names(markers)
142
+ mgs_list <- list()
143
+
144
+ for (ct in cell_type_names) {
145
+ ct_markers <- markers[[ct]]
146
+ n_markers <- min(50, nrow(ct_markers))
147
+ top_markers <- head(ct_markers[order(ct_markers$p.value), ], n_markers)
148
+
149
+ mgs_df <- data.frame(
150
+ gene = rownames(top_markers),
151
+ cluster = ct,
152
+ mean.AUC = -log10(top_markers$p.value + 1e-10)
153
+ )
154
+ mgs_list[[ct]] <- mgs_df
155
+ }
156
+
157
+ mgs <- do.call(rbind, mgs_list)
158
+
159
+ # Run SPOTlight
160
+ spotlight_result <- SPOTlight(
161
+ x = sce,
162
+ y = spe,
163
+ groups = sce$cell_type,
164
+ mgs = mgs,
165
+ weight_id = weight_id,
166
+ group_id = "cluster",
167
+ gene_id = "gene",
168
+ model = nmf_model,
169
+ min_prop = min_prop,
170
+ scale = scale_data,
171
+ verbose = TRUE
172
+ )
173
+ """
174
+ )
175
+
176
+ # Extract results
177
+ with localconverter(
178
+ ro.default_converter + pandas2ri.converter + numpy2ri.converter
179
+ ):
180
+ proportions_np = np.array(ro.r("spotlight_result$mat"))
181
+ spot_names = list(ro.r("rownames(spotlight_result$mat)"))
182
+ cell_type_names = list(ro.r("colnames(spotlight_result$mat)"))
183
+
184
+ proportions = pd.DataFrame(
185
+ proportions_np, index=spot_names, columns=cell_type_names
186
+ )
187
+
188
+ # Create statistics
189
+ stats = create_deconvolution_stats(
190
+ proportions,
191
+ data.common_genes,
192
+ method="SPOTlight",
193
+ device="CPU",
194
+ n_top_genes=n_top_genes,
195
+ nmf_model=nmf_model,
196
+ min_prop=min_prop,
197
+ )
198
+
199
+ # Clean up R global environment
200
+ ro.r(
201
+ """
202
+ rm(list = c("spatial_counts", "reference_counts", "spatial_coords",
203
+ "gene_names", "spatial_names", "reference_names", "cell_types",
204
+ "nmf_model", "min_prop", "scale_data", "weight_id",
205
+ "sce", "spe", "markers", "mgs", "spotlight_result"),
206
+ envir = .GlobalEnv)
207
+ gc()
208
+ """
209
+ )
210
+
211
+ return proportions, stats
212
+
213
+ except Exception as e:
214
+ if isinstance(e, ProcessingError):
215
+ raise
216
+ raise ProcessingError(f"SPOTlight deconvolution failed: {e}") from e
@@ -0,0 +1,109 @@
1
+ """
2
+ Stereoscope deconvolution method.
3
+
4
+ Stereoscope uses a two-stage training workflow:
5
+ 1. Train RNAStereoscope model on reference data
6
+ 2. Train SpatialStereoscope model on spatial data using RNA model
7
+ """
8
+
9
+ import gc
10
+ from typing import Any
11
+
12
+ import pandas as pd
13
+
14
+ from ...utils.adata_utils import ensure_categorical
15
+ from ...utils.exceptions import ProcessingError
16
+ from .base import PreparedDeconvolutionData, create_deconvolution_stats
17
+
18
+
19
+ def deconvolve(
20
+ data: PreparedDeconvolutionData,
21
+ n_epochs: int = 150000,
22
+ learning_rate: float = 0.01,
23
+ batch_size: int = 128,
24
+ use_gpu: bool = False,
25
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
26
+ """Deconvolve spatial data using Stereoscope from scvi-tools.
27
+
28
+ Args:
29
+ data: Prepared deconvolution data (immutable)
30
+ n_epochs: Total epochs (default: 150000, split 75K+75K)
31
+ learning_rate: Learning rate (default: 0.01)
32
+ batch_size: Minibatch size (default: 128)
33
+ use_gpu: Use GPU acceleration
34
+
35
+ Returns:
36
+ Tuple of (proportions DataFrame, statistics dictionary)
37
+ """
38
+ from scvi.external import RNAStereoscope, SpatialStereoscope
39
+
40
+ try:
41
+ # Data already copied in prepare_deconvolution
42
+ spatial_data = data.spatial
43
+ ref_data = data.reference
44
+
45
+ # Ensure categorical cell type
46
+ ensure_categorical(ref_data, data.cell_type_key)
47
+
48
+ cell_types = list(ref_data.obs[data.cell_type_key].cat.categories)
49
+
50
+ # Calculate epoch split
51
+ if n_epochs == 150000:
52
+ rna_epochs, spatial_epochs = 75000, 75000
53
+ else:
54
+ rna_epochs = n_epochs // 2
55
+ spatial_epochs = n_epochs - rna_epochs
56
+
57
+ plan_kwargs = {"lr": learning_rate}
58
+ accelerator = "gpu" if use_gpu else "cpu"
59
+
60
+ # ===== Stage 1: Train RNAStereoscope =====
61
+ RNAStereoscope.setup_anndata(ref_data, labels_key=data.cell_type_key)
62
+ rna_model = RNAStereoscope(ref_data)
63
+
64
+ train_kwargs = {
65
+ "max_epochs": rna_epochs,
66
+ "batch_size": batch_size,
67
+ "plan_kwargs": plan_kwargs,
68
+ }
69
+ if use_gpu:
70
+ train_kwargs["accelerator"] = accelerator
71
+ rna_model.train(**train_kwargs)
72
+
73
+ # ===== Stage 2: Train SpatialStereoscope =====
74
+ SpatialStereoscope.setup_anndata(spatial_data)
75
+ spatial_model = SpatialStereoscope.from_rna_model(spatial_data, rna_model)
76
+
77
+ train_kwargs["max_epochs"] = spatial_epochs
78
+ spatial_model.train(**train_kwargs)
79
+
80
+ # Extract proportions
81
+ proportions = pd.DataFrame(
82
+ spatial_model.get_proportions(),
83
+ index=spatial_data.obs_names,
84
+ columns=cell_types,
85
+ )
86
+
87
+ # Create statistics
88
+ stats = create_deconvolution_stats(
89
+ proportions,
90
+ data.common_genes,
91
+ method="Stereoscope",
92
+ device="gpu" if use_gpu else "cpu",
93
+ n_epochs=n_epochs,
94
+ rna_epochs=rna_epochs,
95
+ spatial_epochs=spatial_epochs,
96
+ learning_rate=learning_rate,
97
+ )
98
+
99
+ # Memory cleanup
100
+ del spatial_model, rna_model
101
+ del spatial_data, ref_data
102
+ gc.collect()
103
+
104
+ return proportions, stats
105
+
106
+ except Exception as e:
107
+ if isinstance(e, ProcessingError):
108
+ raise
109
+ raise ProcessingError(f"Stereoscope deconvolution failed: {e}") from e