chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,320 @@
1
+ """
2
+ CNV (Copy Number Variation) visualization functions.
3
+
4
+ This module contains:
5
+ - CNV heatmap visualization
6
+ - Spatial CNV projection visualization
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, Any, Optional
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+
16
+ from ...models.data import VisualizationParameters
17
+ from ...utils.adata_utils import require_spatial_coords, validate_obs_column
18
+ from ...utils.dependency_manager import require
19
+ from ...utils.exceptions import DataNotFoundError
20
+ from .core import create_figure_from_params, plot_spatial_feature, resolve_figure_size
21
+
22
+ if TYPE_CHECKING:
23
+ import anndata as ad
24
+
25
+ from ...spatial_mcp_adapter import ToolContext
26
+
27
+
28
+ # =============================================================================
29
+ # Spatial CNV Visualization
30
+ # =============================================================================
31
+
32
+
33
+ async def create_spatial_cnv_visualization(
34
+ adata: "ad.AnnData",
35
+ params: VisualizationParameters,
36
+ context: Optional["ToolContext"] = None,
37
+ ) -> plt.Figure:
38
+ """Create spatial CNV projection visualization.
39
+
40
+ Uses the unified plot_spatial_feature() helper for cleaner code
41
+ and consistent parameter handling.
42
+
43
+ Args:
44
+ adata: AnnData object
45
+ params: Visualization parameters
46
+ context: MCP context for logging
47
+
48
+ Returns:
49
+ matplotlib Figure object
50
+
51
+ Raises:
52
+ DataNotFoundError: If spatial coordinates or CNV features not found
53
+ """
54
+ if context:
55
+ await context.info("Creating spatial CNV projection visualization")
56
+
57
+ # Validate spatial coordinates
58
+ require_spatial_coords(adata)
59
+
60
+ # Determine feature to visualize
61
+ feature_to_plot = params.feature
62
+
63
+ # Auto-detect CNV-related features if none specified
64
+ if not feature_to_plot:
65
+ if "numbat_clone" in adata.obs:
66
+ feature_to_plot = "numbat_clone"
67
+ if context:
68
+ await context.info(
69
+ "No feature specified, using 'numbat_clone' (Numbat clone assignment)"
70
+ )
71
+ elif "cnv_score" in adata.obs:
72
+ feature_to_plot = "cnv_score"
73
+ if context:
74
+ await context.info(
75
+ "No feature specified, using 'cnv_score' (CNV score)"
76
+ )
77
+ elif "numbat_p_cnv" in adata.obs:
78
+ feature_to_plot = "numbat_p_cnv"
79
+ if context:
80
+ await context.info(
81
+ "No feature specified, using 'numbat_p_cnv' (Numbat CNV probability)"
82
+ )
83
+ else:
84
+ error_msg = "No CNV features found. Run analyze_cnv() first."
85
+ if context:
86
+ await context.warning(error_msg)
87
+ raise DataNotFoundError(error_msg)
88
+
89
+ # Validate feature exists
90
+ validate_obs_column(adata, feature_to_plot, "CNV feature")
91
+
92
+ if context:
93
+ await context.info(f"Visualizing {feature_to_plot} on spatial coordinates")
94
+
95
+ # Override colormap default for CNV data (RdBu_r is better for CNV scores)
96
+ if not params.colormap:
97
+ params.colormap = (
98
+ "RdBu_r"
99
+ if not pd.api.types.is_categorical_dtype(adata.obs[feature_to_plot])
100
+ else "tab20"
101
+ )
102
+
103
+ # Use centralized figure creation
104
+ fig, axes = create_figure_from_params(params, "spatial")
105
+ ax = axes[0]
106
+
107
+ # Use the enhanced plot_spatial_feature helper
108
+ plot_spatial_feature(adata, ax, feature=feature_to_plot, params=params)
109
+
110
+ if context:
111
+ await context.info(f"Spatial CNV projection created for {feature_to_plot}")
112
+
113
+ return fig
114
+
115
+
116
+ # =============================================================================
117
+ # CNV Heatmap Visualization
118
+ # =============================================================================
119
+
120
+
121
+ async def create_cnv_heatmap_visualization(
122
+ adata: "ad.AnnData",
123
+ params: VisualizationParameters,
124
+ context: Optional["ToolContext"] = None,
125
+ ) -> plt.Figure:
126
+ """Create CNV heatmap visualization.
127
+
128
+ Args:
129
+ adata: AnnData object
130
+ params: Visualization parameters
131
+ context: MCP context for logging
132
+
133
+ Returns:
134
+ matplotlib Figure object
135
+
136
+ Raises:
137
+ DataNotFoundError: If CNV data not found
138
+ DataCompatibilityError: If infercnvpy not installed
139
+ """
140
+ if context:
141
+ await context.info("Creating CNV heatmap visualization")
142
+
143
+ # Auto-detect CNV data source (infercnvpy or Numbat)
144
+ cnv_method = None
145
+
146
+ if "X_cnv" in adata.obsm:
147
+ cnv_method = "infercnvpy"
148
+ elif "X_cnv_numbat" in adata.obsm:
149
+ cnv_method = "numbat"
150
+ else:
151
+ error_msg = "CNV data not found in obsm. Run analyze_cnv() first."
152
+ if context:
153
+ await context.warning(error_msg)
154
+ raise DataNotFoundError(error_msg)
155
+
156
+ if context:
157
+ await context.info(f"Detected CNV data from {cnv_method} method")
158
+
159
+ # Check if infercnvpy is available (needed for visualization)
160
+ require("infercnvpy", feature="CNV heatmap visualization")
161
+
162
+ # For Numbat data, temporarily copy to X_cnv for visualization
163
+ if cnv_method == "numbat":
164
+ if context:
165
+ await context.info(
166
+ "Converting Numbat CNV data to infercnvpy format for visualization"
167
+ )
168
+ adata.obsm["X_cnv"] = adata.obsm["X_cnv_numbat"]
169
+ # Also ensure cnv metadata exists for infercnvpy plotting
170
+ if "cnv" not in adata.uns:
171
+ adata.uns["cnv"] = {
172
+ "genomic_positions": False,
173
+ }
174
+ if context:
175
+ await context.info(
176
+ "Note: Chromosome labels not available for Numbat heatmap. "
177
+ "Install R packages for full chromosome annotation."
178
+ )
179
+
180
+ # Check if CNV metadata exists
181
+ if "cnv" not in adata.uns:
182
+ error_msg = (
183
+ "CNV metadata not found in adata.uns['cnv']. "
184
+ "The CNV analysis may not have completed properly. "
185
+ "Please re-run analyze_cnv()."
186
+ )
187
+ if context:
188
+ await context.warning(error_msg)
189
+ raise DataNotFoundError(error_msg)
190
+
191
+ # Create CNV heatmap
192
+ if context:
193
+ await context.info("Generating CNV heatmap...")
194
+
195
+ # Use centralized figure size resolution
196
+ figsize = resolve_figure_size(params, "cnv")
197
+
198
+ # For Numbat data without chromosome info, use aggregated heatmap by group
199
+ if cnv_method == "numbat" and "chromosome" not in adata.var.columns:
200
+ if context:
201
+ await context.info(
202
+ "Creating aggregated CNV heatmap by group (chromosome positions not available)"
203
+ )
204
+
205
+ # Get CNV matrix
206
+ cnv_matrix = adata.obsm["X_cnv"]
207
+
208
+ # Aggregate by feature (e.g., clone) for cleaner visualization
209
+ if params.feature and params.feature in adata.obs.columns:
210
+ # Group cells by feature and compute mean CNV per group
211
+ feature_values = adata.obs[params.feature]
212
+ unique_groups = sorted(feature_values.unique())
213
+
214
+ # Compute mean CNV for each group
215
+ aggregated_cnv_list: list[Any] = []
216
+ group_labels: list[str] = []
217
+ group_sizes: list[Any] = []
218
+
219
+ for group in unique_groups:
220
+ group_mask = feature_values == group
221
+ group_cnv = cnv_matrix[group_mask, :].mean(axis=0)
222
+ aggregated_cnv_list.append(group_cnv)
223
+ group_labels.append(str(group))
224
+ group_sizes.append(group_mask.sum())
225
+
226
+ aggregated_cnv = np.array(aggregated_cnv_list)
227
+
228
+ # Calculate appropriate figure width based on number of bins
229
+ n_bins = aggregated_cnv.shape[1]
230
+ fig_width = min(max(6, n_bins * 0.004), 12)
231
+ fig_height = max(4, len(unique_groups) * 1.2)
232
+
233
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
234
+
235
+ # Plot aggregated heatmap with fixed aspect ratio
236
+ im = ax.imshow(
237
+ aggregated_cnv,
238
+ cmap="RdBu_r",
239
+ aspect="auto",
240
+ vmin=-1,
241
+ vmax=1,
242
+ interpolation="nearest",
243
+ )
244
+
245
+ # Add colorbar
246
+ plt.colorbar(im, ax=ax, label="Mean CNV state")
247
+
248
+ # Set y-axis labels with group names and cell counts
249
+ ax.set_yticks(range(len(group_labels)))
250
+ ax.set_yticklabels(
251
+ [
252
+ f"{label} (n={size})"
253
+ for label, size in zip(group_labels, group_sizes, strict=False)
254
+ ]
255
+ )
256
+ feature_label = (
257
+ params.feature
258
+ if isinstance(params.feature, str)
259
+ else ", ".join(params.feature) if params.feature else ""
260
+ )
261
+ ax.set_ylabel(feature_label, fontsize=12, fontweight="bold")
262
+
263
+ # Set x-axis
264
+ ax.set_xlabel("Genomic position (binned)", fontsize=12)
265
+ ax.set_xticks([]) # Hide x-axis ticks for cleaner look
266
+
267
+ # Add title
268
+ ax.set_title(
269
+ f"CNV Profile by {params.feature}\n(Numbat analysis, aggregated by group)",
270
+ fontsize=14,
271
+ fontweight="bold",
272
+ )
273
+
274
+ # Add gridlines between groups
275
+ for i in range(len(group_labels) + 1):
276
+ ax.axhline(i - 0.5, color="white", linewidth=2)
277
+
278
+ else:
279
+ # No grouping - show warning and plot all cells (not recommended)
280
+ fig, ax = plt.subplots(figsize=figsize)
281
+
282
+ sns.heatmap(
283
+ cnv_matrix,
284
+ cmap="RdBu_r",
285
+ center=0,
286
+ cbar_kws={"label": "CNV state"},
287
+ yticklabels=False,
288
+ xticklabels=False,
289
+ ax=ax,
290
+ vmin=-1,
291
+ vmax=1,
292
+ )
293
+
294
+ ax.set_xlabel("Genomic position (binned)")
295
+ ax.set_ylabel("Cells")
296
+ ax.set_title("CNV Heatmap (Numbat)\nAll cells (ungrouped)")
297
+
298
+ plt.tight_layout()
299
+
300
+ else:
301
+ # Use infercnvpy chromosome_heatmap for infercnvpy data or Numbat with chr info
302
+ import infercnvpy as cnv
303
+
304
+ if context:
305
+ await context.info("Creating chromosome-organized CNV heatmap...")
306
+
307
+ cnv.pl.chromosome_heatmap(
308
+ adata,
309
+ groupby=params.cluster_key,
310
+ dendrogram=True,
311
+ show=False,
312
+ figsize=figsize,
313
+ )
314
+ # Get current figure
315
+ fig = plt.gcf()
316
+
317
+ if context:
318
+ await context.info("CNV heatmap created successfully")
319
+
320
+ return fig