chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CNV (Copy Number Variation) visualization functions.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- CNV heatmap visualization
|
|
6
|
+
- Spatial CNV projection visualization
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import seaborn as sns
|
|
15
|
+
|
|
16
|
+
from ...models.data import VisualizationParameters
|
|
17
|
+
from ...utils.adata_utils import require_spatial_coords, validate_obs_column
|
|
18
|
+
from ...utils.dependency_manager import require
|
|
19
|
+
from ...utils.exceptions import DataNotFoundError
|
|
20
|
+
from .core import create_figure_from_params, plot_spatial_feature, resolve_figure_size
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
import anndata as ad
|
|
24
|
+
|
|
25
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# Spatial CNV Visualization
|
|
30
|
+
# =============================================================================
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def create_spatial_cnv_visualization(
|
|
34
|
+
adata: "ad.AnnData",
|
|
35
|
+
params: VisualizationParameters,
|
|
36
|
+
context: Optional["ToolContext"] = None,
|
|
37
|
+
) -> plt.Figure:
|
|
38
|
+
"""Create spatial CNV projection visualization.
|
|
39
|
+
|
|
40
|
+
Uses the unified plot_spatial_feature() helper for cleaner code
|
|
41
|
+
and consistent parameter handling.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
adata: AnnData object
|
|
45
|
+
params: Visualization parameters
|
|
46
|
+
context: MCP context for logging
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
matplotlib Figure object
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
DataNotFoundError: If spatial coordinates or CNV features not found
|
|
53
|
+
"""
|
|
54
|
+
if context:
|
|
55
|
+
await context.info("Creating spatial CNV projection visualization")
|
|
56
|
+
|
|
57
|
+
# Validate spatial coordinates
|
|
58
|
+
require_spatial_coords(adata)
|
|
59
|
+
|
|
60
|
+
# Determine feature to visualize
|
|
61
|
+
feature_to_plot = params.feature
|
|
62
|
+
|
|
63
|
+
# Auto-detect CNV-related features if none specified
|
|
64
|
+
if not feature_to_plot:
|
|
65
|
+
if "numbat_clone" in adata.obs:
|
|
66
|
+
feature_to_plot = "numbat_clone"
|
|
67
|
+
if context:
|
|
68
|
+
await context.info(
|
|
69
|
+
"No feature specified, using 'numbat_clone' (Numbat clone assignment)"
|
|
70
|
+
)
|
|
71
|
+
elif "cnv_score" in adata.obs:
|
|
72
|
+
feature_to_plot = "cnv_score"
|
|
73
|
+
if context:
|
|
74
|
+
await context.info(
|
|
75
|
+
"No feature specified, using 'cnv_score' (CNV score)"
|
|
76
|
+
)
|
|
77
|
+
elif "numbat_p_cnv" in adata.obs:
|
|
78
|
+
feature_to_plot = "numbat_p_cnv"
|
|
79
|
+
if context:
|
|
80
|
+
await context.info(
|
|
81
|
+
"No feature specified, using 'numbat_p_cnv' (Numbat CNV probability)"
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
error_msg = "No CNV features found. Run analyze_cnv() first."
|
|
85
|
+
if context:
|
|
86
|
+
await context.warning(error_msg)
|
|
87
|
+
raise DataNotFoundError(error_msg)
|
|
88
|
+
|
|
89
|
+
# Validate feature exists
|
|
90
|
+
validate_obs_column(adata, feature_to_plot, "CNV feature")
|
|
91
|
+
|
|
92
|
+
if context:
|
|
93
|
+
await context.info(f"Visualizing {feature_to_plot} on spatial coordinates")
|
|
94
|
+
|
|
95
|
+
# Override colormap default for CNV data (RdBu_r is better for CNV scores)
|
|
96
|
+
if not params.colormap:
|
|
97
|
+
params.colormap = (
|
|
98
|
+
"RdBu_r"
|
|
99
|
+
if not pd.api.types.is_categorical_dtype(adata.obs[feature_to_plot])
|
|
100
|
+
else "tab20"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Use centralized figure creation
|
|
104
|
+
fig, axes = create_figure_from_params(params, "spatial")
|
|
105
|
+
ax = axes[0]
|
|
106
|
+
|
|
107
|
+
# Use the enhanced plot_spatial_feature helper
|
|
108
|
+
plot_spatial_feature(adata, ax, feature=feature_to_plot, params=params)
|
|
109
|
+
|
|
110
|
+
if context:
|
|
111
|
+
await context.info(f"Spatial CNV projection created for {feature_to_plot}")
|
|
112
|
+
|
|
113
|
+
return fig
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# =============================================================================
|
|
117
|
+
# CNV Heatmap Visualization
|
|
118
|
+
# =============================================================================
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def create_cnv_heatmap_visualization(
|
|
122
|
+
adata: "ad.AnnData",
|
|
123
|
+
params: VisualizationParameters,
|
|
124
|
+
context: Optional["ToolContext"] = None,
|
|
125
|
+
) -> plt.Figure:
|
|
126
|
+
"""Create CNV heatmap visualization.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
adata: AnnData object
|
|
130
|
+
params: Visualization parameters
|
|
131
|
+
context: MCP context for logging
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
matplotlib Figure object
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
DataNotFoundError: If CNV data not found
|
|
138
|
+
DataCompatibilityError: If infercnvpy not installed
|
|
139
|
+
"""
|
|
140
|
+
if context:
|
|
141
|
+
await context.info("Creating CNV heatmap visualization")
|
|
142
|
+
|
|
143
|
+
# Auto-detect CNV data source (infercnvpy or Numbat)
|
|
144
|
+
cnv_method = None
|
|
145
|
+
|
|
146
|
+
if "X_cnv" in adata.obsm:
|
|
147
|
+
cnv_method = "infercnvpy"
|
|
148
|
+
elif "X_cnv_numbat" in adata.obsm:
|
|
149
|
+
cnv_method = "numbat"
|
|
150
|
+
else:
|
|
151
|
+
error_msg = "CNV data not found in obsm. Run analyze_cnv() first."
|
|
152
|
+
if context:
|
|
153
|
+
await context.warning(error_msg)
|
|
154
|
+
raise DataNotFoundError(error_msg)
|
|
155
|
+
|
|
156
|
+
if context:
|
|
157
|
+
await context.info(f"Detected CNV data from {cnv_method} method")
|
|
158
|
+
|
|
159
|
+
# Check if infercnvpy is available (needed for visualization)
|
|
160
|
+
require("infercnvpy", feature="CNV heatmap visualization")
|
|
161
|
+
|
|
162
|
+
# For Numbat data, temporarily copy to X_cnv for visualization
|
|
163
|
+
if cnv_method == "numbat":
|
|
164
|
+
if context:
|
|
165
|
+
await context.info(
|
|
166
|
+
"Converting Numbat CNV data to infercnvpy format for visualization"
|
|
167
|
+
)
|
|
168
|
+
adata.obsm["X_cnv"] = adata.obsm["X_cnv_numbat"]
|
|
169
|
+
# Also ensure cnv metadata exists for infercnvpy plotting
|
|
170
|
+
if "cnv" not in adata.uns:
|
|
171
|
+
adata.uns["cnv"] = {
|
|
172
|
+
"genomic_positions": False,
|
|
173
|
+
}
|
|
174
|
+
if context:
|
|
175
|
+
await context.info(
|
|
176
|
+
"Note: Chromosome labels not available for Numbat heatmap. "
|
|
177
|
+
"Install R packages for full chromosome annotation."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Check if CNV metadata exists
|
|
181
|
+
if "cnv" not in adata.uns:
|
|
182
|
+
error_msg = (
|
|
183
|
+
"CNV metadata not found in adata.uns['cnv']. "
|
|
184
|
+
"The CNV analysis may not have completed properly. "
|
|
185
|
+
"Please re-run analyze_cnv()."
|
|
186
|
+
)
|
|
187
|
+
if context:
|
|
188
|
+
await context.warning(error_msg)
|
|
189
|
+
raise DataNotFoundError(error_msg)
|
|
190
|
+
|
|
191
|
+
# Create CNV heatmap
|
|
192
|
+
if context:
|
|
193
|
+
await context.info("Generating CNV heatmap...")
|
|
194
|
+
|
|
195
|
+
# Use centralized figure size resolution
|
|
196
|
+
figsize = resolve_figure_size(params, "cnv")
|
|
197
|
+
|
|
198
|
+
# For Numbat data without chromosome info, use aggregated heatmap by group
|
|
199
|
+
if cnv_method == "numbat" and "chromosome" not in adata.var.columns:
|
|
200
|
+
if context:
|
|
201
|
+
await context.info(
|
|
202
|
+
"Creating aggregated CNV heatmap by group (chromosome positions not available)"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Get CNV matrix
|
|
206
|
+
cnv_matrix = adata.obsm["X_cnv"]
|
|
207
|
+
|
|
208
|
+
# Aggregate by feature (e.g., clone) for cleaner visualization
|
|
209
|
+
if params.feature and params.feature in adata.obs.columns:
|
|
210
|
+
# Group cells by feature and compute mean CNV per group
|
|
211
|
+
feature_values = adata.obs[params.feature]
|
|
212
|
+
unique_groups = sorted(feature_values.unique())
|
|
213
|
+
|
|
214
|
+
# Compute mean CNV for each group
|
|
215
|
+
aggregated_cnv_list: list[Any] = []
|
|
216
|
+
group_labels: list[str] = []
|
|
217
|
+
group_sizes: list[Any] = []
|
|
218
|
+
|
|
219
|
+
for group in unique_groups:
|
|
220
|
+
group_mask = feature_values == group
|
|
221
|
+
group_cnv = cnv_matrix[group_mask, :].mean(axis=0)
|
|
222
|
+
aggregated_cnv_list.append(group_cnv)
|
|
223
|
+
group_labels.append(str(group))
|
|
224
|
+
group_sizes.append(group_mask.sum())
|
|
225
|
+
|
|
226
|
+
aggregated_cnv = np.array(aggregated_cnv_list)
|
|
227
|
+
|
|
228
|
+
# Calculate appropriate figure width based on number of bins
|
|
229
|
+
n_bins = aggregated_cnv.shape[1]
|
|
230
|
+
fig_width = min(max(6, n_bins * 0.004), 12)
|
|
231
|
+
fig_height = max(4, len(unique_groups) * 1.2)
|
|
232
|
+
|
|
233
|
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
|
234
|
+
|
|
235
|
+
# Plot aggregated heatmap with fixed aspect ratio
|
|
236
|
+
im = ax.imshow(
|
|
237
|
+
aggregated_cnv,
|
|
238
|
+
cmap="RdBu_r",
|
|
239
|
+
aspect="auto",
|
|
240
|
+
vmin=-1,
|
|
241
|
+
vmax=1,
|
|
242
|
+
interpolation="nearest",
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Add colorbar
|
|
246
|
+
plt.colorbar(im, ax=ax, label="Mean CNV state")
|
|
247
|
+
|
|
248
|
+
# Set y-axis labels with group names and cell counts
|
|
249
|
+
ax.set_yticks(range(len(group_labels)))
|
|
250
|
+
ax.set_yticklabels(
|
|
251
|
+
[
|
|
252
|
+
f"{label} (n={size})"
|
|
253
|
+
for label, size in zip(group_labels, group_sizes, strict=False)
|
|
254
|
+
]
|
|
255
|
+
)
|
|
256
|
+
feature_label = (
|
|
257
|
+
params.feature
|
|
258
|
+
if isinstance(params.feature, str)
|
|
259
|
+
else ", ".join(params.feature) if params.feature else ""
|
|
260
|
+
)
|
|
261
|
+
ax.set_ylabel(feature_label, fontsize=12, fontweight="bold")
|
|
262
|
+
|
|
263
|
+
# Set x-axis
|
|
264
|
+
ax.set_xlabel("Genomic position (binned)", fontsize=12)
|
|
265
|
+
ax.set_xticks([]) # Hide x-axis ticks for cleaner look
|
|
266
|
+
|
|
267
|
+
# Add title
|
|
268
|
+
ax.set_title(
|
|
269
|
+
f"CNV Profile by {params.feature}\n(Numbat analysis, aggregated by group)",
|
|
270
|
+
fontsize=14,
|
|
271
|
+
fontweight="bold",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Add gridlines between groups
|
|
275
|
+
for i in range(len(group_labels) + 1):
|
|
276
|
+
ax.axhline(i - 0.5, color="white", linewidth=2)
|
|
277
|
+
|
|
278
|
+
else:
|
|
279
|
+
# No grouping - show warning and plot all cells (not recommended)
|
|
280
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
281
|
+
|
|
282
|
+
sns.heatmap(
|
|
283
|
+
cnv_matrix,
|
|
284
|
+
cmap="RdBu_r",
|
|
285
|
+
center=0,
|
|
286
|
+
cbar_kws={"label": "CNV state"},
|
|
287
|
+
yticklabels=False,
|
|
288
|
+
xticklabels=False,
|
|
289
|
+
ax=ax,
|
|
290
|
+
vmin=-1,
|
|
291
|
+
vmax=1,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
ax.set_xlabel("Genomic position (binned)")
|
|
295
|
+
ax.set_ylabel("Cells")
|
|
296
|
+
ax.set_title("CNV Heatmap (Numbat)\nAll cells (ungrouped)")
|
|
297
|
+
|
|
298
|
+
plt.tight_layout()
|
|
299
|
+
|
|
300
|
+
else:
|
|
301
|
+
# Use infercnvpy chromosome_heatmap for infercnvpy data or Numbat with chr info
|
|
302
|
+
import infercnvpy as cnv
|
|
303
|
+
|
|
304
|
+
if context:
|
|
305
|
+
await context.info("Creating chromosome-organized CNV heatmap...")
|
|
306
|
+
|
|
307
|
+
cnv.pl.chromosome_heatmap(
|
|
308
|
+
adata,
|
|
309
|
+
groupby=params.cluster_key,
|
|
310
|
+
dendrogram=True,
|
|
311
|
+
show=False,
|
|
312
|
+
figsize=figsize,
|
|
313
|
+
)
|
|
314
|
+
# Get current figure
|
|
315
|
+
fig = plt.gcf()
|
|
316
|
+
|
|
317
|
+
if context:
|
|
318
|
+
await context.info("CNV heatmap created successfully")
|
|
319
|
+
|
|
320
|
+
return fig
|