chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Basic visualization functions for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Spatial feature plots
|
|
6
|
+
- UMAP visualizations
|
|
7
|
+
- Heatmap visualizations
|
|
8
|
+
- Violin plots
|
|
9
|
+
- Dot plots
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Optional
|
|
13
|
+
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import scanpy as sc
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import anndata as ad
|
|
20
|
+
|
|
21
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
22
|
+
|
|
23
|
+
from ...models.data import VisualizationParameters
|
|
24
|
+
from ...utils.adata_utils import (
|
|
25
|
+
ensure_categorical,
|
|
26
|
+
get_cluster_key,
|
|
27
|
+
get_gene_expression,
|
|
28
|
+
validate_obs_column,
|
|
29
|
+
)
|
|
30
|
+
from ...utils.compute import ensure_umap
|
|
31
|
+
from ...utils.exceptions import DataNotFoundError, ParameterError
|
|
32
|
+
from .core import (
|
|
33
|
+
add_colorbar,
|
|
34
|
+
create_figure,
|
|
35
|
+
get_colormap,
|
|
36
|
+
get_validated_features,
|
|
37
|
+
plot_spatial_feature,
|
|
38
|
+
setup_multi_panel_figure,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# =============================================================================
|
|
42
|
+
# Spatial Visualization
|
|
43
|
+
# =============================================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def create_spatial_visualization(
|
|
47
|
+
adata: "ad.AnnData",
|
|
48
|
+
params: VisualizationParameters,
|
|
49
|
+
context: Optional["ToolContext"] = None,
|
|
50
|
+
) -> plt.Figure:
|
|
51
|
+
"""Create spatial visualization for one or more features.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
adata: AnnData object with spatial coordinates
|
|
55
|
+
params: Visualization parameters
|
|
56
|
+
context: Optional tool context for logging
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
matplotlib Figure object
|
|
60
|
+
"""
|
|
61
|
+
if params.feature is None:
|
|
62
|
+
features: list[str] = []
|
|
63
|
+
elif isinstance(params.feature, list):
|
|
64
|
+
features = params.feature
|
|
65
|
+
else:
|
|
66
|
+
features = [params.feature]
|
|
67
|
+
|
|
68
|
+
if not features:
|
|
69
|
+
# Default to first available cluster key
|
|
70
|
+
default_cluster = get_cluster_key(adata)
|
|
71
|
+
if default_cluster:
|
|
72
|
+
features = [default_cluster]
|
|
73
|
+
else:
|
|
74
|
+
raise ParameterError(
|
|
75
|
+
"No features specified and no default clustering found"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
n_features = len(features)
|
|
79
|
+
fig, axes = setup_multi_panel_figure(n_features, params, "")
|
|
80
|
+
|
|
81
|
+
for i, feature in enumerate(features):
|
|
82
|
+
ax = axes[i]
|
|
83
|
+
try:
|
|
84
|
+
mappable = plot_spatial_feature(
|
|
85
|
+
adata,
|
|
86
|
+
ax,
|
|
87
|
+
feature=feature,
|
|
88
|
+
params=params,
|
|
89
|
+
title=feature,
|
|
90
|
+
)
|
|
91
|
+
if mappable is not None and params.show_colorbar:
|
|
92
|
+
add_colorbar(fig, ax, mappable, params)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
ax.text(
|
|
95
|
+
0.5,
|
|
96
|
+
0.5,
|
|
97
|
+
f"Error: {str(e)[:50]}",
|
|
98
|
+
ha="center",
|
|
99
|
+
va="center",
|
|
100
|
+
transform=ax.transAxes,
|
|
101
|
+
)
|
|
102
|
+
ax.axis("off")
|
|
103
|
+
|
|
104
|
+
plt.tight_layout()
|
|
105
|
+
return fig
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# =============================================================================
|
|
109
|
+
# UMAP Visualization
|
|
110
|
+
# =============================================================================
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
async def create_umap_visualization(
|
|
114
|
+
adata: "ad.AnnData",
|
|
115
|
+
params: VisualizationParameters,
|
|
116
|
+
context: Optional["ToolContext"] = None,
|
|
117
|
+
) -> plt.Figure:
|
|
118
|
+
"""Create UMAP visualization.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
adata: AnnData object
|
|
122
|
+
params: Visualization parameters
|
|
123
|
+
context: Optional tool context for logging
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
matplotlib Figure object
|
|
127
|
+
"""
|
|
128
|
+
if context:
|
|
129
|
+
feature_desc = params.feature if params.feature else "clusters"
|
|
130
|
+
await context.info(f"Creating UMAP plot for {feature_desc}")
|
|
131
|
+
|
|
132
|
+
# Ensure UMAP is computed (lazy computation)
|
|
133
|
+
if ensure_umap(adata) and context:
|
|
134
|
+
await context.info("Computed UMAP embedding")
|
|
135
|
+
|
|
136
|
+
# Determine what to color by
|
|
137
|
+
color_by = params.feature
|
|
138
|
+
if color_by is None:
|
|
139
|
+
# Default to first available cluster key
|
|
140
|
+
color_by = get_cluster_key(adata)
|
|
141
|
+
|
|
142
|
+
# Create figure
|
|
143
|
+
fig, ax = create_figure(params.figure_size or (10, 8))
|
|
144
|
+
|
|
145
|
+
# Get UMAP coordinates
|
|
146
|
+
umap_coords = adata.obsm["X_umap"]
|
|
147
|
+
|
|
148
|
+
# Get color values
|
|
149
|
+
spot_size = params.spot_size if params.spot_size is not None else 150.0
|
|
150
|
+
if color_by is None:
|
|
151
|
+
# No color - just plot points
|
|
152
|
+
ax.scatter(
|
|
153
|
+
umap_coords[:, 0],
|
|
154
|
+
umap_coords[:, 1],
|
|
155
|
+
s=spot_size // 3,
|
|
156
|
+
alpha=params.alpha,
|
|
157
|
+
c="steelblue",
|
|
158
|
+
)
|
|
159
|
+
elif color_by in adata.var_names:
|
|
160
|
+
# Gene expression - use unified utility
|
|
161
|
+
values = get_gene_expression(adata, color_by)
|
|
162
|
+
|
|
163
|
+
scatter = ax.scatter(
|
|
164
|
+
umap_coords[:, 0],
|
|
165
|
+
umap_coords[:, 1],
|
|
166
|
+
c=values,
|
|
167
|
+
cmap=params.colormap,
|
|
168
|
+
s=spot_size // 3,
|
|
169
|
+
alpha=params.alpha,
|
|
170
|
+
)
|
|
171
|
+
if params.show_colorbar:
|
|
172
|
+
add_colorbar(fig, ax, scatter, params, label=color_by)
|
|
173
|
+
elif color_by in adata.obs.columns:
|
|
174
|
+
# Observation column
|
|
175
|
+
values = adata.obs[color_by]
|
|
176
|
+
is_categorical = (
|
|
177
|
+
pd.api.types.is_categorical_dtype(values) or values.dtype == object
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if is_categorical:
|
|
181
|
+
ensure_categorical(adata, color_by)
|
|
182
|
+
categories = adata.obs[color_by].cat.categories
|
|
183
|
+
n_cats = len(categories)
|
|
184
|
+
colors = get_colormap(params.colormap, n_colors=n_cats)
|
|
185
|
+
|
|
186
|
+
for i, cat in enumerate(categories):
|
|
187
|
+
mask = adata.obs[color_by] == cat
|
|
188
|
+
ax.scatter(
|
|
189
|
+
umap_coords[mask, 0],
|
|
190
|
+
umap_coords[mask, 1],
|
|
191
|
+
c=[colors[i]],
|
|
192
|
+
s=spot_size // 3,
|
|
193
|
+
alpha=params.alpha,
|
|
194
|
+
label=cat,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if params.show_legend:
|
|
198
|
+
ax.legend(
|
|
199
|
+
loc="center left",
|
|
200
|
+
bbox_to_anchor=(1, 0.5),
|
|
201
|
+
fontsize=8,
|
|
202
|
+
frameon=False,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
scatter = ax.scatter(
|
|
206
|
+
umap_coords[:, 0],
|
|
207
|
+
umap_coords[:, 1],
|
|
208
|
+
c=values,
|
|
209
|
+
cmap=params.colormap,
|
|
210
|
+
s=spot_size // 3,
|
|
211
|
+
alpha=params.alpha,
|
|
212
|
+
)
|
|
213
|
+
if params.show_colorbar:
|
|
214
|
+
add_colorbar(fig, ax, scatter, params, label=color_by)
|
|
215
|
+
else:
|
|
216
|
+
raise DataNotFoundError(f"Color feature '{color_by}' not found in genes or obs")
|
|
217
|
+
|
|
218
|
+
ax.set_xlabel("UMAP1")
|
|
219
|
+
ax.set_ylabel("UMAP2")
|
|
220
|
+
ax.set_title(params.title or f"UMAP - {color_by}")
|
|
221
|
+
|
|
222
|
+
if not params.show_axes:
|
|
223
|
+
ax.axis("off")
|
|
224
|
+
|
|
225
|
+
plt.tight_layout()
|
|
226
|
+
return fig
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# =============================================================================
|
|
230
|
+
# Heatmap Visualization
|
|
231
|
+
# =============================================================================
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
async def create_heatmap_visualization(
|
|
235
|
+
adata: "ad.AnnData",
|
|
236
|
+
params: VisualizationParameters,
|
|
237
|
+
context: Optional["ToolContext"] = None,
|
|
238
|
+
) -> plt.Figure:
|
|
239
|
+
"""Create heatmap visualization for gene expression.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
adata: AnnData object
|
|
243
|
+
params: Visualization parameters (requires cluster_key and feature list)
|
|
244
|
+
context: Optional tool context for logging
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
matplotlib Figure object
|
|
248
|
+
"""
|
|
249
|
+
if not params.cluster_key:
|
|
250
|
+
raise ParameterError("Heatmap requires cluster_key parameter")
|
|
251
|
+
|
|
252
|
+
validate_obs_column(adata, params.cluster_key, "Cluster")
|
|
253
|
+
|
|
254
|
+
features = await get_validated_features(adata, params, context, genes_only=True)
|
|
255
|
+
if not features:
|
|
256
|
+
raise ParameterError("No valid gene features provided for heatmap")
|
|
257
|
+
|
|
258
|
+
if context:
|
|
259
|
+
await context.info(
|
|
260
|
+
f"Creating heatmap for {len(features)} genes grouped by {params.cluster_key}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Use scanpy's heatmap function
|
|
264
|
+
# Note: return_fig=True causes issues with newer matplotlib versions
|
|
265
|
+
sc.pl.heatmap(
|
|
266
|
+
adata,
|
|
267
|
+
var_names=features,
|
|
268
|
+
groupby=params.cluster_key,
|
|
269
|
+
cmap=params.colormap,
|
|
270
|
+
show=False,
|
|
271
|
+
dendrogram=params.dotplot_dendrogram,
|
|
272
|
+
swap_axes=params.dotplot_swap_axes,
|
|
273
|
+
standard_scale=params.dotplot_standard_scale,
|
|
274
|
+
)
|
|
275
|
+
fig = plt.gcf()
|
|
276
|
+
|
|
277
|
+
return fig
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# =============================================================================
|
|
281
|
+
# Violin Plot Visualization
|
|
282
|
+
# =============================================================================
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
async def create_violin_visualization(
|
|
286
|
+
adata: "ad.AnnData",
|
|
287
|
+
params: VisualizationParameters,
|
|
288
|
+
context: Optional["ToolContext"] = None,
|
|
289
|
+
) -> plt.Figure:
|
|
290
|
+
"""Create violin plot visualization for gene expression.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
adata: AnnData object
|
|
294
|
+
params: Visualization parameters (requires cluster_key and feature)
|
|
295
|
+
context: Optional tool context for logging
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
matplotlib Figure object
|
|
299
|
+
"""
|
|
300
|
+
if not params.cluster_key:
|
|
301
|
+
raise ParameterError("Violin plot requires cluster_key parameter")
|
|
302
|
+
|
|
303
|
+
validate_obs_column(adata, params.cluster_key, "Cluster")
|
|
304
|
+
|
|
305
|
+
features = await get_validated_features(adata, params, context, genes_only=True)
|
|
306
|
+
if not features:
|
|
307
|
+
raise ParameterError("No valid gene features provided for violin plot")
|
|
308
|
+
|
|
309
|
+
if context:
|
|
310
|
+
await context.info(
|
|
311
|
+
f"Creating violin plot for {len(features)} genes grouped by {params.cluster_key}"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Use scanpy's violin function
|
|
315
|
+
# Note: return_fig=True causes issues with newer matplotlib/seaborn versions
|
|
316
|
+
# Instead, we use plt.gcf() to get the current figure
|
|
317
|
+
sc.pl.violin(
|
|
318
|
+
adata,
|
|
319
|
+
keys=features,
|
|
320
|
+
groupby=params.cluster_key,
|
|
321
|
+
show=False,
|
|
322
|
+
)
|
|
323
|
+
fig = plt.gcf()
|
|
324
|
+
|
|
325
|
+
return fig
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
# =============================================================================
|
|
329
|
+
# Dot Plot Visualization
|
|
330
|
+
# =============================================================================
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
async def create_dotplot_visualization(
|
|
334
|
+
adata: "ad.AnnData",
|
|
335
|
+
params: VisualizationParameters,
|
|
336
|
+
context: Optional["ToolContext"] = None,
|
|
337
|
+
) -> plt.Figure:
|
|
338
|
+
"""Create dot plot visualization for gene expression.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
adata: AnnData object
|
|
342
|
+
params: Visualization parameters (requires cluster_key and feature list)
|
|
343
|
+
context: Optional tool context for logging
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
matplotlib Figure object
|
|
347
|
+
"""
|
|
348
|
+
if not params.cluster_key:
|
|
349
|
+
raise ParameterError("Dot plot requires cluster_key parameter")
|
|
350
|
+
|
|
351
|
+
validate_obs_column(adata, params.cluster_key, "Cluster")
|
|
352
|
+
|
|
353
|
+
features = await get_validated_features(adata, params, context, genes_only=True)
|
|
354
|
+
if not features:
|
|
355
|
+
raise ParameterError("No valid gene features provided for dot plot")
|
|
356
|
+
|
|
357
|
+
if context:
|
|
358
|
+
await context.info(
|
|
359
|
+
f"Creating dot plot for {len(features)} genes grouped by {params.cluster_key}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Build kwargs for dotplot
|
|
363
|
+
# Note: return_fig=True returns DotPlot object, not Figure
|
|
364
|
+
# Use plt.gcf() instead to get the figure
|
|
365
|
+
dotplot_kwargs = {
|
|
366
|
+
"adata": adata,
|
|
367
|
+
"var_names": features,
|
|
368
|
+
"groupby": params.cluster_key,
|
|
369
|
+
"cmap": params.colormap,
|
|
370
|
+
"show": False,
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
# Add optional parameters
|
|
374
|
+
if params.dotplot_dendrogram:
|
|
375
|
+
dotplot_kwargs["dendrogram"] = True
|
|
376
|
+
if params.dotplot_swap_axes:
|
|
377
|
+
dotplot_kwargs["swap_axes"] = True
|
|
378
|
+
if params.dotplot_standard_scale:
|
|
379
|
+
dotplot_kwargs["standard_scale"] = params.dotplot_standard_scale
|
|
380
|
+
if params.dotplot_dot_min is not None:
|
|
381
|
+
dotplot_kwargs["dot_min"] = params.dotplot_dot_min
|
|
382
|
+
if params.dotplot_dot_max is not None:
|
|
383
|
+
dotplot_kwargs["dot_max"] = params.dotplot_dot_max
|
|
384
|
+
if params.dotplot_smallest_dot is not None:
|
|
385
|
+
dotplot_kwargs["smallest_dot"] = params.dotplot_smallest_dot
|
|
386
|
+
if params.dotplot_var_groups:
|
|
387
|
+
dotplot_kwargs["var_group_positions"] = list(params.dotplot_var_groups.keys())
|
|
388
|
+
dotplot_kwargs["var_group_labels"] = list(params.dotplot_var_groups.keys())
|
|
389
|
+
|
|
390
|
+
sc.pl.dotplot(**dotplot_kwargs)
|
|
391
|
+
fig = plt.gcf()
|
|
392
|
+
|
|
393
|
+
return fig
|