chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,469 @@
1
+ """
2
+ Spatial statistics visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - Neighborhood enrichment heatmaps
6
+ - Co-occurrence plots
7
+ - Ripley's function visualizations
8
+ - Moran's I scatter plots
9
+ - Centrality scores
10
+ - Getis-Ord Gi* hotspot maps
11
+ """
12
+
13
+ from typing import TYPE_CHECKING, Optional
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+ if TYPE_CHECKING:
19
+ import anndata as ad
20
+
21
+ from ...spatial_mcp_adapter import ToolContext
22
+
23
+ from ...models.data import VisualizationParameters
24
+ from ...utils.adata_utils import get_analysis_parameter, require_spatial_coords
25
+ from ...utils.dependency_manager import require
26
+ from ...utils.exceptions import DataNotFoundError, ParameterError
27
+ from .core import (
28
+ create_figure_from_params,
29
+ get_categorical_columns,
30
+ resolve_figure_size,
31
+ setup_multi_panel_figure,
32
+ )
33
+
34
+
35
+ def _resolve_cluster_key(
36
+ adata: "ad.AnnData",
37
+ analysis_type: str,
38
+ params_cluster_key: Optional[str],
39
+ ) -> str:
40
+ """Resolve cluster_key from params or stored metadata.
41
+
42
+ Priority:
43
+ 1. User-provided cluster_key (params_cluster_key)
44
+ 2. cluster_key from analysis metadata
45
+
46
+ Args:
47
+ adata: AnnData object
48
+ analysis_type: Analysis type (e.g., "neighborhood", "co_occurrence")
49
+ params_cluster_key: User-provided cluster_key
50
+
51
+ Returns:
52
+ Resolved cluster_key
53
+
54
+ Raises:
55
+ ParameterError: If no cluster_key can be determined
56
+ """
57
+ cluster_key = params_cluster_key or get_analysis_parameter(
58
+ adata, f"spatial_stats_{analysis_type}", "cluster_key"
59
+ )
60
+ if not cluster_key:
61
+ categorical_cols = get_categorical_columns(adata, limit=10)
62
+ raise ParameterError(
63
+ f"cluster_key required for {analysis_type} visualization. "
64
+ f"Available categorical columns: {', '.join(categorical_cols)}"
65
+ )
66
+ return cluster_key
67
+
68
+ # =============================================================================
69
+ # Main Router
70
+ # =============================================================================
71
+
72
+
73
+ async def create_spatial_statistics_visualization(
74
+ adata: "ad.AnnData",
75
+ params: VisualizationParameters,
76
+ context: Optional["ToolContext"] = None,
77
+ ) -> plt.Figure:
78
+ """Create spatial statistics visualization based on subtype.
79
+
80
+ Args:
81
+ adata: AnnData object with spatial statistics results
82
+ params: Visualization parameters including subtype
83
+ context: MCP context
84
+
85
+ Returns:
86
+ Matplotlib figure with spatial statistics visualization
87
+
88
+ Subtypes:
89
+ - neighborhood: Neighborhood enrichment heatmap
90
+ - co_occurrence: Co-occurrence analysis plot
91
+ - ripley: Ripley's K/L function curves
92
+ - moran: Moran's I scatter plot
93
+ - centrality: Graph centrality scores
94
+ - getis_ord: Getis-Ord Gi* hotspot/coldspot maps
95
+ """
96
+ subtype = params.subtype or "neighborhood"
97
+
98
+ if context:
99
+ await context.info(f"Creating {subtype} spatial statistics visualization")
100
+
101
+ if subtype == "neighborhood":
102
+ return await _create_neighborhood_enrichment_visualization(
103
+ adata, params, context
104
+ )
105
+ elif subtype == "co_occurrence":
106
+ return await _create_co_occurrence_visualization(adata, params, context)
107
+ elif subtype == "ripley":
108
+ return await _create_ripley_visualization(adata, params, context)
109
+ elif subtype == "moran":
110
+ return _create_moran_visualization(adata, params, context)
111
+ elif subtype == "centrality":
112
+ return await _create_centrality_visualization(adata, params, context)
113
+ elif subtype == "getis_ord":
114
+ return await _create_getis_ord_visualization(adata, params, context)
115
+ else:
116
+ raise ParameterError(
117
+ f"Unsupported subtype for spatial_statistics: '{subtype}'. "
118
+ f"Available subtypes: neighborhood, co_occurrence, ripley, moran, "
119
+ f"centrality, getis_ord"
120
+ )
121
+
122
+
123
+ # =============================================================================
124
+ # Visualization Functions
125
+ # =============================================================================
126
+
127
+
128
+ async def _create_neighborhood_enrichment_visualization(
129
+ adata: "ad.AnnData",
130
+ params: VisualizationParameters,
131
+ context: Optional["ToolContext"] = None,
132
+ ) -> plt.Figure:
133
+ """Create neighborhood enrichment visualization using squidpy.
134
+
135
+ Data requirements:
136
+ - adata.uns['{cluster_key}_nhood_enrichment']: Enrichment results
137
+ - adata.obs[cluster_key]: Cluster labels
138
+ """
139
+ require("squidpy", feature="neighborhood enrichment visualization")
140
+ import squidpy as sq
141
+
142
+ cluster_key = _resolve_cluster_key(adata, "neighborhood", params.cluster_key)
143
+
144
+ enrichment_key = f"{cluster_key}_nhood_enrichment"
145
+ if enrichment_key not in adata.uns:
146
+ raise DataNotFoundError(
147
+ f"Neighborhood enrichment not found. Run analyze_spatial_statistics "
148
+ f"with cluster_key='{cluster_key}' first."
149
+ )
150
+
151
+ fig, axes = create_figure_from_params(params, "spatial")
152
+ ax = axes[0]
153
+
154
+ sq.pl.nhood_enrichment(
155
+ adata,
156
+ cluster_key=cluster_key,
157
+ cmap=params.colormap or "coolwarm",
158
+ ax=ax,
159
+ title=params.title or f"Neighborhood Enrichment ({cluster_key})",
160
+ )
161
+
162
+ plt.tight_layout()
163
+ return fig
164
+
165
+
166
+ async def _create_co_occurrence_visualization(
167
+ adata: "ad.AnnData",
168
+ params: VisualizationParameters,
169
+ context: Optional["ToolContext"] = None,
170
+ ) -> plt.Figure:
171
+ """Create co-occurrence analysis visualization using squidpy.
172
+
173
+ Data requirements:
174
+ - adata.uns['{cluster_key}_co_occurrence']: Co-occurrence results
175
+ - adata.obs[cluster_key]: Cluster labels
176
+ """
177
+ require("squidpy", feature="co-occurrence visualization")
178
+ import squidpy as sq
179
+
180
+ cluster_key = _resolve_cluster_key(adata, "co_occurrence", params.cluster_key)
181
+
182
+ co_occurrence_key = f"{cluster_key}_co_occurrence"
183
+ if co_occurrence_key not in adata.uns:
184
+ raise DataNotFoundError(
185
+ f"Co-occurrence not found. Run analyze_spatial_statistics "
186
+ f"with cluster_key='{cluster_key}' first."
187
+ )
188
+
189
+ categories = adata.obs[cluster_key].cat.categories.tolist()
190
+ clusters_to_show = categories[: min(4, len(categories))]
191
+
192
+ figsize = resolve_figure_size(params, "heatmap")
193
+
194
+ sq.pl.co_occurrence(
195
+ adata,
196
+ cluster_key=cluster_key,
197
+ clusters=clusters_to_show,
198
+ figsize=figsize,
199
+ dpi=params.dpi,
200
+ )
201
+
202
+ fig = plt.gcf()
203
+ if params.title:
204
+ fig.suptitle(params.title)
205
+
206
+ plt.tight_layout()
207
+ return fig
208
+
209
+
210
+ async def _create_ripley_visualization(
211
+ adata: "ad.AnnData",
212
+ params: VisualizationParameters,
213
+ context: Optional["ToolContext"] = None,
214
+ ) -> plt.Figure:
215
+ """Create Ripley's function visualization using squidpy.
216
+
217
+ Data requirements:
218
+ - adata.uns['{cluster_key}_ripley_L']: Ripley's L function results
219
+ - adata.obs[cluster_key]: Cluster labels
220
+ """
221
+ require("squidpy", feature="Ripley visualization")
222
+ import squidpy as sq
223
+
224
+ cluster_key = _resolve_cluster_key(adata, "ripley", params.cluster_key)
225
+
226
+ ripley_key = f"{cluster_key}_ripley_L"
227
+ if ripley_key not in adata.uns:
228
+ raise DataNotFoundError(
229
+ f"Ripley results not found. Run analyze_spatial_statistics "
230
+ f"with cluster_key='{cluster_key}' and analysis_type='ripley' first."
231
+ )
232
+
233
+ fig, axes = create_figure_from_params(params, "spatial")
234
+ ax = axes[0]
235
+
236
+ sq.pl.ripley(adata, cluster_key=cluster_key, mode="L", plot_sims=True, ax=ax)
237
+
238
+ if params.title:
239
+ ax.set_title(params.title)
240
+
241
+ plt.tight_layout()
242
+ return fig
243
+
244
+
245
+ def _create_moran_visualization(
246
+ adata: "ad.AnnData",
247
+ params: VisualizationParameters,
248
+ context: Optional["ToolContext"] = None,
249
+ ) -> plt.Figure:
250
+ """Create Moran's I volcano-style visualization.
251
+
252
+ Shows -log10(p-value) vs Moran's I for spatially variable genes.
253
+ Color indicates Moran's I value (positive = clustered, negative = dispersed).
254
+
255
+ Data requirements:
256
+ - adata.uns['moranI']: DataFrame with I, pval_norm columns
257
+ """
258
+ if "moranI" not in adata.uns:
259
+ raise DataNotFoundError("Moran's I results not found. Expected key: moranI")
260
+
261
+ moran_data = adata.uns["moranI"]
262
+ moran_i = moran_data["I"].values
263
+ pvals = moran_data["pval_norm"].values
264
+
265
+ # Handle zero/negative p-values for log transform
266
+ pvals_safe = np.clip(pvals, 1e-300, 1.0)
267
+ neg_log_pval = -np.log10(pvals_safe)
268
+
269
+ fig, axes = create_figure_from_params(params, "spatial")
270
+ ax = axes[0]
271
+
272
+ # Color by Moran's I value (meaningful: positive=clustered, negative=dispersed)
273
+ scatter = ax.scatter(
274
+ neg_log_pval,
275
+ moran_i,
276
+ s=50,
277
+ alpha=params.alpha,
278
+ c=moran_i,
279
+ cmap="RdBu_r", # Diverging colormap centered at 0
280
+ vmin=-max(abs(moran_i.min()), abs(moran_i.max())),
281
+ vmax=max(abs(moran_i.min()), abs(moran_i.max())),
282
+ )
283
+
284
+ # Label top significant genes (high I and low p-value)
285
+ gene_names = moran_data.index.tolist()
286
+ sig_threshold = -np.log10(0.05)
287
+ significant_mask = (neg_log_pval > sig_threshold) & (moran_i > 0)
288
+
289
+ if np.any(significant_mask):
290
+ # Sort by combined score (high I * high significance)
291
+ scores = moran_i * neg_log_pval
292
+ top_indices = np.argsort(scores)[::-1][:10] # Top 10
293
+
294
+ for idx in top_indices:
295
+ if significant_mask[idx]:
296
+ ax.annotate(
297
+ gene_names[idx],
298
+ (neg_log_pval[idx], moran_i[idx]),
299
+ xytext=(5, 5),
300
+ textcoords="offset points",
301
+ fontsize=8,
302
+ alpha=0.8,
303
+ )
304
+
305
+ # Reference lines
306
+ ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5, label="I=0")
307
+ ax.axvline(x=sig_threshold, color="red", linestyle="--", alpha=0.5, label="p=0.05")
308
+
309
+ title = params.title or "Moran's I Spatial Autocorrelation"
310
+ ax.set_title(title, fontsize=14)
311
+ ax.set_xlabel("-log₁₀(p-value)", fontsize=12)
312
+ ax.set_ylabel("Moran's I", fontsize=12)
313
+
314
+ if params.show_colorbar:
315
+ cbar = plt.colorbar(scatter, ax=ax)
316
+ cbar.set_label("Moran's I (+ clustered, − dispersed)", fontsize=10)
317
+
318
+ # Add legend for reference lines
319
+ ax.legend(loc="upper left", fontsize=9)
320
+
321
+ plt.tight_layout()
322
+ return fig
323
+
324
+
325
+ async def _create_centrality_visualization(
326
+ adata: "ad.AnnData",
327
+ params: VisualizationParameters,
328
+ context: Optional["ToolContext"] = None,
329
+ ) -> plt.Figure:
330
+ """Create centrality scores visualization using squidpy.
331
+
332
+ Data requirements:
333
+ - adata.uns['{cluster_key}_centrality_scores']: Centrality scores
334
+ - adata.obs[cluster_key]: Cluster labels
335
+ """
336
+ require("squidpy", feature="centrality visualization")
337
+ import squidpy as sq
338
+
339
+ cluster_key = _resolve_cluster_key(adata, "centrality", params.cluster_key)
340
+
341
+ centrality_key = f"{cluster_key}_centrality_scores"
342
+ if centrality_key not in adata.uns:
343
+ raise DataNotFoundError(
344
+ f"Centrality scores not found. Run analyze_spatial_statistics "
345
+ f"with cluster_key='{cluster_key}' first."
346
+ )
347
+
348
+ figsize = resolve_figure_size(params, "spatial")
349
+
350
+ sq.pl.centrality_scores(
351
+ adata,
352
+ cluster_key=cluster_key,
353
+ figsize=figsize,
354
+ dpi=params.dpi,
355
+ )
356
+
357
+ fig = plt.gcf()
358
+ if params.title:
359
+ fig.suptitle(params.title)
360
+
361
+ plt.tight_layout()
362
+ return fig
363
+
364
+
365
+ async def _create_getis_ord_visualization(
366
+ adata: "ad.AnnData",
367
+ params: VisualizationParameters,
368
+ context: Optional["ToolContext"] = None,
369
+ ) -> plt.Figure:
370
+ """Create Getis-Ord Gi* hotspot/coldspot visualization.
371
+
372
+ Data requirements:
373
+ - adata.obs['{gene}_getis_ord_z']: Z-scores for each gene
374
+ - adata.obs['{gene}_getis_ord_p']: P-values for each gene
375
+ - adata.obsm['spatial']: Spatial coordinates
376
+ """
377
+ # Find genes with Getis-Ord results
378
+ getis_ord_genes = []
379
+ for col in adata.obs.columns:
380
+ if col.endswith("_getis_ord_z"):
381
+ gene = col.replace("_getis_ord_z", "")
382
+ if f"{gene}_getis_ord_p" in adata.obs.columns:
383
+ getis_ord_genes.append(gene)
384
+
385
+ if not getis_ord_genes:
386
+ raise DataNotFoundError("No Getis-Ord results found in adata.obs")
387
+
388
+ # Get genes to plot
389
+ feature_list = (
390
+ params.feature
391
+ if isinstance(params.feature, list)
392
+ else ([params.feature] if params.feature else [])
393
+ )
394
+ if feature_list:
395
+ genes_to_plot = [g for g in feature_list if g in getis_ord_genes]
396
+ else:
397
+ genes_to_plot = getis_ord_genes[:6]
398
+
399
+ if not genes_to_plot:
400
+ raise DataNotFoundError(
401
+ f"None of the specified genes have Getis-Ord results: {feature_list}"
402
+ )
403
+
404
+ if context:
405
+ await context.info(
406
+ f"Plotting Getis-Ord results for {len(genes_to_plot)} genes: {genes_to_plot}"
407
+ )
408
+
409
+ fig, axes = setup_multi_panel_figure(
410
+ n_panels=len(genes_to_plot),
411
+ params=params,
412
+ default_title="Getis-Ord Gi* Hotspots/Coldspots",
413
+ )
414
+
415
+ coords = require_spatial_coords(adata)
416
+
417
+ for i, gene in enumerate(genes_to_plot):
418
+ if i < len(axes):
419
+ ax = axes[i]
420
+ z_key = f"{gene}_getis_ord_z"
421
+ p_key = f"{gene}_getis_ord_p"
422
+
423
+ if z_key not in adata.obs or p_key not in adata.obs:
424
+ ax.text(
425
+ 0.5,
426
+ 0.5,
427
+ f"No Getis-Ord data for {gene}",
428
+ ha="center",
429
+ va="center",
430
+ transform=ax.transAxes,
431
+ )
432
+ ax.set_title(f"{gene} (No Data)")
433
+ continue
434
+
435
+ z_scores = adata.obs[z_key].values
436
+ p_vals = adata.obs[p_key].values
437
+
438
+ scatter = ax.scatter(
439
+ coords[:, 0],
440
+ coords[:, 1],
441
+ c=z_scores,
442
+ cmap="RdBu_r",
443
+ s=params.spot_size or 20,
444
+ alpha=params.alpha,
445
+ vmin=-3,
446
+ vmax=3,
447
+ )
448
+
449
+ if params.show_colorbar:
450
+ plt.colorbar(scatter, ax=ax, label="Gi* Z-score")
451
+
452
+ # Count significant hot and cold spots
453
+ alpha = 0.05
454
+ significant = p_vals < alpha
455
+ hot_spots = np.sum((z_scores > 0) & significant)
456
+ cold_spots = np.sum((z_scores < 0) & significant)
457
+
458
+ if params.add_gene_labels:
459
+ ax.set_title(f"{gene}\nHot: {hot_spots}, Cold: {cold_spots}")
460
+ else:
461
+ ax.set_title(f"{gene}")
462
+
463
+ ax.set_xlabel("Spatial X")
464
+ ax.set_ylabel("Spatial Y")
465
+ ax.set_aspect("equal")
466
+ ax.invert_yaxis()
467
+
468
+ plt.tight_layout(rect=(0, 0, 1, 0.95))
469
+ return fig