chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,411 @@
1
+ """
2
+ RNA velocity visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - Velocity stream plots
6
+ - Phase plots (spliced vs unspliced)
7
+ - Proportions plots (pie charts)
8
+ - Velocity heatmaps
9
+ - PAGA with velocity arrows
10
+ """
11
+
12
+ from typing import TYPE_CHECKING, Optional
13
+
14
+ import matplotlib.pyplot as plt
15
+
16
+ if TYPE_CHECKING:
17
+ import anndata as ad
18
+
19
+ from ...spatial_mcp_adapter import ToolContext
20
+
21
+ from ...models.data import VisualizationParameters
22
+ from ...utils.adata_utils import validate_obs_column
23
+ from ...utils.dependency_manager import require
24
+ from ...utils.exceptions import (
25
+ DataCompatibilityError,
26
+ DataNotFoundError,
27
+ ParameterError,
28
+ )
29
+ from .core import (
30
+ create_figure_from_params,
31
+ get_categorical_columns,
32
+ infer_basis,
33
+ resolve_figure_size,
34
+ )
35
+
36
+ # =============================================================================
37
+ # Main Router
38
+ # =============================================================================
39
+
40
+
41
+ async def create_rna_velocity_visualization(
42
+ adata: "ad.AnnData",
43
+ params: VisualizationParameters,
44
+ context: Optional["ToolContext"] = None,
45
+ ) -> plt.Figure:
46
+ """Create RNA velocity visualization based on subtype.
47
+
48
+ Dispatcher function that routes to appropriate scVelo visualization.
49
+
50
+ Args:
51
+ adata: AnnData object with computed RNA velocity
52
+ params: Visualization parameters including subtype
53
+ context: MCP context
54
+
55
+ Returns:
56
+ Matplotlib figure with RNA velocity visualization
57
+
58
+ Subtypes:
59
+ - stream (default): Velocity embedding stream plot
60
+ - phase: Phase plot showing spliced vs unspliced
61
+ - proportions: Pie chart of spliced/unspliced ratios
62
+ - heatmap: Gene expression ordered by latent_time
63
+ - paga: PAGA with velocity arrows
64
+ """
65
+ subtype = params.subtype or "stream"
66
+
67
+ if context:
68
+ await context.info(f"Creating RNA velocity visualization (subtype: {subtype})")
69
+
70
+ if subtype == "stream":
71
+ return await _create_velocity_stream_plot(adata, params, context)
72
+ elif subtype == "phase":
73
+ return await _create_velocity_phase_plot(adata, params, context)
74
+ elif subtype == "proportions":
75
+ return await _create_velocity_proportions_plot(adata, params, context)
76
+ elif subtype == "heatmap":
77
+ return await _create_velocity_heatmap(adata, params, context)
78
+ elif subtype == "paga":
79
+ return await _create_velocity_paga_plot(adata, params, context)
80
+ else:
81
+ raise ParameterError(
82
+ f"Unsupported subtype for rna_velocity: '{subtype}'. "
83
+ f"Available subtypes: stream, phase, proportions, heatmap, paga"
84
+ )
85
+
86
+
87
+ # =============================================================================
88
+ # Visualization Functions
89
+ # =============================================================================
90
+
91
+
92
+ async def _create_velocity_stream_plot(
93
+ adata: "ad.AnnData",
94
+ params: VisualizationParameters,
95
+ context: Optional["ToolContext"] = None,
96
+ ) -> plt.Figure:
97
+ """Create RNA velocity stream plot using scv.pl.velocity_embedding_stream.
98
+
99
+ Data requirements:
100
+ - adata.uns['velocity_graph']: Velocity transition graph
101
+ - adata.obsm['X_umap'] or 'spatial': Embedding for visualization
102
+ """
103
+ require("scvelo", feature="RNA velocity visualization")
104
+ import scvelo as scv
105
+
106
+ if "velocity_graph" not in adata.uns:
107
+ raise DataNotFoundError(
108
+ "RNA velocity not computed. Run analyze_velocity_data first."
109
+ )
110
+
111
+ # Determine basis for plotting
112
+ basis = infer_basis(adata, preferred=params.basis)
113
+ if not basis:
114
+ raise DataCompatibilityError(
115
+ f"No valid embedding basis found. "
116
+ f"Available keys: {list(adata.obsm.keys())}"
117
+ )
118
+ if context and basis != params.basis:
119
+ await context.info(f"Using '{basis}' as basis")
120
+
121
+ # Prepare feature for coloring
122
+ feature = params.feature
123
+ if not feature:
124
+ categorical_cols = get_categorical_columns(adata)
125
+ feature = categorical_cols[0] if categorical_cols else None
126
+ if feature and context:
127
+ await context.info(f"Using '{feature}' for coloring")
128
+
129
+ fig, axes = create_figure_from_params(params, "velocity")
130
+ ax = axes[0]
131
+
132
+ scv.pl.velocity_embedding_stream(
133
+ adata,
134
+ basis=basis,
135
+ color=feature,
136
+ ax=ax,
137
+ show=False,
138
+ alpha=params.alpha,
139
+ legend_loc="right margin" if feature and feature in adata.obs.columns else None,
140
+ frameon=params.show_axes,
141
+ title="",
142
+ )
143
+
144
+ title = params.title or f"RNA Velocity Stream on {basis.capitalize()}"
145
+ ax.set_title(title, fontsize=14)
146
+
147
+ if basis == "spatial":
148
+ ax.invert_yaxis()
149
+
150
+ plt.tight_layout()
151
+ return fig
152
+
153
+
154
+ async def _create_velocity_phase_plot(
155
+ adata: "ad.AnnData",
156
+ params: VisualizationParameters,
157
+ context: Optional["ToolContext"] = None,
158
+ ) -> plt.Figure:
159
+ """Create velocity phase plot using scv.pl.velocity.
160
+
161
+ Shows spliced vs unspliced counts with fitted velocity model for specified genes.
162
+
163
+ Data requirements:
164
+ - adata.layers['velocity']: Velocity vectors
165
+ - adata.layers['Ms']: Smoothed spliced counts
166
+ - adata.layers['Mu']: Smoothed unspliced counts
167
+ """
168
+ require("scvelo", feature="velocity phase plots")
169
+ import scvelo as scv
170
+
171
+ required_layers = ["velocity", "Ms", "Mu"]
172
+ missing_layers = [layer for layer in required_layers if layer not in adata.layers]
173
+ if missing_layers:
174
+ raise DataNotFoundError(
175
+ f"Missing layers for phase plot: {missing_layers}. Run velocity analysis first."
176
+ )
177
+
178
+ if params.feature:
179
+ if isinstance(params.feature, str):
180
+ var_names = [params.feature]
181
+ else:
182
+ var_names = list(params.feature)
183
+ else:
184
+ if "velocity_genes" in adata.var.columns:
185
+ velocity_genes = adata.var_names[adata.var["velocity_genes"]]
186
+ var_names = list(velocity_genes[:4])
187
+ else:
188
+ var_names = list(adata.var_names[:4])
189
+
190
+ valid_genes = [g for g in var_names if g in adata.var_names]
191
+ if not valid_genes:
192
+ raise DataNotFoundError(
193
+ f"None of the specified genes found in data: {var_names}. "
194
+ f"Available genes (first 10): {list(adata.var_names[:10])}"
195
+ )
196
+
197
+ if context:
198
+ await context.info(f"Creating phase plot for genes: {valid_genes}")
199
+
200
+ basis = infer_basis(adata, preferred=params.basis, priority=["umap", "spatial"])
201
+ figsize = resolve_figure_size(
202
+ params, n_panels=len(valid_genes), panel_width=4, panel_height=4
203
+ )
204
+ color = params.cluster_key if params.cluster_key else None
205
+
206
+ scv.pl.velocity(
207
+ adata,
208
+ var_names=valid_genes,
209
+ basis=basis,
210
+ color=color,
211
+ figsize=figsize,
212
+ dpi=params.dpi,
213
+ show=False,
214
+ ncols=len(valid_genes),
215
+ )
216
+
217
+ fig = plt.gcf()
218
+ title = params.title or "RNA Velocity Phase Plot"
219
+ fig.suptitle(title, fontsize=14, y=1.02)
220
+ plt.tight_layout()
221
+ return fig
222
+
223
+
224
+ async def _create_velocity_proportions_plot(
225
+ adata: "ad.AnnData",
226
+ params: VisualizationParameters,
227
+ context: Optional["ToolContext"] = None,
228
+ ) -> plt.Figure:
229
+ """Create velocity proportions plot using scv.pl.proportions.
230
+
231
+ Shows pie chart of spliced/unspliced RNA proportions per cluster.
232
+
233
+ Data requirements:
234
+ - adata.layers['spliced']: Spliced counts
235
+ - adata.layers['unspliced']: Unspliced counts
236
+ - adata.obs[cluster_key]: Cluster labels for grouping
237
+ """
238
+ require("scvelo", feature="proportions plot")
239
+ import scvelo as scv
240
+
241
+ if "spliced" not in adata.layers or "unspliced" not in adata.layers:
242
+ raise DataNotFoundError(
243
+ "Spliced and unspliced layers are required for proportions plot. "
244
+ "Your data may not contain RNA velocity information."
245
+ )
246
+
247
+ cluster_key = params.cluster_key
248
+ if not cluster_key:
249
+ categorical_cols = get_categorical_columns(adata)
250
+ if categorical_cols:
251
+ cluster_key = categorical_cols[0]
252
+ if context:
253
+ await context.info(f"Using cluster_key: '{cluster_key}'")
254
+ else:
255
+ raise ParameterError(
256
+ "cluster_key is required for proportions plot. "
257
+ f"Available columns: {list(adata.obs.columns)[:10]}"
258
+ )
259
+
260
+ validate_obs_column(adata, cluster_key, "Cluster")
261
+
262
+ if context:
263
+ await context.info(f"Creating proportions plot grouped by '{cluster_key}'")
264
+
265
+ figsize = resolve_figure_size(params, "violin")
266
+
267
+ scv.pl.proportions(
268
+ adata,
269
+ groupby=cluster_key,
270
+ figsize=figsize,
271
+ dpi=params.dpi,
272
+ show=False,
273
+ )
274
+
275
+ fig = plt.gcf()
276
+ title = params.title or f"Spliced/Unspliced Proportions by {cluster_key}"
277
+ fig.suptitle(title, fontsize=14, y=1.02)
278
+ plt.tight_layout()
279
+ return fig
280
+
281
+
282
+ async def _create_velocity_heatmap(
283
+ adata: "ad.AnnData",
284
+ params: VisualizationParameters,
285
+ context: Optional["ToolContext"] = None,
286
+ ) -> plt.Figure:
287
+ """Create velocity heatmap using scv.pl.heatmap.
288
+
289
+ Shows gene expression patterns ordered by latent time.
290
+
291
+ Data requirements:
292
+ - adata.obs['latent_time']: Latent time from dynamical model
293
+ - adata.var['velocity_genes']: Velocity genes (optional)
294
+ """
295
+ require("scvelo", feature="velocity heatmap")
296
+ import scvelo as scv
297
+
298
+ validate_obs_column(adata, "latent_time", "Latent time")
299
+
300
+ if params.feature:
301
+ if isinstance(params.feature, str):
302
+ var_names = [params.feature]
303
+ else:
304
+ var_names = list(params.feature)
305
+ valid_genes = [g for g in var_names if g in adata.var_names]
306
+ if not valid_genes:
307
+ raise DataNotFoundError(f"None of the specified genes found: {var_names}")
308
+ var_names = valid_genes
309
+ else:
310
+ if "velocity_genes" in adata.var.columns:
311
+ velocity_genes = adata.var_names[adata.var["velocity_genes"]]
312
+ var_names = list(velocity_genes[:50])
313
+ else:
314
+ if "highly_variable" in adata.var.columns:
315
+ hvg = adata.var_names[adata.var["highly_variable"]]
316
+ var_names = list(hvg[:50])
317
+ else:
318
+ var_names = list(adata.var_names[:50])
319
+
320
+ if context:
321
+ await context.info(f"Creating velocity heatmap with {len(var_names)} genes")
322
+
323
+ figsize = resolve_figure_size(params, "heatmap")
324
+
325
+ scv.pl.heatmap(
326
+ adata,
327
+ var_names=var_names,
328
+ sortby="latent_time",
329
+ col_color=params.cluster_key,
330
+ n_convolve=30,
331
+ show=False,
332
+ figsize=figsize,
333
+ )
334
+
335
+ fig = plt.gcf()
336
+ fig.set_dpi(params.dpi)
337
+
338
+ if params.title:
339
+ fig.suptitle(params.title, fontsize=14, y=1.02)
340
+ plt.tight_layout()
341
+ return fig
342
+
343
+
344
+ async def _create_velocity_paga_plot(
345
+ adata: "ad.AnnData",
346
+ params: VisualizationParameters,
347
+ context: Optional["ToolContext"] = None,
348
+ ) -> plt.Figure:
349
+ """Create PAGA plot with velocity using scv.pl.paga.
350
+
351
+ Shows partition-based graph abstraction with directed velocity arrows.
352
+
353
+ Data requirements:
354
+ - adata.uns['velocity_graph']: Velocity transition graph
355
+ - adata.uns['paga']: PAGA results (computed by scv.tl.paga)
356
+ - adata.obs[cluster_key]: Cluster labels used for PAGA
357
+ """
358
+ require("scvelo", feature="velocity PAGA plot")
359
+ import scvelo as scv
360
+
361
+ if "velocity_graph" not in adata.uns:
362
+ raise DataNotFoundError("velocity_graph required. Run velocity analysis first.")
363
+
364
+ cluster_key = params.cluster_key
365
+ if not cluster_key:
366
+ if "paga" in adata.uns and "groups" in adata.uns.get("paga", {}):
367
+ cluster_key = adata.uns["paga"].get("groups")
368
+ else:
369
+ categorical_cols = get_categorical_columns(adata)
370
+ if categorical_cols:
371
+ cluster_key = categorical_cols[0]
372
+
373
+ if not cluster_key or cluster_key not in adata.obs.columns:
374
+ raise ParameterError(
375
+ f"cluster_key is required for PAGA plot. "
376
+ f"Available columns: {list(adata.obs.columns)[:10]}"
377
+ )
378
+
379
+ # Compute PAGA if not already done
380
+ if "paga" not in adata.uns:
381
+ if context:
382
+ await context.info(f"Computing PAGA for cluster_key='{cluster_key}'")
383
+ import scanpy as sc
384
+
385
+ sc.tl.paga(adata, groups=cluster_key)
386
+ scv.tl.paga(adata, groups=cluster_key)
387
+
388
+ if context:
389
+ await context.info(f"Creating velocity PAGA plot for '{cluster_key}'")
390
+
391
+ basis = infer_basis(adata, preferred=params.basis, priority=["umap", "spatial"])
392
+ fig, axes = create_figure_from_params(params, "velocity")
393
+ ax = axes[0]
394
+
395
+ scv.pl.paga(
396
+ adata,
397
+ basis=basis,
398
+ color=cluster_key,
399
+ ax=ax,
400
+ show=False,
401
+ frameon=params.show_axes,
402
+ )
403
+
404
+ if params.title:
405
+ ax.set_title(params.title, fontsize=14)
406
+
407
+ if basis == "spatial":
408
+ ax.invert_yaxis()
409
+
410
+ plt.tight_layout()
411
+ return fig
@@ -0,0 +1,115 @@
1
+ """
2
+ Utility functions for spatial transcriptomics data analysis.
3
+ """
4
+
5
+ from .adata_utils import ( # Constants; Field discovery; Data access; Validation; Ensure; Standardization
6
+ ALTERNATIVE_BATCH_KEYS,
7
+ ALTERNATIVE_CELL_TYPE_KEYS,
8
+ ALTERNATIVE_CLUSTER_KEYS,
9
+ ALTERNATIVE_SPATIAL_KEYS,
10
+ BATCH_KEY,
11
+ CELL_TYPE_KEY,
12
+ CLUSTER_KEY,
13
+ SPATIAL_KEY,
14
+ ensure_categorical,
15
+ ensure_counts_layer,
16
+ find_common_genes,
17
+ get_analysis_parameter,
18
+ get_batch_key,
19
+ get_cell_type_key,
20
+ get_cluster_key,
21
+ get_gene_expression,
22
+ get_genes_expression,
23
+ get_spatial_key,
24
+ standardize_adata,
25
+ to_dense,
26
+ validate_adata,
27
+ validate_adata_basics,
28
+ validate_gene_overlap,
29
+ validate_obs_column,
30
+ validate_var_column,
31
+ )
32
+ from .dependency_manager import (
33
+ DependencyInfo,
34
+ get,
35
+ is_available,
36
+ require,
37
+ validate_r_environment,
38
+ validate_scvi_tools,
39
+ )
40
+ from .device_utils import (
41
+ cuda_available,
42
+ get_device,
43
+ get_ot_backend,
44
+ mps_available,
45
+ resolve_device_async,
46
+ )
47
+ from .exceptions import (
48
+ ChatSpatialError,
49
+ DataCompatibilityError,
50
+ DataError,
51
+ DataNotFoundError,
52
+ DependencyError,
53
+ ParameterError,
54
+ ProcessingError,
55
+ )
56
+ from .mcp_utils import mcp_tool_error_handler, suppress_output
57
+
58
+ __all__ = [
59
+ # Exceptions
60
+ "ChatSpatialError",
61
+ "DataError",
62
+ "DataNotFoundError",
63
+ "DataCompatibilityError",
64
+ "ParameterError",
65
+ "ProcessingError",
66
+ "DependencyError",
67
+ # MCP utilities
68
+ "suppress_output",
69
+ "mcp_tool_error_handler",
70
+ # Constants
71
+ "SPATIAL_KEY",
72
+ "CELL_TYPE_KEY",
73
+ "CLUSTER_KEY",
74
+ "BATCH_KEY",
75
+ "ALTERNATIVE_SPATIAL_KEYS",
76
+ "ALTERNATIVE_CELL_TYPE_KEYS",
77
+ "ALTERNATIVE_CLUSTER_KEYS",
78
+ "ALTERNATIVE_BATCH_KEYS",
79
+ # Field discovery
80
+ "get_analysis_parameter",
81
+ "get_batch_key",
82
+ "get_cell_type_key",
83
+ "get_cluster_key",
84
+ "get_spatial_key",
85
+ # Expression extraction
86
+ "to_dense",
87
+ "get_gene_expression",
88
+ "get_genes_expression",
89
+ # Validation
90
+ "validate_adata",
91
+ "validate_obs_column",
92
+ "validate_var_column",
93
+ "validate_adata_basics",
94
+ "validate_gene_overlap",
95
+ "ensure_categorical",
96
+ # Gene overlap
97
+ "find_common_genes",
98
+ # Ensure
99
+ "ensure_counts_layer",
100
+ # Standardization
101
+ "standardize_adata",
102
+ # Dependency management
103
+ "DependencyInfo",
104
+ "require",
105
+ "get",
106
+ "is_available",
107
+ "validate_r_environment",
108
+ "validate_scvi_tools",
109
+ # Device utilities
110
+ "cuda_available",
111
+ "mps_available",
112
+ "get_device",
113
+ "resolve_device_async",
114
+ "get_ot_backend",
115
+ ]