chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trajectory inference for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module infers cellular trajectories and pseudotime by combining
|
|
5
|
+
expression patterns with optional velocity and spatial information.
|
|
6
|
+
|
|
7
|
+
Key functionality:
|
|
8
|
+
- `analyze_trajectory`: Main MCP entry point for trajectory analysis
|
|
9
|
+
- Supports CellRank (velocity-based), Palantir (expression-based), and DPT (diffusion-based)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
18
|
+
|
|
19
|
+
from ..models.analysis import TrajectoryResult
|
|
20
|
+
from ..models.data import TrajectoryParameters
|
|
21
|
+
from ..utils.adata_utils import (
|
|
22
|
+
get_spatial_key,
|
|
23
|
+
require_spatial_coords,
|
|
24
|
+
validate_obs_column,
|
|
25
|
+
)
|
|
26
|
+
from ..utils.compute import ensure_diffmap, ensure_neighbors, ensure_pca
|
|
27
|
+
from ..utils.dependency_manager import require
|
|
28
|
+
from ..utils.exceptions import (
|
|
29
|
+
DataError,
|
|
30
|
+
DataNotFoundError,
|
|
31
|
+
ParameterError,
|
|
32
|
+
ProcessingError,
|
|
33
|
+
)
|
|
34
|
+
from ..utils.mcp_utils import suppress_output
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def prepare_gam_model_for_visualization(
|
|
38
|
+
adata,
|
|
39
|
+
genes: list,
|
|
40
|
+
time_key: str = "latent_time",
|
|
41
|
+
fate_key: str = "lineages_fwd",
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Prepare a GAM model for CellRank gene trends visualization.
|
|
45
|
+
|
|
46
|
+
This function handles the computation logic needed for CellRank 2.0 gene trends
|
|
47
|
+
and fate heatmap visualizations. Requires data analyzed via analyze_rna_velocity
|
|
48
|
+
(dynamical mode) and analyze_trajectory (cellrank method).
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
adata : AnnData
|
|
53
|
+
The annotated data matrix with CellRank results.
|
|
54
|
+
genes : list
|
|
55
|
+
List of gene names to prepare the model for.
|
|
56
|
+
time_key : str, default 'latent_time'
|
|
57
|
+
Key in adata.obs for pseudotime/latent time values.
|
|
58
|
+
fate_key : str, default 'lineages_fwd'
|
|
59
|
+
Key in adata.obsm for fate probabilities.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
tuple
|
|
64
|
+
(model, lineage_names) - The GAM model and list of lineage names.
|
|
65
|
+
"""
|
|
66
|
+
require("cellrank")
|
|
67
|
+
from cellrank.models import GAM
|
|
68
|
+
|
|
69
|
+
# Validate required data
|
|
70
|
+
validate_obs_column(adata, time_key, "Time")
|
|
71
|
+
|
|
72
|
+
if fate_key not in adata.obsm:
|
|
73
|
+
raise DataNotFoundError(
|
|
74
|
+
f"Fate probabilities '{fate_key}' not found. Run analyze_trajectory first."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Validate Lineage object has names
|
|
78
|
+
fate_probs = adata.obsm[fate_key]
|
|
79
|
+
if not hasattr(fate_probs, "names") or fate_probs.names is None:
|
|
80
|
+
raise DataError(
|
|
81
|
+
"Fate probabilities must be a CellRank Lineage object with names. "
|
|
82
|
+
"This requires running the full analysis pipeline in memory:\n"
|
|
83
|
+
"1. analyze_rna_velocity(data_id, params={'scvelo_mode': 'dynamical'})\n"
|
|
84
|
+
"2. analyze_trajectory(data_id, params={'method': 'cellrank'})\n"
|
|
85
|
+
"3. Then visualize with plot_type='trajectory', subtype='gene_trends'"
|
|
86
|
+
)
|
|
87
|
+
lineage_names = list(fate_probs.names)
|
|
88
|
+
|
|
89
|
+
# Validate genes exist
|
|
90
|
+
missing_genes = [g for g in genes if g not in adata.var_names]
|
|
91
|
+
if missing_genes:
|
|
92
|
+
raise DataNotFoundError(
|
|
93
|
+
f"Genes not found in data: {missing_genes}. "
|
|
94
|
+
f"Available genes: {list(adata.var_names[:10])}..."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
model = GAM(adata)
|
|
98
|
+
return model, lineage_names
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def infer_spatial_trajectory_cellrank(
|
|
102
|
+
adata, spatial_weight=0.5, kernel_weights=(0.8, 0.2), n_states=5
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Infers cellular trajectories by combining RNA velocity with CellRank.
|
|
106
|
+
|
|
107
|
+
This function uses CellRank to model cell-state transitions by constructing
|
|
108
|
+
a transition matrix from multiple kernels:
|
|
109
|
+
1. A velocity kernel from RNA velocity.
|
|
110
|
+
2. A connectivity kernel based on transcriptomic similarity.
|
|
111
|
+
3. (Optional) A spatial kernel based on physical proximity.
|
|
112
|
+
|
|
113
|
+
Raises ProcessingError if CellRank computation fails.
|
|
114
|
+
"""
|
|
115
|
+
import cellrank as cr
|
|
116
|
+
import numpy as np
|
|
117
|
+
from scipy.sparse import csr_matrix
|
|
118
|
+
from scipy.spatial.distance import pdist, squareform
|
|
119
|
+
|
|
120
|
+
# Check if spatial data is available
|
|
121
|
+
spatial_key = get_spatial_key(adata)
|
|
122
|
+
has_spatial = spatial_key is not None
|
|
123
|
+
|
|
124
|
+
if not has_spatial and spatial_weight > 0:
|
|
125
|
+
spatial_weight = 0
|
|
126
|
+
|
|
127
|
+
# Handle different velocity methods
|
|
128
|
+
if "velocity_method" in adata.uns and adata.uns["velocity_method"] == "velovi":
|
|
129
|
+
if "velovi_adata" in adata.uns:
|
|
130
|
+
adata_for_cellrank = adata.uns["velovi_adata"]
|
|
131
|
+
if has_spatial:
|
|
132
|
+
adata_for_cellrank.obsm["spatial"] = adata.obsm[spatial_key]
|
|
133
|
+
|
|
134
|
+
if "velocity_velovi" in adata_for_cellrank.layers:
|
|
135
|
+
adata_for_cellrank.layers["velocity"] = adata_for_cellrank.layers[
|
|
136
|
+
"velocity_velovi"
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
vk = cr.kernels.VelocityKernel(adata_for_cellrank)
|
|
140
|
+
vk.compute_transition_matrix()
|
|
141
|
+
else:
|
|
142
|
+
raise ProcessingError("VELOVI velocity data not found")
|
|
143
|
+
else:
|
|
144
|
+
adata_for_cellrank = adata
|
|
145
|
+
vk = cr.kernels.VelocityKernel(adata_for_cellrank)
|
|
146
|
+
vk.compute_transition_matrix()
|
|
147
|
+
|
|
148
|
+
# Create connectivity kernel
|
|
149
|
+
ck = cr.kernels.ConnectivityKernel(adata_for_cellrank)
|
|
150
|
+
ck.compute_transition_matrix()
|
|
151
|
+
|
|
152
|
+
# Combine kernels
|
|
153
|
+
vk_weight, ck_weight = kernel_weights
|
|
154
|
+
|
|
155
|
+
if has_spatial and spatial_weight > 0:
|
|
156
|
+
spatial_coords = adata.obsm[spatial_key]
|
|
157
|
+
spatial_dist = squareform(pdist(spatial_coords))
|
|
158
|
+
spatial_sim = np.exp(-spatial_dist / spatial_dist.mean())
|
|
159
|
+
spatial_kernel = csr_matrix(spatial_sim)
|
|
160
|
+
|
|
161
|
+
sk = cr.kernels.PrecomputedKernel(spatial_kernel, adata_for_cellrank)
|
|
162
|
+
sk.compute_transition_matrix()
|
|
163
|
+
|
|
164
|
+
combined_kernel = (1 - spatial_weight) * (
|
|
165
|
+
vk_weight * vk + ck_weight * ck
|
|
166
|
+
) + spatial_weight * sk
|
|
167
|
+
else:
|
|
168
|
+
combined_kernel = vk_weight * vk + ck_weight * ck
|
|
169
|
+
|
|
170
|
+
# GPCCA analysis
|
|
171
|
+
g = cr.estimators.GPCCA(combined_kernel)
|
|
172
|
+
g.compute_eigendecomposition()
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
g.compute_macrostates(n_states=n_states)
|
|
176
|
+
except Exception as e:
|
|
177
|
+
raise ProcessingError(
|
|
178
|
+
f"CellRank failed with n_states={n_states}: {e}. "
|
|
179
|
+
f"Try reducing n_states or use method='palantir'/'dpt'."
|
|
180
|
+
) from e
|
|
181
|
+
|
|
182
|
+
# Predict terminal states
|
|
183
|
+
try:
|
|
184
|
+
g.predict_terminal_states(method="stability")
|
|
185
|
+
except ValueError as e:
|
|
186
|
+
if "No macrostates have been selected" not in str(e):
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
# Check terminal states and compute fate probabilities
|
|
190
|
+
has_terminal_states = (
|
|
191
|
+
hasattr(g, "terminal_states") and g.terminal_states is not None
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if has_terminal_states and len(g.terminal_states.cat.categories) > 0:
|
|
195
|
+
g.compute_fate_probabilities()
|
|
196
|
+
absorption_probs = g.fate_probabilities
|
|
197
|
+
terminal_states = list(g.terminal_states.cat.categories)
|
|
198
|
+
root_state = terminal_states[0]
|
|
199
|
+
pseudotime = 1 - absorption_probs[root_state].X.flatten()
|
|
200
|
+
|
|
201
|
+
adata_for_cellrank.obs["pseudotime"] = pseudotime
|
|
202
|
+
adata_for_cellrank.obsm["fate_probabilities"] = absorption_probs
|
|
203
|
+
adata_for_cellrank.obs["terminal_states"] = g.terminal_states
|
|
204
|
+
else:
|
|
205
|
+
if hasattr(g, "macrostates") and g.macrostates is not None:
|
|
206
|
+
macrostate_probs = g.macrostates_memberships
|
|
207
|
+
pseudotime = 1 - macrostate_probs[:, 0].X.flatten()
|
|
208
|
+
adata_for_cellrank.obs["pseudotime"] = pseudotime
|
|
209
|
+
else:
|
|
210
|
+
raise ProcessingError(
|
|
211
|
+
"CellRank could not compute either terminal states or macrostates"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if hasattr(g, "macrostates") and g.macrostates is not None:
|
|
215
|
+
adata_for_cellrank.obs["macrostates"] = g.macrostates
|
|
216
|
+
|
|
217
|
+
# Transfer results back to original adata
|
|
218
|
+
if "pseudotime" in adata_for_cellrank.obs:
|
|
219
|
+
adata.obs["pseudotime"] = adata_for_cellrank.obs["pseudotime"]
|
|
220
|
+
if "terminal_states" in adata_for_cellrank.obs:
|
|
221
|
+
adata.obs["terminal_states"] = adata_for_cellrank.obs["terminal_states"]
|
|
222
|
+
if "macrostates" in adata_for_cellrank.obs:
|
|
223
|
+
adata.obs["macrostates"] = adata_for_cellrank.obs["macrostates"]
|
|
224
|
+
if "fate_probabilities" in adata_for_cellrank.obsm:
|
|
225
|
+
adata.obsm["fate_probabilities"] = adata_for_cellrank.obsm["fate_probabilities"]
|
|
226
|
+
|
|
227
|
+
# Update velovi_adata if used
|
|
228
|
+
if (
|
|
229
|
+
adata.uns.get("velocity_method") == "velovi"
|
|
230
|
+
and "velovi_adata" in adata.uns
|
|
231
|
+
):
|
|
232
|
+
adata.uns["velovi_adata"] = adata_for_cellrank
|
|
233
|
+
|
|
234
|
+
return adata
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def spatial_aware_embedding(adata, spatial_weight=0.3):
|
|
238
|
+
"""Generate spatially-aware low-dimensional embedding."""
|
|
239
|
+
from sklearn.metrics.pairwise import euclidean_distances
|
|
240
|
+
from umap import UMAP
|
|
241
|
+
|
|
242
|
+
spatial_coords = require_spatial_coords(adata)
|
|
243
|
+
ensure_pca(adata)
|
|
244
|
+
|
|
245
|
+
expr_dist = euclidean_distances(adata.obsm["X_pca"])
|
|
246
|
+
spatial_dist = euclidean_distances(spatial_coords)
|
|
247
|
+
combined_dist = (1 - spatial_weight) * expr_dist + spatial_weight * spatial_dist
|
|
248
|
+
|
|
249
|
+
umap_op = UMAP(metric="precomputed")
|
|
250
|
+
embedding = umap_op.fit_transform(combined_dist)
|
|
251
|
+
adata.obsm["X_spatial_umap"] = embedding
|
|
252
|
+
|
|
253
|
+
return adata
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def infer_pseudotime_palantir(
|
|
257
|
+
adata, root_cells=None, n_diffusion_components=10, num_waypoints=500
|
|
258
|
+
):
|
|
259
|
+
"""
|
|
260
|
+
Infers cellular trajectories and pseudotime using Palantir.
|
|
261
|
+
|
|
262
|
+
Palantir models differentiation as a stochastic process on a graph,
|
|
263
|
+
using diffusion maps to capture data geometry and computing fate
|
|
264
|
+
probabilities via random walks from a root cell.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
adata : AnnData
|
|
269
|
+
The annotated data matrix with PCA results.
|
|
270
|
+
root_cells : list of str, optional
|
|
271
|
+
Cell identifiers as starting points. Auto-selected if not provided.
|
|
272
|
+
n_diffusion_components : int, default 10
|
|
273
|
+
Number of diffusion components.
|
|
274
|
+
num_waypoints : int, default 500
|
|
275
|
+
Number of waypoints for trajectory granularity.
|
|
276
|
+
"""
|
|
277
|
+
import palantir
|
|
278
|
+
|
|
279
|
+
ensure_pca(adata)
|
|
280
|
+
|
|
281
|
+
pca_df = pd.DataFrame(adata.obsm["X_pca"], index=adata.obs_names)
|
|
282
|
+
dm_res = palantir.utils.run_diffusion_maps(
|
|
283
|
+
pca_df, n_components=n_diffusion_components
|
|
284
|
+
)
|
|
285
|
+
ms_data = pd.DataFrame(dm_res["EigenVectors"], index=pca_df.index)
|
|
286
|
+
|
|
287
|
+
if root_cells is not None and len(root_cells) > 0:
|
|
288
|
+
if root_cells[0] not in ms_data.index:
|
|
289
|
+
raise ParameterError(f"Root cell '{root_cells[0]}' not found in data")
|
|
290
|
+
start_cell = root_cells[0]
|
|
291
|
+
else:
|
|
292
|
+
start_cell = ms_data.iloc[:, 0].idxmax()
|
|
293
|
+
|
|
294
|
+
pr_res = palantir.core.run_palantir(
|
|
295
|
+
ms_data, start_cell, num_waypoints=num_waypoints
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
adata.obs["palantir_pseudotime"] = pr_res.pseudotime
|
|
299
|
+
adata.obsm["palantir_branch_probs"] = pr_res.branch_probs
|
|
300
|
+
|
|
301
|
+
return adata
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def compute_dpt_trajectory(adata, root_cells=None, ctx: Optional["ToolContext"] = None):
|
|
305
|
+
"""Compute Diffusion Pseudotime trajectory analysis."""
|
|
306
|
+
import numpy as np
|
|
307
|
+
import scanpy as sc
|
|
308
|
+
|
|
309
|
+
ensure_pca(adata)
|
|
310
|
+
ensure_neighbors(adata)
|
|
311
|
+
ensure_diffmap(adata)
|
|
312
|
+
|
|
313
|
+
if root_cells is not None and len(root_cells) > 0:
|
|
314
|
+
if root_cells[0] in adata.obs_names:
|
|
315
|
+
adata.uns["iroot"] = np.where(adata.obs_names == root_cells[0])[0][0]
|
|
316
|
+
else:
|
|
317
|
+
raise ParameterError(
|
|
318
|
+
f"Root cell '{root_cells[0]}' not found. "
|
|
319
|
+
f"Use valid cell ID from adata.obs_names or omit to auto-select."
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
adata.uns["iroot"] = 0
|
|
323
|
+
|
|
324
|
+
if "dpt_pseudotime" not in adata.obs:
|
|
325
|
+
try:
|
|
326
|
+
sc.tl.dpt(adata)
|
|
327
|
+
except Exception as e:
|
|
328
|
+
raise ProcessingError(f"DPT computation failed: {e}") from e
|
|
329
|
+
|
|
330
|
+
if "dpt_pseudotime" not in adata.obs.columns:
|
|
331
|
+
raise ProcessingError("DPT computation did not create 'dpt_pseudotime' column")
|
|
332
|
+
|
|
333
|
+
adata.obs["dpt_pseudotime"] = adata.obs["dpt_pseudotime"].fillna(0)
|
|
334
|
+
|
|
335
|
+
return adata
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def has_velocity_data(adata) -> bool:
|
|
339
|
+
"""Check if RNA velocity has been computed (by any method)."""
|
|
340
|
+
return (
|
|
341
|
+
"velocity_graph" in adata.uns
|
|
342
|
+
or "velovi_adata" in adata.uns
|
|
343
|
+
or "velocity_method" in adata.uns
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
async def analyze_trajectory(
|
|
348
|
+
data_id: str,
|
|
349
|
+
ctx: "ToolContext",
|
|
350
|
+
params: TrajectoryParameters = TrajectoryParameters(),
|
|
351
|
+
) -> TrajectoryResult:
|
|
352
|
+
"""
|
|
353
|
+
Analyze trajectory and cell state transitions in spatial transcriptomics data.
|
|
354
|
+
|
|
355
|
+
This is the main MCP entry point for trajectory inference. It supports:
|
|
356
|
+
- CellRank: Requires pre-computed velocity data
|
|
357
|
+
- Palantir: Expression-based, no velocity required
|
|
358
|
+
- DPT: Diffusion-based, no velocity required
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
data_id: Dataset identifier.
|
|
362
|
+
ctx: ToolContext for data access and logging.
|
|
363
|
+
params: Trajectory analysis parameters.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
TrajectoryResult with pseudotime and method metadata.
|
|
367
|
+
"""
|
|
368
|
+
adata = await ctx.get_adata(data_id)
|
|
369
|
+
|
|
370
|
+
velocity_available = has_velocity_data(adata)
|
|
371
|
+
pseudotime_key = None
|
|
372
|
+
method_used = params.method
|
|
373
|
+
|
|
374
|
+
# Execute requested method
|
|
375
|
+
if params.method == "cellrank":
|
|
376
|
+
if not velocity_available:
|
|
377
|
+
raise ProcessingError(
|
|
378
|
+
"CellRank requires velocity data. Run velocity analysis first or use palantir/dpt."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
require("cellrank")
|
|
382
|
+
import cellrank as cr # noqa: F401
|
|
383
|
+
|
|
384
|
+
try:
|
|
385
|
+
with suppress_output():
|
|
386
|
+
adata = infer_spatial_trajectory_cellrank(
|
|
387
|
+
adata,
|
|
388
|
+
spatial_weight=params.spatial_weight,
|
|
389
|
+
kernel_weights=params.cellrank_kernel_weights,
|
|
390
|
+
n_states=params.cellrank_n_states,
|
|
391
|
+
)
|
|
392
|
+
pseudotime_key = "pseudotime"
|
|
393
|
+
method_used = "cellrank"
|
|
394
|
+
except Exception as e:
|
|
395
|
+
raise ProcessingError(f"CellRank trajectory inference failed: {e}") from e
|
|
396
|
+
|
|
397
|
+
elif params.method == "palantir":
|
|
398
|
+
try:
|
|
399
|
+
with suppress_output():
|
|
400
|
+
has_spatial = get_spatial_key(adata) is not None
|
|
401
|
+
if has_spatial and params.spatial_weight > 0:
|
|
402
|
+
adata = spatial_aware_embedding(
|
|
403
|
+
adata, spatial_weight=params.spatial_weight
|
|
404
|
+
)
|
|
405
|
+
elif not has_spatial and params.spatial_weight > 0:
|
|
406
|
+
await ctx.warning(
|
|
407
|
+
f"Spatial weight {params.spatial_weight} specified but no spatial "
|
|
408
|
+
"coordinates found. Using expression-only Palantir."
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
adata = infer_pseudotime_palantir(
|
|
412
|
+
adata,
|
|
413
|
+
root_cells=params.root_cells,
|
|
414
|
+
n_diffusion_components=params.palantir_n_diffusion_components,
|
|
415
|
+
num_waypoints=params.palantir_num_waypoints,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
pseudotime_key = "palantir_pseudotime"
|
|
419
|
+
method_used = "palantir"
|
|
420
|
+
|
|
421
|
+
except Exception as e:
|
|
422
|
+
raise ProcessingError(f"Palantir trajectory inference failed: {e}") from e
|
|
423
|
+
|
|
424
|
+
elif params.method == "dpt":
|
|
425
|
+
try:
|
|
426
|
+
with suppress_output():
|
|
427
|
+
adata = compute_dpt_trajectory(
|
|
428
|
+
adata, root_cells=params.root_cells, ctx=ctx
|
|
429
|
+
)
|
|
430
|
+
pseudotime_key = "dpt_pseudotime"
|
|
431
|
+
method_used = "dpt"
|
|
432
|
+
except Exception as e:
|
|
433
|
+
raise ProcessingError(f"DPT analysis failed: {e}") from e
|
|
434
|
+
|
|
435
|
+
else:
|
|
436
|
+
raise ParameterError(f"Unknown trajectory method: {params.method}")
|
|
437
|
+
|
|
438
|
+
if pseudotime_key is None or pseudotime_key not in adata.obs.columns:
|
|
439
|
+
raise ProcessingError("Failed to compute pseudotime with any available method")
|
|
440
|
+
|
|
441
|
+
# Store scientific metadata
|
|
442
|
+
from ..utils.adata_utils import store_analysis_metadata
|
|
443
|
+
|
|
444
|
+
results_keys_dict: dict[str, Any] = {"obs": [pseudotime_key], "obsm": [], "uns": []}
|
|
445
|
+
|
|
446
|
+
if method_used == "cellrank":
|
|
447
|
+
results_keys_dict["obs"].extend(["terminal_states", "macrostates"])
|
|
448
|
+
results_keys_dict["obsm"].append("fate_probabilities")
|
|
449
|
+
results_keys_dict["uns"].append("velocity_method")
|
|
450
|
+
elif method_used == "palantir":
|
|
451
|
+
results_keys_dict["obsm"].append("palantir_branch_probs")
|
|
452
|
+
elif method_used == "dpt":
|
|
453
|
+
results_keys_dict["uns"].append("iroot")
|
|
454
|
+
|
|
455
|
+
parameters_dict: dict[str, Any] = {"spatial_weight": params.spatial_weight}
|
|
456
|
+
if method_used == "cellrank":
|
|
457
|
+
parameters_dict.update(
|
|
458
|
+
{
|
|
459
|
+
"kernel_weights": params.cellrank_kernel_weights,
|
|
460
|
+
"n_states": params.cellrank_n_states,
|
|
461
|
+
}
|
|
462
|
+
)
|
|
463
|
+
elif method_used == "palantir":
|
|
464
|
+
parameters_dict.update(
|
|
465
|
+
{
|
|
466
|
+
"n_diffusion_components": params.palantir_n_diffusion_components,
|
|
467
|
+
"num_waypoints": params.palantir_num_waypoints,
|
|
468
|
+
}
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if params.root_cells:
|
|
472
|
+
parameters_dict["root_cells"] = params.root_cells
|
|
473
|
+
|
|
474
|
+
statistics_dict = {
|
|
475
|
+
"velocity_computed": velocity_available,
|
|
476
|
+
"pseudotime_key": pseudotime_key,
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
store_analysis_metadata(
|
|
480
|
+
adata,
|
|
481
|
+
analysis_name=f"trajectory_{method_used}",
|
|
482
|
+
method=method_used,
|
|
483
|
+
parameters=parameters_dict,
|
|
484
|
+
results_keys=results_keys_dict,
|
|
485
|
+
statistics=statistics_dict,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return TrajectoryResult(
|
|
489
|
+
data_id=data_id,
|
|
490
|
+
pseudotime_computed=True,
|
|
491
|
+
velocity_computed=velocity_available,
|
|
492
|
+
pseudotime_key=pseudotime_key,
|
|
493
|
+
method=method_used,
|
|
494
|
+
spatial_weight=params.spatial_weight,
|
|
495
|
+
)
|