chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,393 @@
1
+ """
2
+ Basic visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - Spatial feature plots
6
+ - UMAP visualizations
7
+ - Heatmap visualizations
8
+ - Violin plots
9
+ - Dot plots
10
+ """
11
+
12
+ from typing import TYPE_CHECKING, Optional
13
+
14
+ import matplotlib.pyplot as plt
15
+ import pandas as pd
16
+ import scanpy as sc
17
+
18
+ if TYPE_CHECKING:
19
+ import anndata as ad
20
+
21
+ from ...spatial_mcp_adapter import ToolContext
22
+
23
+ from ...models.data import VisualizationParameters
24
+ from ...utils.adata_utils import (
25
+ ensure_categorical,
26
+ get_cluster_key,
27
+ get_gene_expression,
28
+ validate_obs_column,
29
+ )
30
+ from ...utils.compute import ensure_umap
31
+ from ...utils.exceptions import DataNotFoundError, ParameterError
32
+ from .core import (
33
+ add_colorbar,
34
+ create_figure,
35
+ get_colormap,
36
+ get_validated_features,
37
+ plot_spatial_feature,
38
+ setup_multi_panel_figure,
39
+ )
40
+
41
+ # =============================================================================
42
+ # Spatial Visualization
43
+ # =============================================================================
44
+
45
+
46
+ async def create_spatial_visualization(
47
+ adata: "ad.AnnData",
48
+ params: VisualizationParameters,
49
+ context: Optional["ToolContext"] = None,
50
+ ) -> plt.Figure:
51
+ """Create spatial visualization for one or more features.
52
+
53
+ Args:
54
+ adata: AnnData object with spatial coordinates
55
+ params: Visualization parameters
56
+ context: Optional tool context for logging
57
+
58
+ Returns:
59
+ matplotlib Figure object
60
+ """
61
+ if params.feature is None:
62
+ features: list[str] = []
63
+ elif isinstance(params.feature, list):
64
+ features = params.feature
65
+ else:
66
+ features = [params.feature]
67
+
68
+ if not features:
69
+ # Default to first available cluster key
70
+ default_cluster = get_cluster_key(adata)
71
+ if default_cluster:
72
+ features = [default_cluster]
73
+ else:
74
+ raise ParameterError(
75
+ "No features specified and no default clustering found"
76
+ )
77
+
78
+ n_features = len(features)
79
+ fig, axes = setup_multi_panel_figure(n_features, params, "")
80
+
81
+ for i, feature in enumerate(features):
82
+ ax = axes[i]
83
+ try:
84
+ mappable = plot_spatial_feature(
85
+ adata,
86
+ ax,
87
+ feature=feature,
88
+ params=params,
89
+ title=feature,
90
+ )
91
+ if mappable is not None and params.show_colorbar:
92
+ add_colorbar(fig, ax, mappable, params)
93
+ except Exception as e:
94
+ ax.text(
95
+ 0.5,
96
+ 0.5,
97
+ f"Error: {str(e)[:50]}",
98
+ ha="center",
99
+ va="center",
100
+ transform=ax.transAxes,
101
+ )
102
+ ax.axis("off")
103
+
104
+ plt.tight_layout()
105
+ return fig
106
+
107
+
108
+ # =============================================================================
109
+ # UMAP Visualization
110
+ # =============================================================================
111
+
112
+
113
+ async def create_umap_visualization(
114
+ adata: "ad.AnnData",
115
+ params: VisualizationParameters,
116
+ context: Optional["ToolContext"] = None,
117
+ ) -> plt.Figure:
118
+ """Create UMAP visualization.
119
+
120
+ Args:
121
+ adata: AnnData object
122
+ params: Visualization parameters
123
+ context: Optional tool context for logging
124
+
125
+ Returns:
126
+ matplotlib Figure object
127
+ """
128
+ if context:
129
+ feature_desc = params.feature if params.feature else "clusters"
130
+ await context.info(f"Creating UMAP plot for {feature_desc}")
131
+
132
+ # Ensure UMAP is computed (lazy computation)
133
+ if ensure_umap(adata) and context:
134
+ await context.info("Computed UMAP embedding")
135
+
136
+ # Determine what to color by
137
+ color_by = params.feature
138
+ if color_by is None:
139
+ # Default to first available cluster key
140
+ color_by = get_cluster_key(adata)
141
+
142
+ # Create figure
143
+ fig, ax = create_figure(params.figure_size or (10, 8))
144
+
145
+ # Get UMAP coordinates
146
+ umap_coords = adata.obsm["X_umap"]
147
+
148
+ # Get color values
149
+ spot_size = params.spot_size if params.spot_size is not None else 150.0
150
+ if color_by is None:
151
+ # No color - just plot points
152
+ ax.scatter(
153
+ umap_coords[:, 0],
154
+ umap_coords[:, 1],
155
+ s=spot_size // 3,
156
+ alpha=params.alpha,
157
+ c="steelblue",
158
+ )
159
+ elif color_by in adata.var_names:
160
+ # Gene expression - use unified utility
161
+ values = get_gene_expression(adata, color_by)
162
+
163
+ scatter = ax.scatter(
164
+ umap_coords[:, 0],
165
+ umap_coords[:, 1],
166
+ c=values,
167
+ cmap=params.colormap,
168
+ s=spot_size // 3,
169
+ alpha=params.alpha,
170
+ )
171
+ if params.show_colorbar:
172
+ add_colorbar(fig, ax, scatter, params, label=color_by)
173
+ elif color_by in adata.obs.columns:
174
+ # Observation column
175
+ values = adata.obs[color_by]
176
+ is_categorical = (
177
+ pd.api.types.is_categorical_dtype(values) or values.dtype == object
178
+ )
179
+
180
+ if is_categorical:
181
+ ensure_categorical(adata, color_by)
182
+ categories = adata.obs[color_by].cat.categories
183
+ n_cats = len(categories)
184
+ colors = get_colormap(params.colormap, n_colors=n_cats)
185
+
186
+ for i, cat in enumerate(categories):
187
+ mask = adata.obs[color_by] == cat
188
+ ax.scatter(
189
+ umap_coords[mask, 0],
190
+ umap_coords[mask, 1],
191
+ c=[colors[i]],
192
+ s=spot_size // 3,
193
+ alpha=params.alpha,
194
+ label=cat,
195
+ )
196
+
197
+ if params.show_legend:
198
+ ax.legend(
199
+ loc="center left",
200
+ bbox_to_anchor=(1, 0.5),
201
+ fontsize=8,
202
+ frameon=False,
203
+ )
204
+ else:
205
+ scatter = ax.scatter(
206
+ umap_coords[:, 0],
207
+ umap_coords[:, 1],
208
+ c=values,
209
+ cmap=params.colormap,
210
+ s=spot_size // 3,
211
+ alpha=params.alpha,
212
+ )
213
+ if params.show_colorbar:
214
+ add_colorbar(fig, ax, scatter, params, label=color_by)
215
+ else:
216
+ raise DataNotFoundError(f"Color feature '{color_by}' not found in genes or obs")
217
+
218
+ ax.set_xlabel("UMAP1")
219
+ ax.set_ylabel("UMAP2")
220
+ ax.set_title(params.title or f"UMAP - {color_by}")
221
+
222
+ if not params.show_axes:
223
+ ax.axis("off")
224
+
225
+ plt.tight_layout()
226
+ return fig
227
+
228
+
229
+ # =============================================================================
230
+ # Heatmap Visualization
231
+ # =============================================================================
232
+
233
+
234
+ async def create_heatmap_visualization(
235
+ adata: "ad.AnnData",
236
+ params: VisualizationParameters,
237
+ context: Optional["ToolContext"] = None,
238
+ ) -> plt.Figure:
239
+ """Create heatmap visualization for gene expression.
240
+
241
+ Args:
242
+ adata: AnnData object
243
+ params: Visualization parameters (requires cluster_key and feature list)
244
+ context: Optional tool context for logging
245
+
246
+ Returns:
247
+ matplotlib Figure object
248
+ """
249
+ if not params.cluster_key:
250
+ raise ParameterError("Heatmap requires cluster_key parameter")
251
+
252
+ validate_obs_column(adata, params.cluster_key, "Cluster")
253
+
254
+ features = await get_validated_features(adata, params, context, genes_only=True)
255
+ if not features:
256
+ raise ParameterError("No valid gene features provided for heatmap")
257
+
258
+ if context:
259
+ await context.info(
260
+ f"Creating heatmap for {len(features)} genes grouped by {params.cluster_key}"
261
+ )
262
+
263
+ # Use scanpy's heatmap function
264
+ # Note: return_fig=True causes issues with newer matplotlib versions
265
+ sc.pl.heatmap(
266
+ adata,
267
+ var_names=features,
268
+ groupby=params.cluster_key,
269
+ cmap=params.colormap,
270
+ show=False,
271
+ dendrogram=params.dotplot_dendrogram,
272
+ swap_axes=params.dotplot_swap_axes,
273
+ standard_scale=params.dotplot_standard_scale,
274
+ )
275
+ fig = plt.gcf()
276
+
277
+ return fig
278
+
279
+
280
+ # =============================================================================
281
+ # Violin Plot Visualization
282
+ # =============================================================================
283
+
284
+
285
+ async def create_violin_visualization(
286
+ adata: "ad.AnnData",
287
+ params: VisualizationParameters,
288
+ context: Optional["ToolContext"] = None,
289
+ ) -> plt.Figure:
290
+ """Create violin plot visualization for gene expression.
291
+
292
+ Args:
293
+ adata: AnnData object
294
+ params: Visualization parameters (requires cluster_key and feature)
295
+ context: Optional tool context for logging
296
+
297
+ Returns:
298
+ matplotlib Figure object
299
+ """
300
+ if not params.cluster_key:
301
+ raise ParameterError("Violin plot requires cluster_key parameter")
302
+
303
+ validate_obs_column(adata, params.cluster_key, "Cluster")
304
+
305
+ features = await get_validated_features(adata, params, context, genes_only=True)
306
+ if not features:
307
+ raise ParameterError("No valid gene features provided for violin plot")
308
+
309
+ if context:
310
+ await context.info(
311
+ f"Creating violin plot for {len(features)} genes grouped by {params.cluster_key}"
312
+ )
313
+
314
+ # Use scanpy's violin function
315
+ # Note: return_fig=True causes issues with newer matplotlib/seaborn versions
316
+ # Instead, we use plt.gcf() to get the current figure
317
+ sc.pl.violin(
318
+ adata,
319
+ keys=features,
320
+ groupby=params.cluster_key,
321
+ show=False,
322
+ )
323
+ fig = plt.gcf()
324
+
325
+ return fig
326
+
327
+
328
+ # =============================================================================
329
+ # Dot Plot Visualization
330
+ # =============================================================================
331
+
332
+
333
+ async def create_dotplot_visualization(
334
+ adata: "ad.AnnData",
335
+ params: VisualizationParameters,
336
+ context: Optional["ToolContext"] = None,
337
+ ) -> plt.Figure:
338
+ """Create dot plot visualization for gene expression.
339
+
340
+ Args:
341
+ adata: AnnData object
342
+ params: Visualization parameters (requires cluster_key and feature list)
343
+ context: Optional tool context for logging
344
+
345
+ Returns:
346
+ matplotlib Figure object
347
+ """
348
+ if not params.cluster_key:
349
+ raise ParameterError("Dot plot requires cluster_key parameter")
350
+
351
+ validate_obs_column(adata, params.cluster_key, "Cluster")
352
+
353
+ features = await get_validated_features(adata, params, context, genes_only=True)
354
+ if not features:
355
+ raise ParameterError("No valid gene features provided for dot plot")
356
+
357
+ if context:
358
+ await context.info(
359
+ f"Creating dot plot for {len(features)} genes grouped by {params.cluster_key}"
360
+ )
361
+
362
+ # Build kwargs for dotplot
363
+ # Note: return_fig=True returns DotPlot object, not Figure
364
+ # Use plt.gcf() instead to get the figure
365
+ dotplot_kwargs = {
366
+ "adata": adata,
367
+ "var_names": features,
368
+ "groupby": params.cluster_key,
369
+ "cmap": params.colormap,
370
+ "show": False,
371
+ }
372
+
373
+ # Add optional parameters
374
+ if params.dotplot_dendrogram:
375
+ dotplot_kwargs["dendrogram"] = True
376
+ if params.dotplot_swap_axes:
377
+ dotplot_kwargs["swap_axes"] = True
378
+ if params.dotplot_standard_scale:
379
+ dotplot_kwargs["standard_scale"] = params.dotplot_standard_scale
380
+ if params.dotplot_dot_min is not None:
381
+ dotplot_kwargs["dot_min"] = params.dotplot_dot_min
382
+ if params.dotplot_dot_max is not None:
383
+ dotplot_kwargs["dot_max"] = params.dotplot_dot_max
384
+ if params.dotplot_smallest_dot is not None:
385
+ dotplot_kwargs["smallest_dot"] = params.dotplot_smallest_dot
386
+ if params.dotplot_var_groups:
387
+ dotplot_kwargs["var_group_positions"] = list(params.dotplot_var_groups.keys())
388
+ dotplot_kwargs["var_group_labels"] = list(params.dotplot_var_groups.keys())
389
+
390
+ sc.pl.dotplot(**dotplot_kwargs)
391
+ fig = plt.gcf()
392
+
393
+ return fig