chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,326 @@
1
+ """
2
+ Cell2location deconvolution method.
3
+
4
+ Cell2location uses a two-stage training process:
5
+ 1. Reference model (NB regression) learns cell type gene expression signatures
6
+ 2. Cell2location model performs spatial mapping using these signatures
7
+ """
8
+
9
+ import gc
10
+ import warnings
11
+ from typing import TYPE_CHECKING, Any, Optional
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ if TYPE_CHECKING:
17
+ import anndata as ad
18
+
19
+ from ...spatial_mcp_adapter import ToolContext
20
+
21
+ from ...utils.dependency_manager import is_available, require
22
+ from ...utils.device_utils import get_device
23
+ from ...utils.exceptions import DataError, ProcessingError
24
+ from ...utils.image_utils import non_interactive_backend
25
+ from ...utils.mcp_utils import suppress_output
26
+ from .base import (
27
+ PreparedDeconvolutionData,
28
+ check_model_convergence,
29
+ create_deconvolution_stats,
30
+ )
31
+
32
+
33
+ async def apply_gene_filtering(
34
+ adata: "ad.AnnData",
35
+ ctx: "ToolContext",
36
+ cell_count_cutoff: int = 5,
37
+ cell_percentage_cutoff2: float = 0.03,
38
+ nonz_mean_cutoff: float = 1.12,
39
+ ) -> "ad.AnnData":
40
+ """Apply cell2location's official gene filtering.
41
+
42
+ Reference: cell2location tutorial - "very permissive gene selection"
43
+
44
+ This function is called by the preprocess hook in __init__.py before
45
+ common gene identification.
46
+
47
+ Note: The original filter_genes function creates a matplotlib figure.
48
+ We suppress this by using Agg backend and closing the figure immediately.
49
+ """
50
+ if not is_available("cell2location"):
51
+ await ctx.warning(
52
+ "cell2location.utils.filtering not available. "
53
+ "Skipping gene filtering (may degrade results)."
54
+ )
55
+ return adata.copy()
56
+
57
+ import matplotlib.pyplot as plt
58
+
59
+ with non_interactive_backend():
60
+ from cell2location.utils.filtering import filter_genes
61
+
62
+ selected = filter_genes(
63
+ adata,
64
+ cell_count_cutoff=cell_count_cutoff,
65
+ cell_percentage_cutoff2=cell_percentage_cutoff2,
66
+ nonz_mean_cutoff=nonz_mean_cutoff,
67
+ )
68
+ plt.close("all")
69
+
70
+ return adata[:, selected].copy()
71
+
72
+
73
+ def deconvolve(
74
+ data: PreparedDeconvolutionData,
75
+ ref_model_epochs: int = 250,
76
+ n_epochs: int = 30000,
77
+ n_cells_per_spot: int = 30,
78
+ detection_alpha: float = 20.0,
79
+ use_gpu: bool = False,
80
+ batch_key: Optional[str] = None,
81
+ categorical_covariate_keys: Optional[list[str]] = None,
82
+ ref_model_lr: float = 0.002,
83
+ cell2location_lr: float = 0.005,
84
+ ref_model_train_size: float = 1.0,
85
+ cell2location_train_size: float = 1.0,
86
+ early_stopping: bool = False,
87
+ early_stopping_patience: int = 45,
88
+ early_stopping_threshold: float = 0.0,
89
+ use_aggressive_training: bool = False,
90
+ validation_size: float = 0.1,
91
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
92
+ """Deconvolve spatial data using Cell2location.
93
+
94
+ Note: Gene filtering is handled by the preprocess hook in __init__.py.
95
+ The data parameter contains already filtered and subset data.
96
+
97
+ Args:
98
+ data: Prepared deconvolution data (immutable, already filtered)
99
+ ref_model_epochs: Epochs for reference model (default: 250)
100
+ n_epochs: Epochs for Cell2location model (default: 30000)
101
+ n_cells_per_spot: Expected cells per location (default: 30)
102
+ detection_alpha: RNA detection sensitivity (default: 20, NEW 2024)
103
+ use_gpu: Use GPU acceleration
104
+ batch_key: Column for batch correction
105
+ categorical_covariate_keys: Technical covariates
106
+ ref_model_lr: Reference model learning rate (default: 0.002)
107
+ cell2location_lr: Cell2location learning rate (default: 0.005)
108
+ *_train_size: Training data fractions
109
+ early_stopping*: Early stopping parameters
110
+ use_aggressive_training: Use train_aggressive() method
111
+ validation_size: Validation set size
112
+
113
+ Returns:
114
+ Tuple of (proportions DataFrame, statistics dictionary)
115
+ """
116
+ require("cell2location")
117
+ from cell2location.models import Cell2location, RegressionModel
118
+
119
+ cell_type_key = data.cell_type_key
120
+
121
+ try:
122
+ device = get_device(prefer_gpu=use_gpu)
123
+
124
+ # Data already copied in prepare_deconvolution
125
+ ref = data.reference
126
+ sp = data.spatial
127
+
128
+ # Ensure float32 for scvi-tools compatibility
129
+ if ref.X.dtype != np.float32:
130
+ ref.X = ref.X.astype(np.float32)
131
+ if sp.X.dtype != np.float32:
132
+ sp.X = sp.X.astype(np.float32)
133
+
134
+ # Handle NaN in cell types
135
+ if ref.obs[cell_type_key].isna().any():
136
+ warnings.warn(
137
+ f"Reference has NaN in {cell_type_key}. Excluding.",
138
+ UserWarning,
139
+ stacklevel=2,
140
+ )
141
+ ref = ref[~ref.obs[cell_type_key].isna()].copy()
142
+
143
+ # ===== Stage 1: Train Reference Model =====
144
+ RegressionModel.setup_anndata(
145
+ adata=ref,
146
+ labels_key=cell_type_key,
147
+ batch_key=batch_key,
148
+ categorical_covariate_keys=categorical_covariate_keys,
149
+ )
150
+
151
+ ref_model = RegressionModel(ref)
152
+ with suppress_output():
153
+ train_kwargs = _build_train_kwargs(
154
+ epochs=ref_model_epochs,
155
+ lr=ref_model_lr,
156
+ train_size=ref_model_train_size,
157
+ device=device,
158
+ early_stopping=early_stopping,
159
+ early_stopping_patience=early_stopping_patience,
160
+ validation_size=validation_size,
161
+ use_aggressive=use_aggressive_training,
162
+ )
163
+ ref_model.train(**train_kwargs)
164
+
165
+ # Check convergence
166
+ converged, warning_msg = check_model_convergence(ref_model, "ReferenceModel")
167
+ if not converged and warning_msg:
168
+ warnings.warn(warning_msg, UserWarning, stacklevel=2)
169
+
170
+ # Export reference signatures
171
+ ref = ref_model.export_posterior(
172
+ ref, sample_kwargs={"num_samples": 1000, "batch_size": 2500}
173
+ )
174
+ ref_signatures = _extract_reference_signatures(ref)
175
+
176
+ # ===== Stage 2: Train Cell2location Model =====
177
+ Cell2location.setup_anndata(
178
+ adata=sp,
179
+ batch_key=batch_key,
180
+ categorical_covariate_keys=categorical_covariate_keys,
181
+ )
182
+
183
+ cell2loc_model = Cell2location(
184
+ sp,
185
+ cell_state_df=ref_signatures,
186
+ N_cells_per_location=n_cells_per_spot,
187
+ detection_alpha=detection_alpha,
188
+ )
189
+
190
+ with suppress_output():
191
+ train_kwargs = _build_train_kwargs(
192
+ epochs=n_epochs,
193
+ lr=cell2location_lr,
194
+ train_size=cell2location_train_size,
195
+ device=device,
196
+ early_stopping=early_stopping,
197
+ early_stopping_patience=early_stopping_patience,
198
+ validation_size=validation_size,
199
+ use_aggressive=use_aggressive_training,
200
+ )
201
+ cell2loc_model.train(**train_kwargs)
202
+
203
+ # Check convergence
204
+ converged, warning_msg = check_model_convergence(
205
+ cell2loc_model, "Cell2location"
206
+ )
207
+ if not converged and warning_msg:
208
+ warnings.warn(warning_msg, UserWarning, stacklevel=2)
209
+
210
+ # Export results
211
+ sp = cell2loc_model.export_posterior(
212
+ sp, sample_kwargs={"num_samples": 1000, "batch_size": 2500}
213
+ )
214
+
215
+ # Extract cell abundance
216
+ cell_abundance = _extract_cell_abundance(sp)
217
+
218
+ # Create proportions DataFrame
219
+ proportions = pd.DataFrame(
220
+ cell_abundance,
221
+ index=sp.obs_names,
222
+ columns=ref_signatures.columns,
223
+ )
224
+
225
+ # Create statistics
226
+ stats = create_deconvolution_stats(
227
+ proportions,
228
+ data.common_genes,
229
+ method="Cell2location",
230
+ device=device,
231
+ n_epochs=n_epochs,
232
+ n_cells_per_spot=n_cells_per_spot,
233
+ detection_alpha=detection_alpha,
234
+ )
235
+
236
+ # Add model performance metrics
237
+ if hasattr(cell2loc_model, "history") and cell2loc_model.history is not None:
238
+ history = cell2loc_model.history
239
+ if "elbo_train" in history and not history["elbo_train"].empty:
240
+ stats["final_elbo"] = float(history["elbo_train"].iloc[-1])
241
+
242
+ # Memory cleanup
243
+ del cell2loc_model, ref_model
244
+ del ref, sp, ref_signatures
245
+ gc.collect()
246
+
247
+ return proportions, stats
248
+
249
+ except Exception as e:
250
+ if isinstance(e, (ProcessingError, DataError)):
251
+ raise
252
+ raise ProcessingError(f"Cell2location deconvolution failed: {e}") from e
253
+
254
+
255
+ def _build_train_kwargs(
256
+ epochs: int,
257
+ lr: float,
258
+ train_size: float,
259
+ device: str,
260
+ early_stopping: bool,
261
+ early_stopping_patience: int,
262
+ validation_size: float,
263
+ use_aggressive: bool,
264
+ ) -> dict[str, Any]:
265
+ """Build training kwargs for scvi-tools models."""
266
+ if use_aggressive:
267
+ kwargs = {"max_epochs": epochs, "lr": lr}
268
+ if device == "cuda":
269
+ kwargs["accelerator"] = "gpu"
270
+ if early_stopping:
271
+ kwargs["early_stopping"] = True
272
+ kwargs["early_stopping_patience"] = early_stopping_patience
273
+ kwargs["check_val_every_n_epoch"] = 1
274
+ kwargs["train_size"] = 1.0 - validation_size
275
+ else:
276
+ kwargs["train_size"] = train_size
277
+ else:
278
+ kwargs = {
279
+ "max_epochs": epochs,
280
+ "batch_size": 2500,
281
+ "lr": lr,
282
+ "train_size": train_size,
283
+ }
284
+ if device == "cuda":
285
+ kwargs["accelerator"] = "gpu"
286
+ return kwargs
287
+
288
+
289
+ def _extract_reference_signatures(ref: "ad.AnnData") -> pd.DataFrame:
290
+ """Extract reference signatures from trained RegressionModel."""
291
+ factor_names = ref.uns["mod"]["factor_names"]
292
+ cols = [f"means_per_cluster_mu_fg_{i}" for i in factor_names]
293
+
294
+ if "means_per_cluster_mu_fg" in ref.varm:
295
+ signatures = ref.varm["means_per_cluster_mu_fg"][cols].copy()
296
+ else:
297
+ signatures = ref.var[cols].copy()
298
+
299
+ signatures.columns = factor_names
300
+ return signatures
301
+
302
+
303
+ def _extract_cell_abundance(sp: "ad.AnnData"):
304
+ """Extract cell abundance from Cell2location results.
305
+
306
+ Cell2location stores results as DataFrames with prefixed column names like
307
+ 'q05cell_abundance_w_sf_CellType'. We need to extract the values and
308
+ return them as a numpy array for consistent downstream processing.
309
+ """
310
+ possible_keys = [
311
+ "q05_cell_abundance_w_sf",
312
+ "means_cell_abundance_w_sf",
313
+ "q50_cell_abundance_w_sf",
314
+ ]
315
+
316
+ for key in possible_keys:
317
+ if key in sp.obsm:
318
+ result = sp.obsm[key]
319
+ if hasattr(result, "values"):
320
+ return result.values
321
+ return result
322
+
323
+ raise ProcessingError(
324
+ f"Cell2location did not produce expected output. "
325
+ f"Available keys: {list(sp.obsm.keys())}"
326
+ )
@@ -0,0 +1,144 @@
1
+ """
2
+ DestVI deconvolution method.
3
+
4
+ DestVI performs multi-resolution deconvolution by first training a CondSCVI
5
+ model on reference data, then using it to initialize a DestVI model.
6
+ """
7
+
8
+ import gc
9
+ from typing import Any
10
+
11
+ import pandas as pd
12
+
13
+ from ...utils.dependency_manager import is_available
14
+ from ...utils.exceptions import DataError, DependencyError, ProcessingError
15
+ from .base import PreparedDeconvolutionData, create_deconvolution_stats
16
+
17
+
18
+ def deconvolve(
19
+ data: PreparedDeconvolutionData,
20
+ n_epochs: int = 10000,
21
+ n_hidden: int = 128,
22
+ n_latent: int = 10,
23
+ n_layers: int = 1,
24
+ dropout_rate: float = 0.1,
25
+ learning_rate: float = 1e-3,
26
+ train_size: float = 0.9,
27
+ vamp_prior_p: int = 15,
28
+ l1_reg: float = 10.0,
29
+ use_gpu: bool = False,
30
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
31
+ """Deconvolve spatial data using DestVI from scvi-tools.
32
+
33
+ Args:
34
+ data: Prepared deconvolution data (immutable)
35
+ n_epochs: Total epochs (split between CondSCVI and DestVI)
36
+ n_hidden: Hidden units in neural networks
37
+ n_latent: Latent space dimensionality
38
+ n_layers: Number of layers
39
+ dropout_rate: Dropout rate
40
+ learning_rate: Learning rate
41
+ train_size: Fraction for training (default: 0.9)
42
+ vamp_prior_p: VampPrior components (default: 15)
43
+ l1_reg: L1 regularization (default: 10.0)
44
+ use_gpu: Use GPU acceleration
45
+
46
+ Returns:
47
+ Tuple of (proportions DataFrame, statistics dictionary)
48
+ """
49
+ if not is_available("scvi-tools"):
50
+ raise DependencyError(
51
+ "scvi-tools is required for DestVI. Install with: pip install scvi-tools"
52
+ )
53
+
54
+ import scvi
55
+
56
+ try:
57
+ # Data already copied in prepare_deconvolution
58
+ spatial_data = data.spatial
59
+ ref_data = data.reference
60
+
61
+ # Validate cell types
62
+ if data.n_cell_types < 2:
63
+ raise DataError(
64
+ f"Reference needs at least 2 cell types, found {data.n_cell_types}"
65
+ )
66
+
67
+ # Calculate epoch distribution
68
+ condscvi_epochs = max(400, n_epochs // 5)
69
+ destvi_epochs = max(200, n_epochs // 10)
70
+
71
+ # Device setting
72
+ accelerator = "gpu" if use_gpu else "cpu"
73
+ plan_kwargs = {"lr": learning_rate}
74
+
75
+ # ===== Stage 1: Train CondSCVI on reference =====
76
+ scvi.model.CondSCVI.setup_anndata(
77
+ ref_data,
78
+ labels_key=data.cell_type_key,
79
+ batch_key=None,
80
+ )
81
+
82
+ condscvi_model = scvi.model.CondSCVI(
83
+ ref_data,
84
+ n_hidden=n_hidden,
85
+ n_latent=n_latent,
86
+ n_layers=n_layers,
87
+ dropout_rate=dropout_rate,
88
+ )
89
+
90
+ condscvi_model.train(
91
+ max_epochs=condscvi_epochs,
92
+ accelerator=accelerator,
93
+ train_size=train_size,
94
+ plan_kwargs=plan_kwargs,
95
+ )
96
+
97
+ # ===== Stage 2: Train DestVI on spatial =====
98
+ scvi.model.DestVI.setup_anndata(spatial_data)
99
+
100
+ destvi_model = scvi.model.DestVI.from_rna_model(
101
+ spatial_data,
102
+ condscvi_model,
103
+ vamp_prior_p=vamp_prior_p,
104
+ l1_reg=l1_reg,
105
+ )
106
+
107
+ destvi_model.train(
108
+ max_epochs=destvi_epochs,
109
+ accelerator=accelerator,
110
+ train_size=train_size,
111
+ plan_kwargs=plan_kwargs,
112
+ )
113
+
114
+ # Get proportions
115
+ proportions = destvi_model.get_proportions()
116
+ proportions.index = spatial_data.obs_names
117
+
118
+ if proportions.empty or len(proportions) != spatial_data.n_obs:
119
+ raise ProcessingError("Failed to extract valid proportions from DestVI")
120
+
121
+ # Create statistics
122
+ stats = create_deconvolution_stats(
123
+ proportions,
124
+ data.common_genes,
125
+ method="DestVI",
126
+ device="gpu" if use_gpu else "cpu",
127
+ n_epochs=n_epochs,
128
+ condscvi_epochs=condscvi_epochs,
129
+ destvi_epochs=destvi_epochs,
130
+ n_hidden=n_hidden,
131
+ n_latent=n_latent,
132
+ )
133
+
134
+ # Memory cleanup
135
+ del destvi_model, condscvi_model
136
+ del spatial_data, ref_data
137
+ gc.collect()
138
+
139
+ return proportions, stats
140
+
141
+ except Exception as e:
142
+ if isinstance(e, (DependencyError, DataError, ProcessingError)):
143
+ raise
144
+ raise ProcessingError(f"DestVI deconvolution failed: {e}") from e
@@ -0,0 +1,101 @@
1
+ """
2
+ FlashDeconv deconvolution method.
3
+
4
+ FlashDeconv is an ultra-fast spatial transcriptomics deconvolution method
5
+ that uses random sketching for O(N) time complexity.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ import pandas as pd
11
+
12
+ from ...utils.dependency_manager import is_available
13
+ from ...utils.exceptions import DependencyError, ProcessingError
14
+ from .base import PreparedDeconvolutionData, create_deconvolution_stats
15
+
16
+
17
+ def deconvolve(
18
+ data: PreparedDeconvolutionData,
19
+ sketch_dim: int = 512,
20
+ lambda_spatial: float = 5000.0,
21
+ n_hvg: int = 2000,
22
+ n_markers_per_type: int = 50,
23
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
24
+ """Deconvolve spatial data using FlashDeconv.
25
+
26
+ FlashDeconv is an ultra-fast deconvolution method with:
27
+ - O(N) time complexity via random sketching
28
+ - Processes 1M spots in ~3 minutes on CPU
29
+ - No GPU required
30
+ - Automatic marker gene selection
31
+ - Spatial regularization for smooth proportions
32
+
33
+ Args:
34
+ data: Prepared deconvolution data (immutable)
35
+ sketch_dim: Dimension for random sketching (default: 512)
36
+ lambda_spatial: Spatial regularization strength (default: 5000.0)
37
+ n_hvg: Number of highly variable genes to use (default: 2000)
38
+ n_markers_per_type: Number of marker genes per cell type (default: 50)
39
+
40
+ Returns:
41
+ Tuple of (proportions DataFrame, statistics dictionary)
42
+ """
43
+ if not is_available("flashdeconv"):
44
+ raise DependencyError(
45
+ "FlashDeconv is not available. Install with: pip install flashdeconv"
46
+ )
47
+
48
+ try:
49
+ import flashdeconv as fd
50
+
51
+ # Data already copied in prepare_deconvolution
52
+ adata_st = data.spatial
53
+ reference = data.reference
54
+
55
+ # Run FlashDeconv
56
+ fd.tl.deconvolve(
57
+ adata_st,
58
+ reference,
59
+ cell_type_key=data.cell_type_key,
60
+ sketch_dim=sketch_dim,
61
+ lambda_spatial=lambda_spatial,
62
+ n_hvg=n_hvg,
63
+ n_markers_per_type=n_markers_per_type,
64
+ )
65
+
66
+ # Extract proportions
67
+ if "flashdeconv" not in adata_st.obsm:
68
+ raise ProcessingError(
69
+ "FlashDeconv did not produce output in adata.obsm['flashdeconv']"
70
+ )
71
+
72
+ proportions = adata_st.obsm["flashdeconv"].copy()
73
+
74
+ # Ensure DataFrame format
75
+ if not isinstance(proportions, pd.DataFrame):
76
+ proportions = pd.DataFrame(
77
+ proportions,
78
+ index=data.spatial.obs_names,
79
+ columns=data.cell_types,
80
+ )
81
+ else:
82
+ proportions.index = data.spatial.obs_names
83
+
84
+ # Create statistics
85
+ stats = create_deconvolution_stats(
86
+ proportions,
87
+ data.common_genes,
88
+ method="FlashDeconv",
89
+ device="CPU",
90
+ sketch_dim=sketch_dim,
91
+ lambda_spatial=lambda_spatial,
92
+ n_hvg=n_hvg,
93
+ n_markers_per_type=n_markers_per_type,
94
+ )
95
+
96
+ return proportions, stats
97
+
98
+ except Exception as e:
99
+ if isinstance(e, (DependencyError, ProcessingError)):
100
+ raise
101
+ raise ProcessingError(f"FlashDeconv deconvolution failed: {e}") from e