chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RNA velocity visualization functions for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Velocity stream plots
|
|
6
|
+
- Phase plots (spliced vs unspliced)
|
|
7
|
+
- Proportions plots (pie charts)
|
|
8
|
+
- Velocity heatmaps
|
|
9
|
+
- PAGA with velocity arrows
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Optional
|
|
13
|
+
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import anndata as ad
|
|
18
|
+
|
|
19
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
20
|
+
|
|
21
|
+
from ...models.data import VisualizationParameters
|
|
22
|
+
from ...utils.adata_utils import validate_obs_column
|
|
23
|
+
from ...utils.dependency_manager import require
|
|
24
|
+
from ...utils.exceptions import (
|
|
25
|
+
DataCompatibilityError,
|
|
26
|
+
DataNotFoundError,
|
|
27
|
+
ParameterError,
|
|
28
|
+
)
|
|
29
|
+
from .core import (
|
|
30
|
+
create_figure_from_params,
|
|
31
|
+
get_categorical_columns,
|
|
32
|
+
infer_basis,
|
|
33
|
+
resolve_figure_size,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# Main Router
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def create_rna_velocity_visualization(
|
|
42
|
+
adata: "ad.AnnData",
|
|
43
|
+
params: VisualizationParameters,
|
|
44
|
+
context: Optional["ToolContext"] = None,
|
|
45
|
+
) -> plt.Figure:
|
|
46
|
+
"""Create RNA velocity visualization based on subtype.
|
|
47
|
+
|
|
48
|
+
Dispatcher function that routes to appropriate scVelo visualization.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
adata: AnnData object with computed RNA velocity
|
|
52
|
+
params: Visualization parameters including subtype
|
|
53
|
+
context: MCP context
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Matplotlib figure with RNA velocity visualization
|
|
57
|
+
|
|
58
|
+
Subtypes:
|
|
59
|
+
- stream (default): Velocity embedding stream plot
|
|
60
|
+
- phase: Phase plot showing spliced vs unspliced
|
|
61
|
+
- proportions: Pie chart of spliced/unspliced ratios
|
|
62
|
+
- heatmap: Gene expression ordered by latent_time
|
|
63
|
+
- paga: PAGA with velocity arrows
|
|
64
|
+
"""
|
|
65
|
+
subtype = params.subtype or "stream"
|
|
66
|
+
|
|
67
|
+
if context:
|
|
68
|
+
await context.info(f"Creating RNA velocity visualization (subtype: {subtype})")
|
|
69
|
+
|
|
70
|
+
if subtype == "stream":
|
|
71
|
+
return await _create_velocity_stream_plot(adata, params, context)
|
|
72
|
+
elif subtype == "phase":
|
|
73
|
+
return await _create_velocity_phase_plot(adata, params, context)
|
|
74
|
+
elif subtype == "proportions":
|
|
75
|
+
return await _create_velocity_proportions_plot(adata, params, context)
|
|
76
|
+
elif subtype == "heatmap":
|
|
77
|
+
return await _create_velocity_heatmap(adata, params, context)
|
|
78
|
+
elif subtype == "paga":
|
|
79
|
+
return await _create_velocity_paga_plot(adata, params, context)
|
|
80
|
+
else:
|
|
81
|
+
raise ParameterError(
|
|
82
|
+
f"Unsupported subtype for rna_velocity: '{subtype}'. "
|
|
83
|
+
f"Available subtypes: stream, phase, proportions, heatmap, paga"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# =============================================================================
|
|
88
|
+
# Visualization Functions
|
|
89
|
+
# =============================================================================
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def _create_velocity_stream_plot(
|
|
93
|
+
adata: "ad.AnnData",
|
|
94
|
+
params: VisualizationParameters,
|
|
95
|
+
context: Optional["ToolContext"] = None,
|
|
96
|
+
) -> plt.Figure:
|
|
97
|
+
"""Create RNA velocity stream plot using scv.pl.velocity_embedding_stream.
|
|
98
|
+
|
|
99
|
+
Data requirements:
|
|
100
|
+
- adata.uns['velocity_graph']: Velocity transition graph
|
|
101
|
+
- adata.obsm['X_umap'] or 'spatial': Embedding for visualization
|
|
102
|
+
"""
|
|
103
|
+
require("scvelo", feature="RNA velocity visualization")
|
|
104
|
+
import scvelo as scv
|
|
105
|
+
|
|
106
|
+
if "velocity_graph" not in adata.uns:
|
|
107
|
+
raise DataNotFoundError(
|
|
108
|
+
"RNA velocity not computed. Run analyze_velocity_data first."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Determine basis for plotting
|
|
112
|
+
basis = infer_basis(adata, preferred=params.basis)
|
|
113
|
+
if not basis:
|
|
114
|
+
raise DataCompatibilityError(
|
|
115
|
+
f"No valid embedding basis found. "
|
|
116
|
+
f"Available keys: {list(adata.obsm.keys())}"
|
|
117
|
+
)
|
|
118
|
+
if context and basis != params.basis:
|
|
119
|
+
await context.info(f"Using '{basis}' as basis")
|
|
120
|
+
|
|
121
|
+
# Prepare feature for coloring
|
|
122
|
+
feature = params.feature
|
|
123
|
+
if not feature:
|
|
124
|
+
categorical_cols = get_categorical_columns(adata)
|
|
125
|
+
feature = categorical_cols[0] if categorical_cols else None
|
|
126
|
+
if feature and context:
|
|
127
|
+
await context.info(f"Using '{feature}' for coloring")
|
|
128
|
+
|
|
129
|
+
fig, axes = create_figure_from_params(params, "velocity")
|
|
130
|
+
ax = axes[0]
|
|
131
|
+
|
|
132
|
+
scv.pl.velocity_embedding_stream(
|
|
133
|
+
adata,
|
|
134
|
+
basis=basis,
|
|
135
|
+
color=feature,
|
|
136
|
+
ax=ax,
|
|
137
|
+
show=False,
|
|
138
|
+
alpha=params.alpha,
|
|
139
|
+
legend_loc="right margin" if feature and feature in adata.obs.columns else None,
|
|
140
|
+
frameon=params.show_axes,
|
|
141
|
+
title="",
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
title = params.title or f"RNA Velocity Stream on {basis.capitalize()}"
|
|
145
|
+
ax.set_title(title, fontsize=14)
|
|
146
|
+
|
|
147
|
+
if basis == "spatial":
|
|
148
|
+
ax.invert_yaxis()
|
|
149
|
+
|
|
150
|
+
plt.tight_layout()
|
|
151
|
+
return fig
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _create_velocity_phase_plot(
|
|
155
|
+
adata: "ad.AnnData",
|
|
156
|
+
params: VisualizationParameters,
|
|
157
|
+
context: Optional["ToolContext"] = None,
|
|
158
|
+
) -> plt.Figure:
|
|
159
|
+
"""Create velocity phase plot using scv.pl.velocity.
|
|
160
|
+
|
|
161
|
+
Shows spliced vs unspliced counts with fitted velocity model for specified genes.
|
|
162
|
+
|
|
163
|
+
Data requirements:
|
|
164
|
+
- adata.layers['velocity']: Velocity vectors
|
|
165
|
+
- adata.layers['Ms']: Smoothed spliced counts
|
|
166
|
+
- adata.layers['Mu']: Smoothed unspliced counts
|
|
167
|
+
"""
|
|
168
|
+
require("scvelo", feature="velocity phase plots")
|
|
169
|
+
import scvelo as scv
|
|
170
|
+
|
|
171
|
+
required_layers = ["velocity", "Ms", "Mu"]
|
|
172
|
+
missing_layers = [layer for layer in required_layers if layer not in adata.layers]
|
|
173
|
+
if missing_layers:
|
|
174
|
+
raise DataNotFoundError(
|
|
175
|
+
f"Missing layers for phase plot: {missing_layers}. Run velocity analysis first."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if params.feature:
|
|
179
|
+
if isinstance(params.feature, str):
|
|
180
|
+
var_names = [params.feature]
|
|
181
|
+
else:
|
|
182
|
+
var_names = list(params.feature)
|
|
183
|
+
else:
|
|
184
|
+
if "velocity_genes" in adata.var.columns:
|
|
185
|
+
velocity_genes = adata.var_names[adata.var["velocity_genes"]]
|
|
186
|
+
var_names = list(velocity_genes[:4])
|
|
187
|
+
else:
|
|
188
|
+
var_names = list(adata.var_names[:4])
|
|
189
|
+
|
|
190
|
+
valid_genes = [g for g in var_names if g in adata.var_names]
|
|
191
|
+
if not valid_genes:
|
|
192
|
+
raise DataNotFoundError(
|
|
193
|
+
f"None of the specified genes found in data: {var_names}. "
|
|
194
|
+
f"Available genes (first 10): {list(adata.var_names[:10])}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if context:
|
|
198
|
+
await context.info(f"Creating phase plot for genes: {valid_genes}")
|
|
199
|
+
|
|
200
|
+
basis = infer_basis(adata, preferred=params.basis, priority=["umap", "spatial"])
|
|
201
|
+
figsize = resolve_figure_size(
|
|
202
|
+
params, n_panels=len(valid_genes), panel_width=4, panel_height=4
|
|
203
|
+
)
|
|
204
|
+
color = params.cluster_key if params.cluster_key else None
|
|
205
|
+
|
|
206
|
+
scv.pl.velocity(
|
|
207
|
+
adata,
|
|
208
|
+
var_names=valid_genes,
|
|
209
|
+
basis=basis,
|
|
210
|
+
color=color,
|
|
211
|
+
figsize=figsize,
|
|
212
|
+
dpi=params.dpi,
|
|
213
|
+
show=False,
|
|
214
|
+
ncols=len(valid_genes),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
fig = plt.gcf()
|
|
218
|
+
title = params.title or "RNA Velocity Phase Plot"
|
|
219
|
+
fig.suptitle(title, fontsize=14, y=1.02)
|
|
220
|
+
plt.tight_layout()
|
|
221
|
+
return fig
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
async def _create_velocity_proportions_plot(
|
|
225
|
+
adata: "ad.AnnData",
|
|
226
|
+
params: VisualizationParameters,
|
|
227
|
+
context: Optional["ToolContext"] = None,
|
|
228
|
+
) -> plt.Figure:
|
|
229
|
+
"""Create velocity proportions plot using scv.pl.proportions.
|
|
230
|
+
|
|
231
|
+
Shows pie chart of spliced/unspliced RNA proportions per cluster.
|
|
232
|
+
|
|
233
|
+
Data requirements:
|
|
234
|
+
- adata.layers['spliced']: Spliced counts
|
|
235
|
+
- adata.layers['unspliced']: Unspliced counts
|
|
236
|
+
- adata.obs[cluster_key]: Cluster labels for grouping
|
|
237
|
+
"""
|
|
238
|
+
require("scvelo", feature="proportions plot")
|
|
239
|
+
import scvelo as scv
|
|
240
|
+
|
|
241
|
+
if "spliced" not in adata.layers or "unspliced" not in adata.layers:
|
|
242
|
+
raise DataNotFoundError(
|
|
243
|
+
"Spliced and unspliced layers are required for proportions plot. "
|
|
244
|
+
"Your data may not contain RNA velocity information."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
cluster_key = params.cluster_key
|
|
248
|
+
if not cluster_key:
|
|
249
|
+
categorical_cols = get_categorical_columns(adata)
|
|
250
|
+
if categorical_cols:
|
|
251
|
+
cluster_key = categorical_cols[0]
|
|
252
|
+
if context:
|
|
253
|
+
await context.info(f"Using cluster_key: '{cluster_key}'")
|
|
254
|
+
else:
|
|
255
|
+
raise ParameterError(
|
|
256
|
+
"cluster_key is required for proportions plot. "
|
|
257
|
+
f"Available columns: {list(adata.obs.columns)[:10]}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
validate_obs_column(adata, cluster_key, "Cluster")
|
|
261
|
+
|
|
262
|
+
if context:
|
|
263
|
+
await context.info(f"Creating proportions plot grouped by '{cluster_key}'")
|
|
264
|
+
|
|
265
|
+
figsize = resolve_figure_size(params, "violin")
|
|
266
|
+
|
|
267
|
+
scv.pl.proportions(
|
|
268
|
+
adata,
|
|
269
|
+
groupby=cluster_key,
|
|
270
|
+
figsize=figsize,
|
|
271
|
+
dpi=params.dpi,
|
|
272
|
+
show=False,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
fig = plt.gcf()
|
|
276
|
+
title = params.title or f"Spliced/Unspliced Proportions by {cluster_key}"
|
|
277
|
+
fig.suptitle(title, fontsize=14, y=1.02)
|
|
278
|
+
plt.tight_layout()
|
|
279
|
+
return fig
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
async def _create_velocity_heatmap(
|
|
283
|
+
adata: "ad.AnnData",
|
|
284
|
+
params: VisualizationParameters,
|
|
285
|
+
context: Optional["ToolContext"] = None,
|
|
286
|
+
) -> plt.Figure:
|
|
287
|
+
"""Create velocity heatmap using scv.pl.heatmap.
|
|
288
|
+
|
|
289
|
+
Shows gene expression patterns ordered by latent time.
|
|
290
|
+
|
|
291
|
+
Data requirements:
|
|
292
|
+
- adata.obs['latent_time']: Latent time from dynamical model
|
|
293
|
+
- adata.var['velocity_genes']: Velocity genes (optional)
|
|
294
|
+
"""
|
|
295
|
+
require("scvelo", feature="velocity heatmap")
|
|
296
|
+
import scvelo as scv
|
|
297
|
+
|
|
298
|
+
validate_obs_column(adata, "latent_time", "Latent time")
|
|
299
|
+
|
|
300
|
+
if params.feature:
|
|
301
|
+
if isinstance(params.feature, str):
|
|
302
|
+
var_names = [params.feature]
|
|
303
|
+
else:
|
|
304
|
+
var_names = list(params.feature)
|
|
305
|
+
valid_genes = [g for g in var_names if g in adata.var_names]
|
|
306
|
+
if not valid_genes:
|
|
307
|
+
raise DataNotFoundError(f"None of the specified genes found: {var_names}")
|
|
308
|
+
var_names = valid_genes
|
|
309
|
+
else:
|
|
310
|
+
if "velocity_genes" in adata.var.columns:
|
|
311
|
+
velocity_genes = adata.var_names[adata.var["velocity_genes"]]
|
|
312
|
+
var_names = list(velocity_genes[:50])
|
|
313
|
+
else:
|
|
314
|
+
if "highly_variable" in adata.var.columns:
|
|
315
|
+
hvg = adata.var_names[adata.var["highly_variable"]]
|
|
316
|
+
var_names = list(hvg[:50])
|
|
317
|
+
else:
|
|
318
|
+
var_names = list(adata.var_names[:50])
|
|
319
|
+
|
|
320
|
+
if context:
|
|
321
|
+
await context.info(f"Creating velocity heatmap with {len(var_names)} genes")
|
|
322
|
+
|
|
323
|
+
figsize = resolve_figure_size(params, "heatmap")
|
|
324
|
+
|
|
325
|
+
scv.pl.heatmap(
|
|
326
|
+
adata,
|
|
327
|
+
var_names=var_names,
|
|
328
|
+
sortby="latent_time",
|
|
329
|
+
col_color=params.cluster_key,
|
|
330
|
+
n_convolve=30,
|
|
331
|
+
show=False,
|
|
332
|
+
figsize=figsize,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
fig = plt.gcf()
|
|
336
|
+
fig.set_dpi(params.dpi)
|
|
337
|
+
|
|
338
|
+
if params.title:
|
|
339
|
+
fig.suptitle(params.title, fontsize=14, y=1.02)
|
|
340
|
+
plt.tight_layout()
|
|
341
|
+
return fig
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
async def _create_velocity_paga_plot(
|
|
345
|
+
adata: "ad.AnnData",
|
|
346
|
+
params: VisualizationParameters,
|
|
347
|
+
context: Optional["ToolContext"] = None,
|
|
348
|
+
) -> plt.Figure:
|
|
349
|
+
"""Create PAGA plot with velocity using scv.pl.paga.
|
|
350
|
+
|
|
351
|
+
Shows partition-based graph abstraction with directed velocity arrows.
|
|
352
|
+
|
|
353
|
+
Data requirements:
|
|
354
|
+
- adata.uns['velocity_graph']: Velocity transition graph
|
|
355
|
+
- adata.uns['paga']: PAGA results (computed by scv.tl.paga)
|
|
356
|
+
- adata.obs[cluster_key]: Cluster labels used for PAGA
|
|
357
|
+
"""
|
|
358
|
+
require("scvelo", feature="velocity PAGA plot")
|
|
359
|
+
import scvelo as scv
|
|
360
|
+
|
|
361
|
+
if "velocity_graph" not in adata.uns:
|
|
362
|
+
raise DataNotFoundError("velocity_graph required. Run velocity analysis first.")
|
|
363
|
+
|
|
364
|
+
cluster_key = params.cluster_key
|
|
365
|
+
if not cluster_key:
|
|
366
|
+
if "paga" in adata.uns and "groups" in adata.uns.get("paga", {}):
|
|
367
|
+
cluster_key = adata.uns["paga"].get("groups")
|
|
368
|
+
else:
|
|
369
|
+
categorical_cols = get_categorical_columns(adata)
|
|
370
|
+
if categorical_cols:
|
|
371
|
+
cluster_key = categorical_cols[0]
|
|
372
|
+
|
|
373
|
+
if not cluster_key or cluster_key not in adata.obs.columns:
|
|
374
|
+
raise ParameterError(
|
|
375
|
+
f"cluster_key is required for PAGA plot. "
|
|
376
|
+
f"Available columns: {list(adata.obs.columns)[:10]}"
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Compute PAGA if not already done
|
|
380
|
+
if "paga" not in adata.uns:
|
|
381
|
+
if context:
|
|
382
|
+
await context.info(f"Computing PAGA for cluster_key='{cluster_key}'")
|
|
383
|
+
import scanpy as sc
|
|
384
|
+
|
|
385
|
+
sc.tl.paga(adata, groups=cluster_key)
|
|
386
|
+
scv.tl.paga(adata, groups=cluster_key)
|
|
387
|
+
|
|
388
|
+
if context:
|
|
389
|
+
await context.info(f"Creating velocity PAGA plot for '{cluster_key}'")
|
|
390
|
+
|
|
391
|
+
basis = infer_basis(adata, preferred=params.basis, priority=["umap", "spatial"])
|
|
392
|
+
fig, axes = create_figure_from_params(params, "velocity")
|
|
393
|
+
ax = axes[0]
|
|
394
|
+
|
|
395
|
+
scv.pl.paga(
|
|
396
|
+
adata,
|
|
397
|
+
basis=basis,
|
|
398
|
+
color=cluster_key,
|
|
399
|
+
ax=ax,
|
|
400
|
+
show=False,
|
|
401
|
+
frameon=params.show_axes,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if params.title:
|
|
405
|
+
ax.set_title(params.title, fontsize=14)
|
|
406
|
+
|
|
407
|
+
if basis == "spatial":
|
|
408
|
+
ax.invert_yaxis()
|
|
409
|
+
|
|
410
|
+
plt.tight_layout()
|
|
411
|
+
return fig
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for spatial transcriptomics data analysis.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .adata_utils import ( # Constants; Field discovery; Data access; Validation; Ensure; Standardization
|
|
6
|
+
ALTERNATIVE_BATCH_KEYS,
|
|
7
|
+
ALTERNATIVE_CELL_TYPE_KEYS,
|
|
8
|
+
ALTERNATIVE_CLUSTER_KEYS,
|
|
9
|
+
ALTERNATIVE_SPATIAL_KEYS,
|
|
10
|
+
BATCH_KEY,
|
|
11
|
+
CELL_TYPE_KEY,
|
|
12
|
+
CLUSTER_KEY,
|
|
13
|
+
SPATIAL_KEY,
|
|
14
|
+
ensure_categorical,
|
|
15
|
+
ensure_counts_layer,
|
|
16
|
+
find_common_genes,
|
|
17
|
+
get_analysis_parameter,
|
|
18
|
+
get_batch_key,
|
|
19
|
+
get_cell_type_key,
|
|
20
|
+
get_cluster_key,
|
|
21
|
+
get_gene_expression,
|
|
22
|
+
get_genes_expression,
|
|
23
|
+
get_spatial_key,
|
|
24
|
+
standardize_adata,
|
|
25
|
+
to_dense,
|
|
26
|
+
validate_adata,
|
|
27
|
+
validate_adata_basics,
|
|
28
|
+
validate_gene_overlap,
|
|
29
|
+
validate_obs_column,
|
|
30
|
+
validate_var_column,
|
|
31
|
+
)
|
|
32
|
+
from .dependency_manager import (
|
|
33
|
+
DependencyInfo,
|
|
34
|
+
get,
|
|
35
|
+
is_available,
|
|
36
|
+
require,
|
|
37
|
+
validate_r_environment,
|
|
38
|
+
validate_scvi_tools,
|
|
39
|
+
)
|
|
40
|
+
from .device_utils import (
|
|
41
|
+
cuda_available,
|
|
42
|
+
get_device,
|
|
43
|
+
get_ot_backend,
|
|
44
|
+
mps_available,
|
|
45
|
+
resolve_device_async,
|
|
46
|
+
)
|
|
47
|
+
from .exceptions import (
|
|
48
|
+
ChatSpatialError,
|
|
49
|
+
DataCompatibilityError,
|
|
50
|
+
DataError,
|
|
51
|
+
DataNotFoundError,
|
|
52
|
+
DependencyError,
|
|
53
|
+
ParameterError,
|
|
54
|
+
ProcessingError,
|
|
55
|
+
)
|
|
56
|
+
from .mcp_utils import mcp_tool_error_handler, suppress_output
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
# Exceptions
|
|
60
|
+
"ChatSpatialError",
|
|
61
|
+
"DataError",
|
|
62
|
+
"DataNotFoundError",
|
|
63
|
+
"DataCompatibilityError",
|
|
64
|
+
"ParameterError",
|
|
65
|
+
"ProcessingError",
|
|
66
|
+
"DependencyError",
|
|
67
|
+
# MCP utilities
|
|
68
|
+
"suppress_output",
|
|
69
|
+
"mcp_tool_error_handler",
|
|
70
|
+
# Constants
|
|
71
|
+
"SPATIAL_KEY",
|
|
72
|
+
"CELL_TYPE_KEY",
|
|
73
|
+
"CLUSTER_KEY",
|
|
74
|
+
"BATCH_KEY",
|
|
75
|
+
"ALTERNATIVE_SPATIAL_KEYS",
|
|
76
|
+
"ALTERNATIVE_CELL_TYPE_KEYS",
|
|
77
|
+
"ALTERNATIVE_CLUSTER_KEYS",
|
|
78
|
+
"ALTERNATIVE_BATCH_KEYS",
|
|
79
|
+
# Field discovery
|
|
80
|
+
"get_analysis_parameter",
|
|
81
|
+
"get_batch_key",
|
|
82
|
+
"get_cell_type_key",
|
|
83
|
+
"get_cluster_key",
|
|
84
|
+
"get_spatial_key",
|
|
85
|
+
# Expression extraction
|
|
86
|
+
"to_dense",
|
|
87
|
+
"get_gene_expression",
|
|
88
|
+
"get_genes_expression",
|
|
89
|
+
# Validation
|
|
90
|
+
"validate_adata",
|
|
91
|
+
"validate_obs_column",
|
|
92
|
+
"validate_var_column",
|
|
93
|
+
"validate_adata_basics",
|
|
94
|
+
"validate_gene_overlap",
|
|
95
|
+
"ensure_categorical",
|
|
96
|
+
# Gene overlap
|
|
97
|
+
"find_common_genes",
|
|
98
|
+
# Ensure
|
|
99
|
+
"ensure_counts_layer",
|
|
100
|
+
# Standardization
|
|
101
|
+
"standardize_adata",
|
|
102
|
+
# Dependency management
|
|
103
|
+
"DependencyInfo",
|
|
104
|
+
"require",
|
|
105
|
+
"get",
|
|
106
|
+
"is_available",
|
|
107
|
+
"validate_r_environment",
|
|
108
|
+
"validate_scvi_tools",
|
|
109
|
+
# Device utilities
|
|
110
|
+
"cuda_available",
|
|
111
|
+
"mps_available",
|
|
112
|
+
"get_device",
|
|
113
|
+
"resolve_device_async",
|
|
114
|
+
"get_ot_backend",
|
|
115
|
+
]
|