chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,402 @@
1
+ """
2
+ Deconvolution module for spatial transcriptomics data.
3
+
4
+ This module provides a unified interface for multiple deconvolution methods:
5
+ - flashdeconv: Ultra-fast deconvolution with O(N) complexity (recommended)
6
+ - cell2location: Bayesian deconvolution with spatial priors
7
+ - destvi: Deep learning-based multi-resolution deconvolution
8
+ - stereoscope: Two-stage probabilistic deconvolution
9
+ - rctd: Robust Cell Type Decomposition (R-based)
10
+ - spotlight: NMF-based deconvolution (R-based)
11
+ - card: CAR model with spatial correlation (R-based)
12
+ - tangram: Deep learning mapping via scvi-tools
13
+
14
+ Usage:
15
+ from chatspatial.tools.deconvolution import deconvolve_spatial_data
16
+ result = await deconvolve_spatial_data(data_id, ctx, params)
17
+ """
18
+
19
+ import gc
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+
25
+ if TYPE_CHECKING:
26
+ import anndata as ad
27
+
28
+ from ...spatial_mcp_adapter import ToolContext
29
+
30
+ from ...models.analysis import DeconvolutionResult
31
+ from ...models.data import DeconvolutionParameters
32
+ from ...utils.adata_utils import (
33
+ ensure_unique_var_names_async,
34
+ store_analysis_metadata,
35
+ validate_obs_column,
36
+ )
37
+ from ...utils.exceptions import DataError, DependencyError, ParameterError
38
+ from .base import PreparedDeconvolutionData, prepare_deconvolution
39
+
40
+ # Export main function and data container
41
+ __all__ = ["deconvolve_spatial_data", "PreparedDeconvolutionData"]
42
+
43
+
44
+ async def deconvolve_spatial_data(
45
+ data_id: str,
46
+ ctx: "ToolContext",
47
+ params: DeconvolutionParameters,
48
+ ) -> DeconvolutionResult:
49
+ """Deconvolve spatial transcriptomics data to estimate cell type proportions.
50
+
51
+ This is the main entry point for all deconvolution methods. It handles:
52
+ - Data loading and validation
53
+ - Method selection and dependency checking
54
+ - Dispatching to the appropriate method-specific implementation
55
+ - Result storage and formatting
56
+
57
+ Args:
58
+ data_id: Dataset ID for spatial data
59
+ ctx: Tool context for data access and logging
60
+ params: Deconvolution parameters (must include method and cell_type_key)
61
+
62
+ Returns:
63
+ DeconvolutionResult with cell type proportions and statistics
64
+ """
65
+ # Validate input
66
+ if not data_id:
67
+ raise ParameterError("Dataset ID cannot be empty")
68
+
69
+ # Get spatial data
70
+ spatial_adata = await ctx.get_adata(data_id)
71
+ if spatial_adata.n_obs == 0:
72
+ raise DataError(f"Dataset {data_id} contains no observations")
73
+
74
+ await ensure_unique_var_names_async(spatial_adata, ctx, "spatial data")
75
+
76
+ # Load reference data for methods that require it
77
+ reference_adata = None
78
+ if params.method in _METHODS_REQUIRING_REFERENCE:
79
+ if not params.reference_data_id:
80
+ raise ParameterError(
81
+ f"Method '{params.method}' requires reference_data_id."
82
+ )
83
+
84
+ reference_adata = await ctx.get_adata(params.reference_data_id)
85
+ if reference_adata.n_obs == 0:
86
+ raise DataError(
87
+ f"Reference dataset {params.reference_data_id} contains no observations"
88
+ )
89
+
90
+ await ensure_unique_var_names_async(reference_adata, ctx, "reference data")
91
+ validate_obs_column(reference_adata, params.cell_type_key, "Cell type")
92
+
93
+ # Check method availability
94
+ _check_method_availability(params.method)
95
+
96
+ # Prepare data using unified function
97
+ require_int = params.method in _R_BASED_METHODS
98
+ preprocess_hook = _get_preprocess_hook(params)
99
+
100
+ data = await prepare_deconvolution(
101
+ spatial_adata=spatial_adata,
102
+ reference_adata=reference_adata,
103
+ cell_type_key=params.cell_type_key,
104
+ ctx=ctx,
105
+ require_int_dtype=require_int,
106
+ preprocess=preprocess_hook,
107
+ )
108
+
109
+ # Dispatch to method-specific implementation
110
+ proportions, stats = _dispatch_method(data, params)
111
+
112
+ # Memory cleanup
113
+ del data
114
+ gc.collect()
115
+
116
+ # Store results in AnnData
117
+ result = await _store_results(
118
+ spatial_adata, proportions, stats, params.method, data_id, ctx
119
+ )
120
+
121
+ return result
122
+
123
+
124
+ # =============================================================================
125
+ # Method Registry
126
+ # =============================================================================
127
+
128
+ _METHODS_REQUIRING_REFERENCE = {
129
+ "flashdeconv",
130
+ "cell2location",
131
+ "rctd",
132
+ "destvi",
133
+ "stereoscope",
134
+ "tangram",
135
+ "spotlight",
136
+ "card",
137
+ }
138
+
139
+ _R_BASED_METHODS = {"rctd", "spotlight", "card"}
140
+
141
+ _METHOD_DEPENDENCIES = {
142
+ "flashdeconv": ["flashdeconv"],
143
+ "cell2location": ["cell2location", "torch"],
144
+ "destvi": ["scvi", "torch"],
145
+ "stereoscope": ["scvi", "torch"],
146
+ "tangram": ["scvi", "torch", "tangram", "mudata"],
147
+ "rctd": ["rpy2"],
148
+ "spotlight": ["rpy2"],
149
+ "card": ["rpy2"],
150
+ }
151
+
152
+
153
+ def _check_method_availability(method: str) -> None:
154
+ """Check if a deconvolution method is available."""
155
+ import importlib.util
156
+
157
+ deps = _METHOD_DEPENDENCIES.get(method, [])
158
+ missing = []
159
+
160
+ for dep in deps:
161
+ import_name = "scvi" if dep == "scvi-tools" else dep.replace("-", "_")
162
+ if importlib.util.find_spec(import_name) is None:
163
+ missing.append(dep)
164
+
165
+ if missing:
166
+ available = []
167
+ for m, d in _METHOD_DEPENDENCIES.items():
168
+ check_deps = [
169
+ "scvi" if x == "scvi-tools" else x.replace("-", "_") for x in d
170
+ ]
171
+ if all(importlib.util.find_spec(x) is not None for x in check_deps):
172
+ available.append(m)
173
+
174
+ alt_msg = f"Available: {', '.join(available)}" if available else ""
175
+ if "flashdeconv" in available:
176
+ alt_msg += " (flashdeconv recommended - fastest)"
177
+
178
+ raise DependencyError(
179
+ f"Method '{method}' requires: {', '.join(missing)}. {alt_msg}"
180
+ )
181
+
182
+
183
+ def _get_preprocess_hook(params: DeconvolutionParameters):
184
+ """Get method-specific preprocessing hook if needed."""
185
+ if params.method == "cell2location" and params.cell2location_apply_gene_filtering:
186
+ # Return a closure that captures the filter params
187
+ async def cell2location_preprocess(spatial, reference, ctx):
188
+ from .cell2location import apply_gene_filtering
189
+
190
+ sp = await apply_gene_filtering(
191
+ spatial,
192
+ ctx,
193
+ cell_count_cutoff=params.cell2location_gene_filter_cell_count_cutoff,
194
+ cell_percentage_cutoff2=params.cell2location_gene_filter_cell_percentage_cutoff2,
195
+ nonz_mean_cutoff=params.cell2location_gene_filter_nonz_mean_cutoff,
196
+ )
197
+ ref = await apply_gene_filtering(
198
+ reference,
199
+ ctx,
200
+ cell_count_cutoff=params.cell2location_gene_filter_cell_count_cutoff,
201
+ cell_percentage_cutoff2=params.cell2location_gene_filter_cell_percentage_cutoff2,
202
+ nonz_mean_cutoff=params.cell2location_gene_filter_nonz_mean_cutoff,
203
+ )
204
+ return sp, ref
205
+
206
+ return cell2location_preprocess
207
+
208
+ return None
209
+
210
+
211
+ def _dispatch_method(
212
+ data: PreparedDeconvolutionData,
213
+ params: DeconvolutionParameters,
214
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
215
+ """Dispatch to the appropriate method implementation."""
216
+ method = params.method
217
+
218
+ if method == "flashdeconv":
219
+ from . import flashdeconv
220
+
221
+ return flashdeconv.deconvolve(
222
+ data,
223
+ sketch_dim=params.flashdeconv_sketch_dim,
224
+ lambda_spatial=params.flashdeconv_lambda_spatial,
225
+ n_hvg=params.flashdeconv_n_hvg,
226
+ n_markers_per_type=params.flashdeconv_n_markers_per_type,
227
+ )
228
+
229
+ elif method == "cell2location":
230
+ from . import cell2location
231
+
232
+ return cell2location.deconvolve(
233
+ data,
234
+ ref_model_epochs=params.cell2location_ref_model_epochs,
235
+ n_epochs=params.cell2location_n_epochs,
236
+ n_cells_per_spot=params.cell2location_n_cells_per_spot or 30,
237
+ detection_alpha=params.cell2location_detection_alpha,
238
+ use_gpu=params.use_gpu,
239
+ batch_key=params.cell2location_batch_key,
240
+ categorical_covariate_keys=params.cell2location_categorical_covariate_keys,
241
+ ref_model_lr=params.cell2location_ref_model_lr,
242
+ cell2location_lr=params.cell2location_lr,
243
+ ref_model_train_size=params.cell2location_ref_model_train_size,
244
+ cell2location_train_size=params.cell2location_train_size,
245
+ early_stopping=params.cell2location_early_stopping,
246
+ early_stopping_patience=params.cell2location_early_stopping_patience,
247
+ early_stopping_threshold=params.cell2location_early_stopping_threshold,
248
+ use_aggressive_training=params.cell2location_use_aggressive_training,
249
+ validation_size=params.cell2location_validation_size,
250
+ )
251
+
252
+ elif method == "destvi":
253
+ from . import destvi
254
+
255
+ return destvi.deconvolve(
256
+ data,
257
+ n_epochs=params.destvi_n_epochs,
258
+ n_hidden=params.destvi_n_hidden,
259
+ n_latent=params.destvi_n_latent,
260
+ n_layers=params.destvi_n_layers,
261
+ dropout_rate=params.destvi_dropout_rate,
262
+ learning_rate=params.destvi_learning_rate,
263
+ train_size=params.destvi_train_size,
264
+ vamp_prior_p=params.destvi_vamp_prior_p,
265
+ l1_reg=params.destvi_l1_reg,
266
+ use_gpu=params.use_gpu,
267
+ )
268
+
269
+ elif method == "stereoscope":
270
+ from . import stereoscope
271
+
272
+ return stereoscope.deconvolve(
273
+ data,
274
+ n_epochs=params.stereoscope_n_epochs,
275
+ learning_rate=params.stereoscope_learning_rate,
276
+ batch_size=params.stereoscope_batch_size,
277
+ use_gpu=params.use_gpu,
278
+ )
279
+
280
+ elif method == "rctd":
281
+ from . import rctd
282
+
283
+ return rctd.deconvolve(
284
+ data,
285
+ mode=params.rctd_mode,
286
+ max_cores=params.max_cores,
287
+ confidence_threshold=params.rctd_confidence_threshold,
288
+ doublet_threshold=params.rctd_doublet_threshold,
289
+ max_multi_types=params.rctd_max_multi_types,
290
+ )
291
+
292
+ elif method == "spotlight":
293
+ from . import spotlight
294
+
295
+ return spotlight.deconvolve(
296
+ data,
297
+ n_top_genes=params.spotlight_n_top_genes,
298
+ nmf_model=params.spotlight_nmf_model,
299
+ min_prop=params.spotlight_min_prop,
300
+ scale=params.spotlight_scale,
301
+ weight_id=params.spotlight_weight_id,
302
+ )
303
+
304
+ elif method == "card":
305
+ from . import card
306
+
307
+ return card.deconvolve(
308
+ data,
309
+ sample_key=params.card_sample_key,
310
+ minCountGene=params.card_minCountGene,
311
+ minCountSpot=params.card_minCountSpot,
312
+ imputation=params.card_imputation,
313
+ NumGrids=params.card_NumGrids,
314
+ ineibor=params.card_ineibor,
315
+ )
316
+
317
+ elif method == "tangram":
318
+ from . import tangram
319
+
320
+ return tangram.deconvolve(
321
+ data,
322
+ n_epochs=params.tangram_n_epochs,
323
+ mode=params.tangram_mode,
324
+ learning_rate=params.tangram_learning_rate,
325
+ density_prior=params.tangram_density_prior,
326
+ use_gpu=params.use_gpu,
327
+ )
328
+
329
+ else:
330
+ raise ParameterError(
331
+ f"Unsupported method: {params.method}. "
332
+ f"Supported: flashdeconv, cell2location, destvi, stereoscope, "
333
+ f"rctd, spotlight, card, tangram"
334
+ )
335
+
336
+
337
+ async def _store_results(
338
+ spatial_adata: "ad.AnnData",
339
+ proportions: pd.DataFrame,
340
+ stats: dict[str, Any],
341
+ method: str,
342
+ data_id: str,
343
+ ctx: "ToolContext",
344
+ ) -> DeconvolutionResult:
345
+ """Store deconvolution results in AnnData and return result object."""
346
+ proportions_key = f"deconvolution_{method}"
347
+ cell_types = list(proportions.columns)
348
+
349
+ # Align proportions with spatial_adata.obs_names
350
+ full_proportions = proportions.reindex(spatial_adata.obs_names).fillna(0).values
351
+
352
+ # Store in obsm
353
+ spatial_adata.obsm[proportions_key] = full_proportions
354
+
355
+ # Store cell type names
356
+ spatial_adata.uns[f"{proportions_key}_cell_types"] = cell_types
357
+
358
+ # Add individual cell type columns to obs
359
+ for i, ct in enumerate(cell_types):
360
+ spatial_adata.obs[f"{proportions_key}_{ct}"] = full_proportions[:, i]
361
+
362
+ # Add dominant cell type annotation
363
+ dominant_key = f"dominant_celltype_{method}"
364
+ cell_types_array = np.array(cell_types)
365
+ dominant_types = cell_types_array[np.argmax(full_proportions, axis=1)]
366
+ spatial_adata.obs[dominant_key] = pd.Categorical(dominant_types)
367
+
368
+ # Store metadata for provenance tracking (enables reading stored values
369
+ # instead of inferring from key names in visualization)
370
+ store_analysis_metadata(
371
+ spatial_adata,
372
+ analysis_name=f"deconvolution_{method}",
373
+ method=method,
374
+ parameters={}, # Method-specific params already in stats
375
+ results_keys={
376
+ "obsm": [proportions_key],
377
+ "obs": [dominant_key],
378
+ "uns": [f"{proportions_key}_cell_types"],
379
+ },
380
+ statistics={
381
+ "n_cell_types": len(cell_types),
382
+ "n_spots": len(full_proportions),
383
+ "cell_types": cell_types,
384
+ "proportions_key": proportions_key,
385
+ "dominant_type_key": dominant_key,
386
+ },
387
+ )
388
+
389
+ # Save updated data
390
+ await ctx.set_adata(data_id, spatial_adata)
391
+
392
+ return DeconvolutionResult(
393
+ data_id=data_id,
394
+ method=method,
395
+ dominant_type_key=dominant_key,
396
+ n_cell_types=len(cell_types),
397
+ cell_types=cell_types,
398
+ proportions_key=proportions_key,
399
+ n_spots=stats.get("n_spots", 0),
400
+ genes_used=stats.get("genes_used", stats.get("common_genes", 0)),
401
+ statistics=stats,
402
+ )