chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch integration visualization functions.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Batch integration quality assessment visualization
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Optional
|
|
9
|
+
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
from scipy.stats import entropy
|
|
13
|
+
|
|
14
|
+
from ...models.data import VisualizationParameters
|
|
15
|
+
from ...utils.adata_utils import get_spatial_key, validate_obs_column
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import anndata as ad
|
|
19
|
+
|
|
20
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# =============================================================================
|
|
24
|
+
# Batch Integration Visualization
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def create_batch_integration_visualization(
|
|
29
|
+
adata: "ad.AnnData",
|
|
30
|
+
params: VisualizationParameters,
|
|
31
|
+
context: Optional["ToolContext"] = None,
|
|
32
|
+
) -> plt.Figure:
|
|
33
|
+
"""Create multi-panel visualization to assess batch integration quality.
|
|
34
|
+
|
|
35
|
+
This visualization is specifically for evaluating the quality of batch correction
|
|
36
|
+
after integrating multiple samples. It requires proper batch information.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
adata: AnnData object with integrated samples
|
|
40
|
+
params: Visualization parameters (batch_key required)
|
|
41
|
+
context: MCP context for logging
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
matplotlib Figure object
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
DataNotFoundError: If batch information not found
|
|
48
|
+
"""
|
|
49
|
+
if context:
|
|
50
|
+
await context.info("Creating batch integration quality visualization")
|
|
51
|
+
|
|
52
|
+
# Validate batch key exists
|
|
53
|
+
batch_key = params.batch_key
|
|
54
|
+
validate_obs_column(adata, batch_key, "Batch")
|
|
55
|
+
|
|
56
|
+
# Create multi-panel figure (2x2 layout)
|
|
57
|
+
figsize = params.figure_size if params.figure_size else (16, 12)
|
|
58
|
+
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
|
59
|
+
|
|
60
|
+
batch_values = adata.obs[batch_key]
|
|
61
|
+
unique_batches = batch_values.unique()
|
|
62
|
+
colors = plt.get_cmap("Set3")(np.linspace(0, 1, len(unique_batches)))
|
|
63
|
+
|
|
64
|
+
# Panel 1: UMAP colored by batch (shows mixing)
|
|
65
|
+
if "X_umap" in adata.obsm:
|
|
66
|
+
umap_coords = adata.obsm["X_umap"]
|
|
67
|
+
|
|
68
|
+
for i, batch in enumerate(unique_batches):
|
|
69
|
+
mask = batch_values == batch
|
|
70
|
+
axes[0, 0].scatter(
|
|
71
|
+
umap_coords[mask, 0],
|
|
72
|
+
umap_coords[mask, 1],
|
|
73
|
+
c=[colors[i]],
|
|
74
|
+
label=f"{batch}",
|
|
75
|
+
s=5,
|
|
76
|
+
alpha=0.7,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
axes[0, 0].set_title(
|
|
80
|
+
"UMAP colored by batch\n(Good integration = mixed colors)", fontsize=12
|
|
81
|
+
)
|
|
82
|
+
axes[0, 0].set_xlabel("UMAP 1")
|
|
83
|
+
axes[0, 0].set_ylabel("UMAP 2")
|
|
84
|
+
axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
85
|
+
else:
|
|
86
|
+
if context:
|
|
87
|
+
await context.warning(
|
|
88
|
+
"UMAP coordinates not available. "
|
|
89
|
+
"Run preprocessing with UMAP computation for complete visualization."
|
|
90
|
+
)
|
|
91
|
+
axes[0, 0].text(
|
|
92
|
+
0.5, 0.5, "UMAP coordinates not available", ha="center", va="center"
|
|
93
|
+
)
|
|
94
|
+
axes[0, 0].set_title("UMAP (Not Available)", fontsize=12)
|
|
95
|
+
|
|
96
|
+
# Panel 2: Spatial plot colored by batch (if spatial data available)
|
|
97
|
+
spatial_key = get_spatial_key(adata)
|
|
98
|
+
if spatial_key:
|
|
99
|
+
spatial_coords = adata.obsm[spatial_key]
|
|
100
|
+
|
|
101
|
+
for i, batch in enumerate(unique_batches):
|
|
102
|
+
mask = batch_values == batch
|
|
103
|
+
axes[0, 1].scatter(
|
|
104
|
+
spatial_coords[mask, 0],
|
|
105
|
+
spatial_coords[mask, 1],
|
|
106
|
+
c=[colors[i]],
|
|
107
|
+
label=f"{batch}",
|
|
108
|
+
s=10,
|
|
109
|
+
alpha=0.7,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
axes[0, 1].set_title("Spatial coordinates colored by batch", fontsize=12)
|
|
113
|
+
axes[0, 1].set_xlabel("Spatial X")
|
|
114
|
+
axes[0, 1].set_ylabel("Spatial Y")
|
|
115
|
+
axes[0, 1].set_aspect("equal")
|
|
116
|
+
axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
117
|
+
else:
|
|
118
|
+
if context:
|
|
119
|
+
await context.info(
|
|
120
|
+
"Spatial coordinates not available. "
|
|
121
|
+
"This is expected for non-spatial datasets."
|
|
122
|
+
)
|
|
123
|
+
axes[0, 1].text(
|
|
124
|
+
0.5, 0.5, "Spatial coordinates not available", ha="center", va="center"
|
|
125
|
+
)
|
|
126
|
+
axes[0, 1].set_title("Spatial (Not Available)", fontsize=12)
|
|
127
|
+
|
|
128
|
+
# Panel 3: Batch composition bar plot
|
|
129
|
+
batch_counts = adata.obs[batch_key].value_counts()
|
|
130
|
+
axes[1, 0].bar(
|
|
131
|
+
range(len(batch_counts)),
|
|
132
|
+
batch_counts.values,
|
|
133
|
+
color=colors[: len(batch_counts)],
|
|
134
|
+
)
|
|
135
|
+
axes[1, 0].set_xticks(range(len(batch_counts)))
|
|
136
|
+
axes[1, 0].set_xticklabels(batch_counts.index, rotation=45, ha="right")
|
|
137
|
+
axes[1, 0].set_title("Cell counts per batch", fontsize=12)
|
|
138
|
+
axes[1, 0].set_ylabel("Number of cells")
|
|
139
|
+
|
|
140
|
+
# Panel 4: Integration quality metrics (if available)
|
|
141
|
+
axes[1, 1].text(
|
|
142
|
+
0.1,
|
|
143
|
+
0.9,
|
|
144
|
+
"Integration Quality Assessment:",
|
|
145
|
+
fontsize=14,
|
|
146
|
+
fontweight="bold",
|
|
147
|
+
transform=axes[1, 1].transAxes,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
metrics_text = f"Total cells: {adata.n_obs:,}\n"
|
|
151
|
+
metrics_text += f"Total genes: {adata.n_vars:,}\n"
|
|
152
|
+
metrics_text += (
|
|
153
|
+
f"Batches: {len(unique_batches)} ({', '.join(map(str, unique_batches))})\n\n"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if params.integration_method:
|
|
157
|
+
metrics_text += f"Integration method: {params.integration_method}\n"
|
|
158
|
+
|
|
159
|
+
# Add basic mixing metrics
|
|
160
|
+
if "X_umap" in adata.obsm:
|
|
161
|
+
# Calculate simple mixing metric (entropy)
|
|
162
|
+
umap_coords = adata.obsm["X_umap"]
|
|
163
|
+
x_bins = np.linspace(umap_coords[:, 0].min(), umap_coords[:, 0].max(), 10)
|
|
164
|
+
y_bins = np.linspace(umap_coords[:, 1].min(), umap_coords[:, 1].max(), 10)
|
|
165
|
+
|
|
166
|
+
entropies = []
|
|
167
|
+
for i in range(len(x_bins) - 1):
|
|
168
|
+
for j in range(len(y_bins) - 1):
|
|
169
|
+
mask = (
|
|
170
|
+
(umap_coords[:, 0] >= x_bins[i])
|
|
171
|
+
& (umap_coords[:, 0] < x_bins[i + 1])
|
|
172
|
+
& (umap_coords[:, 1] >= y_bins[j])
|
|
173
|
+
& (umap_coords[:, 1] < y_bins[j + 1])
|
|
174
|
+
)
|
|
175
|
+
if mask.sum() > 10: # Only consider regions with enough cells
|
|
176
|
+
batch_props = adata.obs[batch_key][mask].value_counts(
|
|
177
|
+
normalize=True
|
|
178
|
+
)
|
|
179
|
+
entropies.append(entropy(batch_props))
|
|
180
|
+
|
|
181
|
+
if entropies:
|
|
182
|
+
avg_entropy = np.mean(entropies)
|
|
183
|
+
max_entropy = np.log(len(unique_batches)) # Perfect mixing entropy
|
|
184
|
+
mixing_score = avg_entropy / max_entropy if max_entropy > 0 else 0
|
|
185
|
+
metrics_text += (
|
|
186
|
+
f"Mixing score: {mixing_score:.3f} (0=segregated, 1=perfectly mixed)\n"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
axes[1, 1].text(
|
|
190
|
+
0.1,
|
|
191
|
+
0.7,
|
|
192
|
+
metrics_text,
|
|
193
|
+
fontsize=10,
|
|
194
|
+
transform=axes[1, 1].transAxes,
|
|
195
|
+
verticalalignment="top",
|
|
196
|
+
fontfamily="monospace",
|
|
197
|
+
)
|
|
198
|
+
axes[1, 1].set_xlim(0, 1)
|
|
199
|
+
axes[1, 1].set_ylim(0, 1)
|
|
200
|
+
axes[1, 1].set_xticks([])
|
|
201
|
+
axes[1, 1].set_yticks([])
|
|
202
|
+
axes[1, 1].set_title("Integration Metrics", fontsize=12)
|
|
203
|
+
|
|
204
|
+
plt.tight_layout()
|
|
205
|
+
return fig
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main visualization entry point.
|
|
3
|
+
|
|
4
|
+
This module contains the main visualize_data function that dispatches
|
|
5
|
+
to appropriate visualization handlers based on plot_type.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import TYPE_CHECKING, Union
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import scanpy as sc
|
|
13
|
+
from mcp.server.fastmcp.utilities.types import ImageContent
|
|
14
|
+
from mcp.types import EmbeddedResource
|
|
15
|
+
|
|
16
|
+
from ...models.data import VisualizationParameters
|
|
17
|
+
from ...utils.exceptions import (
|
|
18
|
+
DataCompatibilityError,
|
|
19
|
+
DataNotFoundError,
|
|
20
|
+
ParameterError,
|
|
21
|
+
ProcessingError,
|
|
22
|
+
)
|
|
23
|
+
from ...utils.image_utils import optimize_fig_to_image_with_cache
|
|
24
|
+
|
|
25
|
+
# Import all visualization handlers
|
|
26
|
+
from .basic import (
|
|
27
|
+
create_dotplot_visualization,
|
|
28
|
+
create_heatmap_visualization,
|
|
29
|
+
create_spatial_visualization,
|
|
30
|
+
create_umap_visualization,
|
|
31
|
+
create_violin_visualization,
|
|
32
|
+
)
|
|
33
|
+
from .cell_comm import create_cell_communication_visualization
|
|
34
|
+
from .cnv import create_cnv_heatmap_visualization, create_spatial_cnv_visualization
|
|
35
|
+
from .deconvolution import (
|
|
36
|
+
create_card_imputation_visualization,
|
|
37
|
+
create_deconvolution_visualization,
|
|
38
|
+
)
|
|
39
|
+
from .enrichment import create_pathway_enrichment_visualization
|
|
40
|
+
from .integration import create_batch_integration_visualization
|
|
41
|
+
from .multi_gene import (
|
|
42
|
+
create_gene_correlation_visualization,
|
|
43
|
+
create_lr_pairs_visualization,
|
|
44
|
+
create_multi_gene_visualization,
|
|
45
|
+
create_spatial_interaction_visualization,
|
|
46
|
+
)
|
|
47
|
+
from .spatial_stats import create_spatial_statistics_visualization
|
|
48
|
+
from .trajectory import create_trajectory_visualization
|
|
49
|
+
from .velocity import create_rna_velocity_visualization
|
|
50
|
+
|
|
51
|
+
if TYPE_CHECKING:
|
|
52
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Handler registry for dispatch - defined here to avoid circular imports
|
|
56
|
+
PLOT_HANDLERS = {
|
|
57
|
+
# Basic plots
|
|
58
|
+
"spatial": create_spatial_visualization,
|
|
59
|
+
"umap": create_umap_visualization,
|
|
60
|
+
"heatmap": create_heatmap_visualization,
|
|
61
|
+
"violin": create_violin_visualization,
|
|
62
|
+
"dotplot": create_dotplot_visualization,
|
|
63
|
+
# Analysis-specific plots
|
|
64
|
+
"deconvolution": create_deconvolution_visualization,
|
|
65
|
+
"cell_communication": create_cell_communication_visualization,
|
|
66
|
+
"rna_velocity": create_rna_velocity_visualization,
|
|
67
|
+
"trajectory": create_trajectory_visualization,
|
|
68
|
+
"spatial_statistics": create_spatial_statistics_visualization,
|
|
69
|
+
"pathway_enrichment": create_pathway_enrichment_visualization,
|
|
70
|
+
# CNV plots
|
|
71
|
+
"card_imputation": create_card_imputation_visualization,
|
|
72
|
+
"spatial_cnv": create_spatial_cnv_visualization,
|
|
73
|
+
"cnv_heatmap": create_cnv_heatmap_visualization,
|
|
74
|
+
# Integration plots
|
|
75
|
+
"batch_integration": create_batch_integration_visualization,
|
|
76
|
+
# Multi-gene plots
|
|
77
|
+
"multi_gene": create_multi_gene_visualization,
|
|
78
|
+
"lr_pairs": create_lr_pairs_visualization,
|
|
79
|
+
"gene_correlation": create_gene_correlation_visualization,
|
|
80
|
+
"spatial_interaction": create_spatial_interaction_visualization,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def visualize_data(
|
|
85
|
+
data_id: str,
|
|
86
|
+
ctx: "ToolContext",
|
|
87
|
+
params: VisualizationParameters = VisualizationParameters(), # type: ignore[call-arg]
|
|
88
|
+
) -> Union[ImageContent, tuple[ImageContent, EmbeddedResource]]:
|
|
89
|
+
"""Visualize spatial transcriptomics data.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
data_id: Dataset ID
|
|
93
|
+
ctx: ToolContext for unified data access and logging
|
|
94
|
+
params: Visualization parameters
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Union[ImageContent, Tuple[ImageContent, EmbeddedResource]]:
|
|
98
|
+
- Small images (<100KB): ImageContent object
|
|
99
|
+
- Large images (>=100KB): Tuple[Preview ImageContent, High-quality Resource]
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
DataNotFoundError: If the dataset is not found
|
|
103
|
+
ParameterError: If parameters are invalid
|
|
104
|
+
DataCompatibilityError: If data is not compatible with the visualization
|
|
105
|
+
ProcessingError: If processing fails
|
|
106
|
+
"""
|
|
107
|
+
# Validate parameters - use PLOT_HANDLERS as single source of truth
|
|
108
|
+
if params.plot_type not in PLOT_HANDLERS:
|
|
109
|
+
raise ParameterError(
|
|
110
|
+
f"Invalid plot_type: {params.plot_type}. "
|
|
111
|
+
f"Must be one of {list(PLOT_HANDLERS)}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
# Retrieve the AnnData object via ToolContext
|
|
116
|
+
adata = await ctx.get_adata(data_id)
|
|
117
|
+
|
|
118
|
+
# Validate AnnData object - basic validation
|
|
119
|
+
if adata.n_obs < 5:
|
|
120
|
+
raise DataNotFoundError("Dataset has too few cells (minimum 5 required)")
|
|
121
|
+
if adata.n_vars < 5:
|
|
122
|
+
raise DataNotFoundError("Dataset has too few genes (minimum 5 required)")
|
|
123
|
+
|
|
124
|
+
# Set matplotlib style for better visualizations
|
|
125
|
+
sc.settings.set_figure_params(dpi=params.dpi or 100, facecolor="white")
|
|
126
|
+
|
|
127
|
+
# Dispatch to appropriate handler
|
|
128
|
+
handler = PLOT_HANDLERS[params.plot_type]
|
|
129
|
+
fig = await handler(adata, params, ctx._mcp_context)
|
|
130
|
+
|
|
131
|
+
# Generate plot_type_key with subtype if applicable (for cache consistency)
|
|
132
|
+
subtype = params.subtype
|
|
133
|
+
plot_type_key = f"{params.plot_type}_{subtype}" if subtype else params.plot_type
|
|
134
|
+
|
|
135
|
+
# Use the optimized conversion function
|
|
136
|
+
return await optimize_fig_to_image_with_cache(
|
|
137
|
+
fig,
|
|
138
|
+
params,
|
|
139
|
+
ctx._mcp_context,
|
|
140
|
+
data_id=data_id,
|
|
141
|
+
plot_type=plot_type_key,
|
|
142
|
+
mode="auto",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
# Make sure to close any open figures in case of error
|
|
147
|
+
plt.close("all")
|
|
148
|
+
|
|
149
|
+
# For image conversion errors, return error message as string
|
|
150
|
+
if "fig_to_image" in str(e) or "convert" in str(e).lower():
|
|
151
|
+
error_details = traceback.format_exc()
|
|
152
|
+
return (
|
|
153
|
+
f"Error in {params.plot_type} visualization:\n\n"
|
|
154
|
+
f"{e}\n\n"
|
|
155
|
+
f"Technical details:\n{error_details}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Wrap the error in a more informative exception
|
|
159
|
+
if isinstance(e, (DataNotFoundError, ParameterError, DataCompatibilityError)):
|
|
160
|
+
raise
|
|
161
|
+
else:
|
|
162
|
+
raise ProcessingError(
|
|
163
|
+
f"Failed to create {params.plot_type} visualization: {e}"
|
|
164
|
+
) from e
|