chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RNA velocity analysis for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module computes RNA velocity to infer the direction of cellular state changes
|
|
5
|
+
by analyzing the balance of spliced and unspliced mRNA counts.
|
|
6
|
+
|
|
7
|
+
Key functionality:
|
|
8
|
+
- `analyze_rna_velocity`: Main MCP entry point for velocity analysis
|
|
9
|
+
- Supports scVelo (standard) and VELOVI (deep learning) methods
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
18
|
+
|
|
19
|
+
from ..models.analysis import RNAVelocityResult
|
|
20
|
+
from ..models.data import RNAVelocityParameters
|
|
21
|
+
from ..utils.adata_utils import validate_adata
|
|
22
|
+
from ..utils.dependency_manager import require
|
|
23
|
+
from ..utils.exceptions import (
|
|
24
|
+
DataError,
|
|
25
|
+
DataNotFoundError,
|
|
26
|
+
ParameterError,
|
|
27
|
+
ProcessingError,
|
|
28
|
+
)
|
|
29
|
+
from ..utils.mcp_utils import suppress_output
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def preprocess_for_velocity(
|
|
33
|
+
adata, min_shared_counts=30, n_top_genes=2000, n_pcs=30, n_neighbors=30, params=None
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Prepares an AnnData object for RNA velocity analysis using the scVelo pipeline.
|
|
37
|
+
|
|
38
|
+
This function performs the standard scVelo preprocessing workflow:
|
|
39
|
+
1. Filtering genes based on minimum shared counts between spliced and
|
|
40
|
+
unspliced layers.
|
|
41
|
+
2. Normalizing the data.
|
|
42
|
+
3. Selecting a subset of highly variable genes.
|
|
43
|
+
4. Computing first and second-order moments across nearest neighbors.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
adata : AnnData
|
|
48
|
+
The annotated data matrix with 'spliced' and 'unspliced' layers.
|
|
49
|
+
min_shared_counts : int, default 30
|
|
50
|
+
Minimum number of counts shared between spliced and unspliced layers.
|
|
51
|
+
n_top_genes : int, default 2000
|
|
52
|
+
Number of highly variable genes to use.
|
|
53
|
+
n_pcs : int, default 30
|
|
54
|
+
Number of principal components to compute.
|
|
55
|
+
n_neighbors : int, default 30
|
|
56
|
+
Number of nearest neighbors for moment computation.
|
|
57
|
+
params : RNAVelocityParameters, optional
|
|
58
|
+
If provided, overrides the individual parameters.
|
|
59
|
+
"""
|
|
60
|
+
import scvelo as scv
|
|
61
|
+
|
|
62
|
+
# If params object is provided, use its values
|
|
63
|
+
if params is not None:
|
|
64
|
+
from ..models.data import RNAVelocityParameters
|
|
65
|
+
|
|
66
|
+
if isinstance(params, RNAVelocityParameters):
|
|
67
|
+
min_shared_counts = params.min_shared_counts
|
|
68
|
+
n_top_genes = params.n_top_genes
|
|
69
|
+
n_pcs = params.n_pcs
|
|
70
|
+
n_neighbors = params.n_neighbors
|
|
71
|
+
|
|
72
|
+
# Validate velocity data
|
|
73
|
+
try:
|
|
74
|
+
validate_adata(adata, {}, check_velocity=True)
|
|
75
|
+
except DataNotFoundError as e:
|
|
76
|
+
raise DataError(f"Invalid velocity data: {e}") from e
|
|
77
|
+
|
|
78
|
+
# Standard preprocessing with configurable parameters
|
|
79
|
+
scv.pp.filter_and_normalize(
|
|
80
|
+
adata, min_shared_counts=min_shared_counts, n_top_genes=n_top_genes
|
|
81
|
+
)
|
|
82
|
+
scv.pp.moments(adata, n_pcs=n_pcs, n_neighbors=n_neighbors)
|
|
83
|
+
|
|
84
|
+
return adata
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compute_rna_velocity(adata, mode="stochastic", params=None):
|
|
88
|
+
"""
|
|
89
|
+
Computes RNA velocity to infer the direction of cellular differentiation.
|
|
90
|
+
|
|
91
|
+
This function executes the core RNA velocity workflow:
|
|
92
|
+
1. Ensures preprocessing (moment computation) is complete.
|
|
93
|
+
2. Estimates RNA velocity using the specified model.
|
|
94
|
+
3. Constructs a velocity graph for cell-to-cell transitions.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
adata : AnnData
|
|
99
|
+
The annotated data matrix with 'spliced' and 'unspliced' layers.
|
|
100
|
+
mode : str, default 'stochastic'
|
|
101
|
+
The model for velocity estimation:
|
|
102
|
+
- 'stochastic': Likelihood-based model accounting for noise.
|
|
103
|
+
- 'deterministic': Simpler steady-state model.
|
|
104
|
+
- 'dynamical': Full transcriptional dynamics with ODE fitting.
|
|
105
|
+
params : RNAVelocityParameters, optional
|
|
106
|
+
Parameter object (mode will be extracted from params.scvelo_mode).
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
AnnData
|
|
111
|
+
Updated with velocity vectors and graph.
|
|
112
|
+
"""
|
|
113
|
+
import scvelo as scv
|
|
114
|
+
|
|
115
|
+
# Use params for mode if provided
|
|
116
|
+
if params is not None:
|
|
117
|
+
from ..models.data import RNAVelocityParameters
|
|
118
|
+
|
|
119
|
+
if isinstance(params, RNAVelocityParameters):
|
|
120
|
+
mode = params.scvelo_mode
|
|
121
|
+
|
|
122
|
+
# Check if preprocessing is needed
|
|
123
|
+
if "Ms" not in adata.layers or "Mu" not in adata.layers:
|
|
124
|
+
adata = preprocess_for_velocity(adata, params=params)
|
|
125
|
+
|
|
126
|
+
# Compute velocity based on mode
|
|
127
|
+
if mode == "dynamical":
|
|
128
|
+
scv.tl.recover_dynamics(adata)
|
|
129
|
+
scv.tl.velocity(adata, mode="dynamical")
|
|
130
|
+
# Compute latent time (required for gene_trends visualization)
|
|
131
|
+
scv.tl.latent_time(adata)
|
|
132
|
+
else:
|
|
133
|
+
scv.tl.velocity(adata, mode=mode)
|
|
134
|
+
|
|
135
|
+
# Compute velocity graph
|
|
136
|
+
scv.tl.velocity_graph(adata)
|
|
137
|
+
|
|
138
|
+
return adata
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def _prepare_velovi_data(adata, ctx: "ToolContext"):
|
|
142
|
+
"""Prepare data for VELOVI according to official standards."""
|
|
143
|
+
import scvelo as scv
|
|
144
|
+
|
|
145
|
+
adata_velovi = adata.copy()
|
|
146
|
+
|
|
147
|
+
# Convert layer names to VELOVI standards
|
|
148
|
+
if "spliced" in adata_velovi.layers and "unspliced" in adata_velovi.layers:
|
|
149
|
+
adata_velovi.layers["Ms"] = adata_velovi.layers["spliced"]
|
|
150
|
+
adata_velovi.layers["Mu"] = adata_velovi.layers["unspliced"]
|
|
151
|
+
else:
|
|
152
|
+
raise DataNotFoundError("Missing required 'spliced' and 'unspliced' layers")
|
|
153
|
+
|
|
154
|
+
# scvelo preprocessing
|
|
155
|
+
try:
|
|
156
|
+
scv.pp.filter_and_normalize(
|
|
157
|
+
adata_velovi, min_shared_counts=30, n_top_genes=2000, enforce=False
|
|
158
|
+
)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
await ctx.warning(f"scvelo preprocessing warning: {e}")
|
|
161
|
+
|
|
162
|
+
# Compute moments
|
|
163
|
+
try:
|
|
164
|
+
scv.pp.moments(adata_velovi, n_pcs=30, n_neighbors=30)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
await ctx.warning(f"moments computation warning: {e}")
|
|
167
|
+
|
|
168
|
+
return adata_velovi
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _validate_velovi_data(adata):
|
|
172
|
+
"""VELOVI-specific data validation."""
|
|
173
|
+
if "Ms" not in adata.layers or "Mu" not in adata.layers:
|
|
174
|
+
raise DataNotFoundError("Missing required layers 'Ms' and 'Mu' for VELOVI")
|
|
175
|
+
|
|
176
|
+
ms_data = adata.layers["Ms"]
|
|
177
|
+
mu_data = adata.layers["Mu"]
|
|
178
|
+
|
|
179
|
+
if ms_data.shape != mu_data.shape:
|
|
180
|
+
raise DataError(f"Shape mismatch: Ms {ms_data.shape} vs Mu {mu_data.shape}")
|
|
181
|
+
|
|
182
|
+
if ms_data.ndim != 2 or mu_data.ndim != 2:
|
|
183
|
+
raise DataError(
|
|
184
|
+
f"Expected 2D arrays, got Ms:{ms_data.ndim}D, Mu:{mu_data.ndim}D"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return True
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
async def analyze_velocity_with_velovi(
|
|
191
|
+
adata,
|
|
192
|
+
n_epochs: int = 1000,
|
|
193
|
+
n_hidden: int = 128,
|
|
194
|
+
n_latent: int = 10,
|
|
195
|
+
use_gpu: bool = False,
|
|
196
|
+
ctx: Optional["ToolContext"] = None,
|
|
197
|
+
) -> dict[str, Any]:
|
|
198
|
+
"""
|
|
199
|
+
Analyzes RNA velocity using the deep learning model VELOVI.
|
|
200
|
+
|
|
201
|
+
VELOVI (Velocity Variational Inference) is a probabilistic deep generative model
|
|
202
|
+
that estimates transcriptional dynamics from spliced and unspliced mRNA counts.
|
|
203
|
+
It provides velocity vectors with uncertainty quantification.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
adata: AnnData with 'spliced' and 'unspliced' layers.
|
|
207
|
+
n_epochs: Number of training epochs.
|
|
208
|
+
n_hidden: Number of hidden units in neural network layers.
|
|
209
|
+
n_latent: Dimensionality of the latent space.
|
|
210
|
+
use_gpu: If True, use GPU for training.
|
|
211
|
+
ctx: ToolContext for logging.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Dictionary with VELOVI results and metadata.
|
|
215
|
+
"""
|
|
216
|
+
try:
|
|
217
|
+
require("scvi", feature="VELOVI velocity analysis")
|
|
218
|
+
from scvi.external import VELOVI
|
|
219
|
+
|
|
220
|
+
# Data preprocessing
|
|
221
|
+
adata_prepared = await _prepare_velovi_data(adata, ctx)
|
|
222
|
+
|
|
223
|
+
# Data validation
|
|
224
|
+
_validate_velovi_data(adata_prepared)
|
|
225
|
+
|
|
226
|
+
# VELOVI setup
|
|
227
|
+
VELOVI.setup_anndata(
|
|
228
|
+
adata_prepared,
|
|
229
|
+
spliced_layer="Ms",
|
|
230
|
+
unspliced_layer="Mu",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Model creation
|
|
234
|
+
velovi_model = VELOVI(adata_prepared, n_hidden=n_hidden, n_latent=n_latent)
|
|
235
|
+
|
|
236
|
+
# Model training
|
|
237
|
+
if use_gpu:
|
|
238
|
+
velovi_model.train(max_epochs=n_epochs, accelerator="gpu")
|
|
239
|
+
else:
|
|
240
|
+
velovi_model.train(max_epochs=n_epochs)
|
|
241
|
+
|
|
242
|
+
# Result extraction
|
|
243
|
+
latent_time = velovi_model.get_latent_time(n_samples=25)
|
|
244
|
+
velocities = velovi_model.get_velocity(n_samples=25, velo_statistic="mean")
|
|
245
|
+
latent_repr = velovi_model.get_latent_representation()
|
|
246
|
+
|
|
247
|
+
# Handle pandas/numpy compatibility
|
|
248
|
+
if hasattr(latent_time, "values"):
|
|
249
|
+
latent_time = latent_time.values
|
|
250
|
+
if hasattr(velocities, "values"):
|
|
251
|
+
velocities = velocities.values
|
|
252
|
+
|
|
253
|
+
# Ensure numpy array format
|
|
254
|
+
latent_time = np.asarray(latent_time)
|
|
255
|
+
velocities = np.asarray(velocities)
|
|
256
|
+
latent_repr = np.asarray(latent_repr)
|
|
257
|
+
|
|
258
|
+
# Safe scaling calculation
|
|
259
|
+
t = latent_time
|
|
260
|
+
if t.ndim > 1:
|
|
261
|
+
t_max = np.max(t, axis=0)
|
|
262
|
+
if np.all(t_max > 0):
|
|
263
|
+
scaling = 20 / t_max
|
|
264
|
+
else:
|
|
265
|
+
scaling = np.where(t_max > 0, 20 / t_max, 1.0)
|
|
266
|
+
else:
|
|
267
|
+
t_max = np.max(t)
|
|
268
|
+
scaling = 20 / t_max if t_max > 0 else 1.0
|
|
269
|
+
|
|
270
|
+
if hasattr(scaling, "to_numpy"):
|
|
271
|
+
scaling = scaling.to_numpy()
|
|
272
|
+
scaling = np.asarray(scaling)
|
|
273
|
+
|
|
274
|
+
# Calculate scaled velocities
|
|
275
|
+
if scaling.ndim == 0:
|
|
276
|
+
scaled_velocities = velocities / scaling
|
|
277
|
+
elif scaling.ndim == 1 and velocities.ndim == 2:
|
|
278
|
+
scaled_velocities = velocities / scaling[np.newaxis, :]
|
|
279
|
+
else:
|
|
280
|
+
scaled_velocities = velocities / scaling
|
|
281
|
+
|
|
282
|
+
# Store results in preprocessed data object
|
|
283
|
+
adata_prepared.layers["velocity_velovi"] = scaled_velocities
|
|
284
|
+
adata_prepared.layers["latent_time_velovi"] = latent_time
|
|
285
|
+
adata_prepared.obsm["X_velovi_latent"] = latent_repr
|
|
286
|
+
|
|
287
|
+
# Calculate velocity statistics
|
|
288
|
+
velocity_norm = np.linalg.norm(scaled_velocities, axis=1)
|
|
289
|
+
adata_prepared.obs["velocity_velovi_norm"] = velocity_norm
|
|
290
|
+
|
|
291
|
+
# Transfer key information back to original adata
|
|
292
|
+
adata.obs["velocity_velovi_norm"] = velocity_norm
|
|
293
|
+
adata.obsm["X_velovi_latent"] = latent_repr
|
|
294
|
+
|
|
295
|
+
# Store preprocessed data in uns for future use
|
|
296
|
+
adata.uns["velovi_adata"] = adata_prepared
|
|
297
|
+
adata.uns["velovi_gene_names"] = adata_prepared.var_names.tolist()
|
|
298
|
+
|
|
299
|
+
return {
|
|
300
|
+
"method": "VELOVI",
|
|
301
|
+
"velocity_computed": True,
|
|
302
|
+
"n_epochs": n_epochs,
|
|
303
|
+
"n_hidden": n_hidden,
|
|
304
|
+
"n_latent": n_latent,
|
|
305
|
+
"velocity_shape": scaled_velocities.shape,
|
|
306
|
+
"latent_time_shape": latent_time.shape,
|
|
307
|
+
"latent_repr_shape": latent_repr.shape,
|
|
308
|
+
"velocity_mean_norm": float(velocity_norm.mean()),
|
|
309
|
+
"velocity_std_norm": float(velocity_norm.std()),
|
|
310
|
+
"n_genes_analyzed": adata_prepared.n_vars,
|
|
311
|
+
"original_n_genes": adata.n_vars,
|
|
312
|
+
"training_completed": True,
|
|
313
|
+
"device": "GPU" if use_gpu else "CPU",
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
raise ProcessingError(f"VELOVI velocity analysis failed: {e}") from e
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
async def analyze_rna_velocity(
|
|
321
|
+
data_id: str,
|
|
322
|
+
ctx: "ToolContext",
|
|
323
|
+
params: RNAVelocityParameters = RNAVelocityParameters(),
|
|
324
|
+
) -> RNAVelocityResult:
|
|
325
|
+
"""
|
|
326
|
+
Computes RNA velocity for spatial transcriptomics data.
|
|
327
|
+
|
|
328
|
+
This is the main MCP entry point for velocity analysis. It requires
|
|
329
|
+
'spliced' and 'unspliced' count layers in the input dataset.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
data_id: Dataset identifier.
|
|
333
|
+
ctx: ToolContext for data access and logging.
|
|
334
|
+
params: RNA velocity parameters.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
RNAVelocityResult with computation metadata.
|
|
338
|
+
|
|
339
|
+
Raises:
|
|
340
|
+
DataNotFoundError: If data lacks required layers.
|
|
341
|
+
ProcessingError: If velocity computation fails.
|
|
342
|
+
"""
|
|
343
|
+
require("scvelo")
|
|
344
|
+
import scvelo as scv # noqa: F401
|
|
345
|
+
|
|
346
|
+
# Get AnnData object
|
|
347
|
+
adata = await ctx.get_adata(data_id)
|
|
348
|
+
|
|
349
|
+
# Validate data for velocity analysis
|
|
350
|
+
try:
|
|
351
|
+
validate_adata(adata, {}, check_velocity=True)
|
|
352
|
+
except DataNotFoundError as e:
|
|
353
|
+
raise DataNotFoundError(
|
|
354
|
+
f"Missing velocity data: {e}. Requires 'spliced' and 'unspliced' layers."
|
|
355
|
+
) from e
|
|
356
|
+
|
|
357
|
+
velocity_computed = False
|
|
358
|
+
velocity_method_used = params.method
|
|
359
|
+
|
|
360
|
+
# Dispatch based on method
|
|
361
|
+
if params.method == "scvelo":
|
|
362
|
+
with suppress_output():
|
|
363
|
+
try:
|
|
364
|
+
adata = compute_rna_velocity(
|
|
365
|
+
adata, mode=params.scvelo_mode, params=params
|
|
366
|
+
)
|
|
367
|
+
velocity_computed = True
|
|
368
|
+
except Exception as e:
|
|
369
|
+
raise ProcessingError(
|
|
370
|
+
f"scVelo RNA velocity analysis failed: {e}"
|
|
371
|
+
) from e
|
|
372
|
+
|
|
373
|
+
elif params.method == "velovi":
|
|
374
|
+
require("scvi", feature="VELOVI velocity analysis")
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
velovi_results = await analyze_velocity_with_velovi(
|
|
378
|
+
adata,
|
|
379
|
+
n_epochs=params.velovi_n_epochs,
|
|
380
|
+
n_hidden=params.velovi_n_hidden,
|
|
381
|
+
n_latent=params.velovi_n_latent,
|
|
382
|
+
use_gpu=params.velovi_use_gpu,
|
|
383
|
+
ctx=ctx,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if velovi_results.get("velocity_computed", False):
|
|
387
|
+
velocity_computed = True
|
|
388
|
+
if "velovi_adata" in adata.uns:
|
|
389
|
+
adata.uns["velocity_graph"] = True
|
|
390
|
+
adata.uns["velocity_method"] = "velovi"
|
|
391
|
+
else:
|
|
392
|
+
raise ProcessingError("VELOVI failed to compute velocity")
|
|
393
|
+
|
|
394
|
+
except Exception as e:
|
|
395
|
+
raise ProcessingError(f"VELOVI velocity analysis failed: {e}") from e
|
|
396
|
+
|
|
397
|
+
else:
|
|
398
|
+
raise ParameterError(f"Unknown velocity method: {params.method}")
|
|
399
|
+
|
|
400
|
+
return RNAVelocityResult(
|
|
401
|
+
data_id=data_id,
|
|
402
|
+
velocity_computed=velocity_computed,
|
|
403
|
+
velocity_graph_key="velocity_graph" if velocity_computed else None,
|
|
404
|
+
mode=velocity_method_used if params.method == "scvelo" else params.method,
|
|
405
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization module for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module provides visualization functions organized by analysis type:
|
|
5
|
+
- basic: Spatial plots, UMAP, heatmaps, violin plots, dotplots
|
|
6
|
+
- deconvolution: Cell type proportion visualizations
|
|
7
|
+
- cell_comm: Cell-cell communication visualizations
|
|
8
|
+
- velocity: RNA velocity visualizations
|
|
9
|
+
- trajectory: Trajectory and pseudotime visualizations
|
|
10
|
+
- spatial_stats: Spatial statistics visualizations
|
|
11
|
+
- enrichment: Pathway enrichment visualizations
|
|
12
|
+
- cnv: Copy number variation visualizations
|
|
13
|
+
- integration: Batch integration quality visualizations
|
|
14
|
+
- persistence: Visualization saving and export
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
from chatspatial.tools.visualization import (
|
|
18
|
+
create_spatial_visualization,
|
|
19
|
+
create_umap_visualization,
|
|
20
|
+
create_deconvolution_visualization,
|
|
21
|
+
# ... etc
|
|
22
|
+
)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Basic visualizations
|
|
26
|
+
from .basic import (
|
|
27
|
+
create_dotplot_visualization,
|
|
28
|
+
create_heatmap_visualization,
|
|
29
|
+
create_spatial_visualization,
|
|
30
|
+
create_umap_visualization,
|
|
31
|
+
create_violin_visualization,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Cell communication visualizations
|
|
35
|
+
from .cell_comm import create_cell_communication_visualization
|
|
36
|
+
|
|
37
|
+
# CNV visualizations
|
|
38
|
+
from .cnv import create_cnv_heatmap_visualization, create_spatial_cnv_visualization
|
|
39
|
+
|
|
40
|
+
# Core utilities and data classes
|
|
41
|
+
from .core import (
|
|
42
|
+
FIGURE_DEFAULTS,
|
|
43
|
+
CellCommunicationData,
|
|
44
|
+
DeconvolutionData,
|
|
45
|
+
add_colorbar,
|
|
46
|
+
create_figure,
|
|
47
|
+
create_figure_from_params,
|
|
48
|
+
get_categorical_cmap,
|
|
49
|
+
get_category_colors,
|
|
50
|
+
get_colormap,
|
|
51
|
+
get_diverging_colormap,
|
|
52
|
+
get_validated_features,
|
|
53
|
+
plot_spatial_feature,
|
|
54
|
+
resolve_figure_size,
|
|
55
|
+
setup_multi_panel_figure,
|
|
56
|
+
validate_and_prepare_feature,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# CARD imputation (from deconvolution module)
|
|
60
|
+
# Deconvolution visualizations
|
|
61
|
+
from .deconvolution import (
|
|
62
|
+
create_card_imputation_visualization,
|
|
63
|
+
create_deconvolution_visualization,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Enrichment visualizations
|
|
67
|
+
from .enrichment import (
|
|
68
|
+
create_enrichment_visualization,
|
|
69
|
+
create_pathway_enrichment_visualization,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Batch integration visualizations
|
|
73
|
+
from .integration import create_batch_integration_visualization
|
|
74
|
+
|
|
75
|
+
# Main entry point and handler registry (from main.py to avoid circular imports)
|
|
76
|
+
from .main import PLOT_HANDLERS, visualize_data
|
|
77
|
+
|
|
78
|
+
# Multi-gene visualizations
|
|
79
|
+
from .multi_gene import (
|
|
80
|
+
create_gene_correlation_visualization,
|
|
81
|
+
create_lr_pairs_visualization,
|
|
82
|
+
create_multi_gene_umap_visualization,
|
|
83
|
+
create_multi_gene_visualization,
|
|
84
|
+
create_spatial_interaction_visualization,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Persistence functions
|
|
88
|
+
from .persistence import (
|
|
89
|
+
clear_visualization_cache,
|
|
90
|
+
export_all_visualizations,
|
|
91
|
+
save_visualization,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Spatial statistics visualizations
|
|
95
|
+
from .spatial_stats import create_spatial_statistics_visualization
|
|
96
|
+
|
|
97
|
+
# Trajectory visualizations
|
|
98
|
+
from .trajectory import create_trajectory_visualization
|
|
99
|
+
|
|
100
|
+
# RNA velocity visualizations
|
|
101
|
+
from .velocity import create_rna_velocity_visualization
|
|
102
|
+
|
|
103
|
+
__all__ = [
|
|
104
|
+
# Core utilities
|
|
105
|
+
"FIGURE_DEFAULTS",
|
|
106
|
+
"create_figure",
|
|
107
|
+
"create_figure_from_params",
|
|
108
|
+
"resolve_figure_size",
|
|
109
|
+
"setup_multi_panel_figure",
|
|
110
|
+
"add_colorbar",
|
|
111
|
+
"get_colormap",
|
|
112
|
+
"get_categorical_cmap",
|
|
113
|
+
"get_category_colors",
|
|
114
|
+
"get_diverging_colormap",
|
|
115
|
+
"plot_spatial_feature",
|
|
116
|
+
"get_validated_features",
|
|
117
|
+
"validate_and_prepare_feature",
|
|
118
|
+
# Data classes
|
|
119
|
+
"DeconvolutionData",
|
|
120
|
+
"CellCommunicationData",
|
|
121
|
+
# Basic visualizations
|
|
122
|
+
"create_spatial_visualization",
|
|
123
|
+
"create_umap_visualization",
|
|
124
|
+
"create_heatmap_visualization",
|
|
125
|
+
"create_violin_visualization",
|
|
126
|
+
"create_dotplot_visualization",
|
|
127
|
+
# Specialized visualizations
|
|
128
|
+
"create_deconvolution_visualization",
|
|
129
|
+
"create_cell_communication_visualization",
|
|
130
|
+
"create_rna_velocity_visualization",
|
|
131
|
+
"create_trajectory_visualization",
|
|
132
|
+
"create_spatial_statistics_visualization",
|
|
133
|
+
"create_enrichment_visualization",
|
|
134
|
+
"create_pathway_enrichment_visualization",
|
|
135
|
+
# CNV visualizations
|
|
136
|
+
"create_card_imputation_visualization",
|
|
137
|
+
"create_spatial_cnv_visualization",
|
|
138
|
+
"create_cnv_heatmap_visualization",
|
|
139
|
+
# Integration visualizations
|
|
140
|
+
"create_batch_integration_visualization",
|
|
141
|
+
# Multi-gene visualizations
|
|
142
|
+
"create_multi_gene_visualization",
|
|
143
|
+
"create_multi_gene_umap_visualization",
|
|
144
|
+
"create_lr_pairs_visualization",
|
|
145
|
+
"create_gene_correlation_visualization",
|
|
146
|
+
"create_spatial_interaction_visualization",
|
|
147
|
+
# Persistence functions
|
|
148
|
+
"save_visualization",
|
|
149
|
+
"export_all_visualizations",
|
|
150
|
+
"clear_visualization_cache",
|
|
151
|
+
# Main entry point
|
|
152
|
+
"visualize_data",
|
|
153
|
+
# Handler registry
|
|
154
|
+
"PLOT_HANDLERS",
|
|
155
|
+
]
|