chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ """
2
+ Batch integration visualization functions.
3
+
4
+ This module contains:
5
+ - Batch integration quality assessment visualization
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Optional
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from scipy.stats import entropy
13
+
14
+ from ...models.data import VisualizationParameters
15
+ from ...utils.adata_utils import get_spatial_key, validate_obs_column
16
+
17
+ if TYPE_CHECKING:
18
+ import anndata as ad
19
+
20
+ from ...spatial_mcp_adapter import ToolContext
21
+
22
+
23
+ # =============================================================================
24
+ # Batch Integration Visualization
25
+ # =============================================================================
26
+
27
+
28
+ async def create_batch_integration_visualization(
29
+ adata: "ad.AnnData",
30
+ params: VisualizationParameters,
31
+ context: Optional["ToolContext"] = None,
32
+ ) -> plt.Figure:
33
+ """Create multi-panel visualization to assess batch integration quality.
34
+
35
+ This visualization is specifically for evaluating the quality of batch correction
36
+ after integrating multiple samples. It requires proper batch information.
37
+
38
+ Args:
39
+ adata: AnnData object with integrated samples
40
+ params: Visualization parameters (batch_key required)
41
+ context: MCP context for logging
42
+
43
+ Returns:
44
+ matplotlib Figure object
45
+
46
+ Raises:
47
+ DataNotFoundError: If batch information not found
48
+ """
49
+ if context:
50
+ await context.info("Creating batch integration quality visualization")
51
+
52
+ # Validate batch key exists
53
+ batch_key = params.batch_key
54
+ validate_obs_column(adata, batch_key, "Batch")
55
+
56
+ # Create multi-panel figure (2x2 layout)
57
+ figsize = params.figure_size if params.figure_size else (16, 12)
58
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
59
+
60
+ batch_values = adata.obs[batch_key]
61
+ unique_batches = batch_values.unique()
62
+ colors = plt.get_cmap("Set3")(np.linspace(0, 1, len(unique_batches)))
63
+
64
+ # Panel 1: UMAP colored by batch (shows mixing)
65
+ if "X_umap" in adata.obsm:
66
+ umap_coords = adata.obsm["X_umap"]
67
+
68
+ for i, batch in enumerate(unique_batches):
69
+ mask = batch_values == batch
70
+ axes[0, 0].scatter(
71
+ umap_coords[mask, 0],
72
+ umap_coords[mask, 1],
73
+ c=[colors[i]],
74
+ label=f"{batch}",
75
+ s=5,
76
+ alpha=0.7,
77
+ )
78
+
79
+ axes[0, 0].set_title(
80
+ "UMAP colored by batch\n(Good integration = mixed colors)", fontsize=12
81
+ )
82
+ axes[0, 0].set_xlabel("UMAP 1")
83
+ axes[0, 0].set_ylabel("UMAP 2")
84
+ axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
85
+ else:
86
+ if context:
87
+ await context.warning(
88
+ "UMAP coordinates not available. "
89
+ "Run preprocessing with UMAP computation for complete visualization."
90
+ )
91
+ axes[0, 0].text(
92
+ 0.5, 0.5, "UMAP coordinates not available", ha="center", va="center"
93
+ )
94
+ axes[0, 0].set_title("UMAP (Not Available)", fontsize=12)
95
+
96
+ # Panel 2: Spatial plot colored by batch (if spatial data available)
97
+ spatial_key = get_spatial_key(adata)
98
+ if spatial_key:
99
+ spatial_coords = adata.obsm[spatial_key]
100
+
101
+ for i, batch in enumerate(unique_batches):
102
+ mask = batch_values == batch
103
+ axes[0, 1].scatter(
104
+ spatial_coords[mask, 0],
105
+ spatial_coords[mask, 1],
106
+ c=[colors[i]],
107
+ label=f"{batch}",
108
+ s=10,
109
+ alpha=0.7,
110
+ )
111
+
112
+ axes[0, 1].set_title("Spatial coordinates colored by batch", fontsize=12)
113
+ axes[0, 1].set_xlabel("Spatial X")
114
+ axes[0, 1].set_ylabel("Spatial Y")
115
+ axes[0, 1].set_aspect("equal")
116
+ axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
117
+ else:
118
+ if context:
119
+ await context.info(
120
+ "Spatial coordinates not available. "
121
+ "This is expected for non-spatial datasets."
122
+ )
123
+ axes[0, 1].text(
124
+ 0.5, 0.5, "Spatial coordinates not available", ha="center", va="center"
125
+ )
126
+ axes[0, 1].set_title("Spatial (Not Available)", fontsize=12)
127
+
128
+ # Panel 3: Batch composition bar plot
129
+ batch_counts = adata.obs[batch_key].value_counts()
130
+ axes[1, 0].bar(
131
+ range(len(batch_counts)),
132
+ batch_counts.values,
133
+ color=colors[: len(batch_counts)],
134
+ )
135
+ axes[1, 0].set_xticks(range(len(batch_counts)))
136
+ axes[1, 0].set_xticklabels(batch_counts.index, rotation=45, ha="right")
137
+ axes[1, 0].set_title("Cell counts per batch", fontsize=12)
138
+ axes[1, 0].set_ylabel("Number of cells")
139
+
140
+ # Panel 4: Integration quality metrics (if available)
141
+ axes[1, 1].text(
142
+ 0.1,
143
+ 0.9,
144
+ "Integration Quality Assessment:",
145
+ fontsize=14,
146
+ fontweight="bold",
147
+ transform=axes[1, 1].transAxes,
148
+ )
149
+
150
+ metrics_text = f"Total cells: {adata.n_obs:,}\n"
151
+ metrics_text += f"Total genes: {adata.n_vars:,}\n"
152
+ metrics_text += (
153
+ f"Batches: {len(unique_batches)} ({', '.join(map(str, unique_batches))})\n\n"
154
+ )
155
+
156
+ if params.integration_method:
157
+ metrics_text += f"Integration method: {params.integration_method}\n"
158
+
159
+ # Add basic mixing metrics
160
+ if "X_umap" in adata.obsm:
161
+ # Calculate simple mixing metric (entropy)
162
+ umap_coords = adata.obsm["X_umap"]
163
+ x_bins = np.linspace(umap_coords[:, 0].min(), umap_coords[:, 0].max(), 10)
164
+ y_bins = np.linspace(umap_coords[:, 1].min(), umap_coords[:, 1].max(), 10)
165
+
166
+ entropies = []
167
+ for i in range(len(x_bins) - 1):
168
+ for j in range(len(y_bins) - 1):
169
+ mask = (
170
+ (umap_coords[:, 0] >= x_bins[i])
171
+ & (umap_coords[:, 0] < x_bins[i + 1])
172
+ & (umap_coords[:, 1] >= y_bins[j])
173
+ & (umap_coords[:, 1] < y_bins[j + 1])
174
+ )
175
+ if mask.sum() > 10: # Only consider regions with enough cells
176
+ batch_props = adata.obs[batch_key][mask].value_counts(
177
+ normalize=True
178
+ )
179
+ entropies.append(entropy(batch_props))
180
+
181
+ if entropies:
182
+ avg_entropy = np.mean(entropies)
183
+ max_entropy = np.log(len(unique_batches)) # Perfect mixing entropy
184
+ mixing_score = avg_entropy / max_entropy if max_entropy > 0 else 0
185
+ metrics_text += (
186
+ f"Mixing score: {mixing_score:.3f} (0=segregated, 1=perfectly mixed)\n"
187
+ )
188
+
189
+ axes[1, 1].text(
190
+ 0.1,
191
+ 0.7,
192
+ metrics_text,
193
+ fontsize=10,
194
+ transform=axes[1, 1].transAxes,
195
+ verticalalignment="top",
196
+ fontfamily="monospace",
197
+ )
198
+ axes[1, 1].set_xlim(0, 1)
199
+ axes[1, 1].set_ylim(0, 1)
200
+ axes[1, 1].set_xticks([])
201
+ axes[1, 1].set_yticks([])
202
+ axes[1, 1].set_title("Integration Metrics", fontsize=12)
203
+
204
+ plt.tight_layout()
205
+ return fig
@@ -0,0 +1,164 @@
1
+ """
2
+ Main visualization entry point.
3
+
4
+ This module contains the main visualize_data function that dispatches
5
+ to appropriate visualization handlers based on plot_type.
6
+ """
7
+
8
+ import traceback
9
+ from typing import TYPE_CHECKING, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import scanpy as sc
13
+ from mcp.server.fastmcp.utilities.types import ImageContent
14
+ from mcp.types import EmbeddedResource
15
+
16
+ from ...models.data import VisualizationParameters
17
+ from ...utils.exceptions import (
18
+ DataCompatibilityError,
19
+ DataNotFoundError,
20
+ ParameterError,
21
+ ProcessingError,
22
+ )
23
+ from ...utils.image_utils import optimize_fig_to_image_with_cache
24
+
25
+ # Import all visualization handlers
26
+ from .basic import (
27
+ create_dotplot_visualization,
28
+ create_heatmap_visualization,
29
+ create_spatial_visualization,
30
+ create_umap_visualization,
31
+ create_violin_visualization,
32
+ )
33
+ from .cell_comm import create_cell_communication_visualization
34
+ from .cnv import create_cnv_heatmap_visualization, create_spatial_cnv_visualization
35
+ from .deconvolution import (
36
+ create_card_imputation_visualization,
37
+ create_deconvolution_visualization,
38
+ )
39
+ from .enrichment import create_pathway_enrichment_visualization
40
+ from .integration import create_batch_integration_visualization
41
+ from .multi_gene import (
42
+ create_gene_correlation_visualization,
43
+ create_lr_pairs_visualization,
44
+ create_multi_gene_visualization,
45
+ create_spatial_interaction_visualization,
46
+ )
47
+ from .spatial_stats import create_spatial_statistics_visualization
48
+ from .trajectory import create_trajectory_visualization
49
+ from .velocity import create_rna_velocity_visualization
50
+
51
+ if TYPE_CHECKING:
52
+ from ...spatial_mcp_adapter import ToolContext
53
+
54
+
55
+ # Handler registry for dispatch - defined here to avoid circular imports
56
+ PLOT_HANDLERS = {
57
+ # Basic plots
58
+ "spatial": create_spatial_visualization,
59
+ "umap": create_umap_visualization,
60
+ "heatmap": create_heatmap_visualization,
61
+ "violin": create_violin_visualization,
62
+ "dotplot": create_dotplot_visualization,
63
+ # Analysis-specific plots
64
+ "deconvolution": create_deconvolution_visualization,
65
+ "cell_communication": create_cell_communication_visualization,
66
+ "rna_velocity": create_rna_velocity_visualization,
67
+ "trajectory": create_trajectory_visualization,
68
+ "spatial_statistics": create_spatial_statistics_visualization,
69
+ "pathway_enrichment": create_pathway_enrichment_visualization,
70
+ # CNV plots
71
+ "card_imputation": create_card_imputation_visualization,
72
+ "spatial_cnv": create_spatial_cnv_visualization,
73
+ "cnv_heatmap": create_cnv_heatmap_visualization,
74
+ # Integration plots
75
+ "batch_integration": create_batch_integration_visualization,
76
+ # Multi-gene plots
77
+ "multi_gene": create_multi_gene_visualization,
78
+ "lr_pairs": create_lr_pairs_visualization,
79
+ "gene_correlation": create_gene_correlation_visualization,
80
+ "spatial_interaction": create_spatial_interaction_visualization,
81
+ }
82
+
83
+
84
+ async def visualize_data(
85
+ data_id: str,
86
+ ctx: "ToolContext",
87
+ params: VisualizationParameters = VisualizationParameters(), # type: ignore[call-arg]
88
+ ) -> Union[ImageContent, tuple[ImageContent, EmbeddedResource]]:
89
+ """Visualize spatial transcriptomics data.
90
+
91
+ Args:
92
+ data_id: Dataset ID
93
+ ctx: ToolContext for unified data access and logging
94
+ params: Visualization parameters
95
+
96
+ Returns:
97
+ Union[ImageContent, Tuple[ImageContent, EmbeddedResource]]:
98
+ - Small images (<100KB): ImageContent object
99
+ - Large images (>=100KB): Tuple[Preview ImageContent, High-quality Resource]
100
+
101
+ Raises:
102
+ DataNotFoundError: If the dataset is not found
103
+ ParameterError: If parameters are invalid
104
+ DataCompatibilityError: If data is not compatible with the visualization
105
+ ProcessingError: If processing fails
106
+ """
107
+ # Validate parameters - use PLOT_HANDLERS as single source of truth
108
+ if params.plot_type not in PLOT_HANDLERS:
109
+ raise ParameterError(
110
+ f"Invalid plot_type: {params.plot_type}. "
111
+ f"Must be one of {list(PLOT_HANDLERS)}"
112
+ )
113
+
114
+ try:
115
+ # Retrieve the AnnData object via ToolContext
116
+ adata = await ctx.get_adata(data_id)
117
+
118
+ # Validate AnnData object - basic validation
119
+ if adata.n_obs < 5:
120
+ raise DataNotFoundError("Dataset has too few cells (minimum 5 required)")
121
+ if adata.n_vars < 5:
122
+ raise DataNotFoundError("Dataset has too few genes (minimum 5 required)")
123
+
124
+ # Set matplotlib style for better visualizations
125
+ sc.settings.set_figure_params(dpi=params.dpi or 100, facecolor="white")
126
+
127
+ # Dispatch to appropriate handler
128
+ handler = PLOT_HANDLERS[params.plot_type]
129
+ fig = await handler(adata, params, ctx._mcp_context)
130
+
131
+ # Generate plot_type_key with subtype if applicable (for cache consistency)
132
+ subtype = params.subtype
133
+ plot_type_key = f"{params.plot_type}_{subtype}" if subtype else params.plot_type
134
+
135
+ # Use the optimized conversion function
136
+ return await optimize_fig_to_image_with_cache(
137
+ fig,
138
+ params,
139
+ ctx._mcp_context,
140
+ data_id=data_id,
141
+ plot_type=plot_type_key,
142
+ mode="auto",
143
+ )
144
+
145
+ except Exception as e:
146
+ # Make sure to close any open figures in case of error
147
+ plt.close("all")
148
+
149
+ # For image conversion errors, return error message as string
150
+ if "fig_to_image" in str(e) or "convert" in str(e).lower():
151
+ error_details = traceback.format_exc()
152
+ return (
153
+ f"Error in {params.plot_type} visualization:\n\n"
154
+ f"{e}\n\n"
155
+ f"Technical details:\n{error_details}"
156
+ )
157
+
158
+ # Wrap the error in a more informative exception
159
+ if isinstance(e, (DataNotFoundError, ParameterError, DataCompatibilityError)):
160
+ raise
161
+ else:
162
+ raise ProcessingError(
163
+ f"Failed to create {params.plot_type} visualization: {e}"
164
+ ) from e