chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trajectory visualization functions for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Pseudotime visualizations
|
|
6
|
+
- CellRank circular projections
|
|
7
|
+
- Fate map visualizations
|
|
8
|
+
- Gene trends along lineages
|
|
9
|
+
- Fate heatmaps
|
|
10
|
+
- Palantir results visualization
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import TYPE_CHECKING, Optional
|
|
14
|
+
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
import scanpy as sc
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import anndata as ad
|
|
20
|
+
|
|
21
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
22
|
+
|
|
23
|
+
from ...models.data import VisualizationParameters
|
|
24
|
+
from ...utils.adata_utils import validate_obs_column
|
|
25
|
+
from ...utils.dependency_manager import require
|
|
26
|
+
from ...utils.image_utils import non_interactive_backend
|
|
27
|
+
from ...utils.exceptions import (
|
|
28
|
+
DataCompatibilityError,
|
|
29
|
+
DataNotFoundError,
|
|
30
|
+
ParameterError,
|
|
31
|
+
)
|
|
32
|
+
from .core import (
|
|
33
|
+
get_categorical_columns,
|
|
34
|
+
infer_basis,
|
|
35
|
+
resolve_figure_size,
|
|
36
|
+
setup_multi_panel_figure,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# =============================================================================
|
|
40
|
+
# Main Router
|
|
41
|
+
# =============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def create_trajectory_visualization(
|
|
45
|
+
adata: "ad.AnnData",
|
|
46
|
+
params: VisualizationParameters,
|
|
47
|
+
context: Optional["ToolContext"] = None,
|
|
48
|
+
) -> plt.Figure:
|
|
49
|
+
"""Create trajectory visualization based on subtype.
|
|
50
|
+
|
|
51
|
+
Dispatcher function that routes to appropriate trajectory visualization.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
adata: AnnData object with computed trajectory/pseudotime
|
|
55
|
+
params: Visualization parameters including subtype
|
|
56
|
+
context: MCP context
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Matplotlib figure with trajectory visualization
|
|
60
|
+
|
|
61
|
+
Subtypes:
|
|
62
|
+
- pseudotime (default): Pseudotime on embedding with optional velocity stream
|
|
63
|
+
- circular: CellRank circular projection of fate probabilities
|
|
64
|
+
- fate_map: CellRank aggregated fate probabilities (bar/paga/heatmap)
|
|
65
|
+
- gene_trends: CellRank gene expression trends along lineages
|
|
66
|
+
- fate_heatmap: CellRank smoothed expression heatmap by pseudotime
|
|
67
|
+
- palantir: Palantir comprehensive results (pseudotime, entropy, fate probs)
|
|
68
|
+
"""
|
|
69
|
+
subtype = params.subtype or "pseudotime"
|
|
70
|
+
|
|
71
|
+
if context:
|
|
72
|
+
await context.info(f"Creating trajectory visualization (subtype: {subtype})")
|
|
73
|
+
|
|
74
|
+
if subtype == "pseudotime":
|
|
75
|
+
return await _create_trajectory_pseudotime_plot(adata, params, context)
|
|
76
|
+
elif subtype == "circular":
|
|
77
|
+
return await _create_cellrank_circular_projection(adata, params, context)
|
|
78
|
+
elif subtype == "fate_map":
|
|
79
|
+
return await _create_cellrank_fate_map(adata, params, context)
|
|
80
|
+
elif subtype == "gene_trends":
|
|
81
|
+
return await _create_cellrank_gene_trends(adata, params, context)
|
|
82
|
+
elif subtype == "fate_heatmap":
|
|
83
|
+
return await _create_cellrank_fate_heatmap(adata, params, context)
|
|
84
|
+
elif subtype == "palantir":
|
|
85
|
+
return await _create_palantir_results(adata, params, context)
|
|
86
|
+
else:
|
|
87
|
+
raise ParameterError(
|
|
88
|
+
f"Unsupported subtype for trajectory: '{subtype}'. "
|
|
89
|
+
f"Available subtypes: pseudotime, circular, fate_map, gene_trends, "
|
|
90
|
+
f"fate_heatmap, palantir"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# =============================================================================
|
|
95
|
+
# Visualization Functions
|
|
96
|
+
# =============================================================================
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def _create_trajectory_pseudotime_plot(
|
|
100
|
+
adata: "ad.AnnData",
|
|
101
|
+
params: VisualizationParameters,
|
|
102
|
+
context: Optional["ToolContext"] = None,
|
|
103
|
+
) -> plt.Figure:
|
|
104
|
+
"""Create trajectory pseudotime visualization.
|
|
105
|
+
|
|
106
|
+
Shows pseudotime on embedding with optional velocity stream plot.
|
|
107
|
+
|
|
108
|
+
Data requirements:
|
|
109
|
+
- adata.obs['*pseudotime*']: Any pseudotime column
|
|
110
|
+
- adata.obsm['X_umap'] or 'spatial': Embedding for visualization
|
|
111
|
+
- adata.uns['velocity_graph']: Optional, for velocity stream panel
|
|
112
|
+
"""
|
|
113
|
+
# Find pseudotime key
|
|
114
|
+
pseudotime_key = params.feature
|
|
115
|
+
if not pseudotime_key:
|
|
116
|
+
pseudotime_candidates = [
|
|
117
|
+
k for k in adata.obs.columns if "pseudotime" in k.lower()
|
|
118
|
+
]
|
|
119
|
+
if pseudotime_candidates:
|
|
120
|
+
pseudotime_key = pseudotime_candidates[0]
|
|
121
|
+
if context:
|
|
122
|
+
await context.info(f"Found pseudotime column: {pseudotime_key}")
|
|
123
|
+
else:
|
|
124
|
+
raise DataNotFoundError(
|
|
125
|
+
"No pseudotime found. Run trajectory analysis first."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
validate_obs_column(adata, pseudotime_key, "Pseudotime")
|
|
129
|
+
|
|
130
|
+
# Check if RNA velocity is available
|
|
131
|
+
has_velocity = "velocity_graph" in adata.uns
|
|
132
|
+
|
|
133
|
+
# Determine basis for plotting
|
|
134
|
+
basis = infer_basis(adata, preferred=params.basis)
|
|
135
|
+
if not basis:
|
|
136
|
+
raise DataCompatibilityError(
|
|
137
|
+
f"No valid embedding basis found. "
|
|
138
|
+
f"Available keys: {list(adata.obsm.keys())}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Setup figure: 1 panel if no velocity, 2 panels if velocity exists
|
|
142
|
+
n_panels = 2 if has_velocity else 1
|
|
143
|
+
|
|
144
|
+
fig, axes = setup_multi_panel_figure(
|
|
145
|
+
n_panels=n_panels,
|
|
146
|
+
params=params,
|
|
147
|
+
default_title=f"Trajectory Analysis - Pseudotime ({pseudotime_key})",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Panel 1: Pseudotime plot
|
|
151
|
+
ax1 = axes[0]
|
|
152
|
+
try:
|
|
153
|
+
sc.pl.embedding(
|
|
154
|
+
adata,
|
|
155
|
+
basis=basis,
|
|
156
|
+
color=pseudotime_key,
|
|
157
|
+
cmap=params.colormap,
|
|
158
|
+
ax=ax1,
|
|
159
|
+
show=False,
|
|
160
|
+
frameon=params.show_axes,
|
|
161
|
+
alpha=params.alpha,
|
|
162
|
+
colorbar_loc="right" if params.show_colorbar else None,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if basis == "spatial":
|
|
166
|
+
ax1.invert_yaxis()
|
|
167
|
+
|
|
168
|
+
except Exception as e:
|
|
169
|
+
ax1.text(
|
|
170
|
+
0.5,
|
|
171
|
+
0.5,
|
|
172
|
+
f"Error plotting pseudotime:\n{e}",
|
|
173
|
+
ha="center",
|
|
174
|
+
va="center",
|
|
175
|
+
transform=ax1.transAxes,
|
|
176
|
+
)
|
|
177
|
+
ax1.set_title("Pseudotime (Error)", fontsize=12)
|
|
178
|
+
|
|
179
|
+
# Panel 2: Velocity stream plot (if available)
|
|
180
|
+
if has_velocity and n_panels > 1:
|
|
181
|
+
ax2 = axes[1]
|
|
182
|
+
try:
|
|
183
|
+
import scvelo as scv
|
|
184
|
+
|
|
185
|
+
scv.pl.velocity_embedding_stream(
|
|
186
|
+
adata,
|
|
187
|
+
basis=basis,
|
|
188
|
+
color=pseudotime_key,
|
|
189
|
+
cmap=params.colormap,
|
|
190
|
+
ax=ax2,
|
|
191
|
+
show=False,
|
|
192
|
+
alpha=params.alpha,
|
|
193
|
+
frameon=params.show_axes,
|
|
194
|
+
)
|
|
195
|
+
ax2.set_title("RNA Velocity Stream", fontsize=12)
|
|
196
|
+
|
|
197
|
+
if basis == "spatial":
|
|
198
|
+
ax2.invert_yaxis()
|
|
199
|
+
|
|
200
|
+
except ImportError:
|
|
201
|
+
ax2.text(
|
|
202
|
+
0.5,
|
|
203
|
+
0.5,
|
|
204
|
+
"scvelo not installed",
|
|
205
|
+
ha="center",
|
|
206
|
+
va="center",
|
|
207
|
+
transform=ax2.transAxes,
|
|
208
|
+
)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
ax2.text(
|
|
211
|
+
0.5,
|
|
212
|
+
0.5,
|
|
213
|
+
f"Error: {str(e)[:50]}",
|
|
214
|
+
ha="center",
|
|
215
|
+
va="center",
|
|
216
|
+
transform=ax2.transAxes,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
plt.tight_layout(rect=(0, 0, 1, 0.95))
|
|
220
|
+
return fig
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
async def _create_cellrank_circular_projection(
|
|
224
|
+
adata: "ad.AnnData",
|
|
225
|
+
params: VisualizationParameters,
|
|
226
|
+
context: Optional["ToolContext"] = None,
|
|
227
|
+
) -> plt.Figure:
|
|
228
|
+
"""Create CellRank circular projection using cr.pl.circular_projection.
|
|
229
|
+
|
|
230
|
+
Shows fate probabilities in a circular layout.
|
|
231
|
+
|
|
232
|
+
Data requirements:
|
|
233
|
+
- adata.obs['terminal_states'] or 'term_states_fwd': Terminal state labels
|
|
234
|
+
- adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
|
|
235
|
+
"""
|
|
236
|
+
require("cellrank", feature="circular projection")
|
|
237
|
+
import cellrank as cr
|
|
238
|
+
|
|
239
|
+
# Check for CellRank results
|
|
240
|
+
fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
|
|
241
|
+
fate_key = None
|
|
242
|
+
for key in fate_key_candidates:
|
|
243
|
+
if key in adata.obsm:
|
|
244
|
+
fate_key = key
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
if not fate_key:
|
|
248
|
+
raise DataNotFoundError(
|
|
249
|
+
"CellRank fate probabilities not found. Run trajectory analysis first."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if context:
|
|
253
|
+
await context.info("Creating CellRank circular projection")
|
|
254
|
+
|
|
255
|
+
# Determine keys for coloring
|
|
256
|
+
keys = [params.cluster_key] if params.cluster_key else None
|
|
257
|
+
if not keys:
|
|
258
|
+
categorical_cols = get_categorical_columns(adata, limit=3)
|
|
259
|
+
keys = categorical_cols if categorical_cols else None
|
|
260
|
+
|
|
261
|
+
# Use centralized figure size resolution
|
|
262
|
+
figsize = resolve_figure_size(params, "trajectory")
|
|
263
|
+
|
|
264
|
+
with non_interactive_backend():
|
|
265
|
+
cr.pl.circular_projection(
|
|
266
|
+
adata,
|
|
267
|
+
keys=keys,
|
|
268
|
+
figsize=figsize,
|
|
269
|
+
dpi=params.dpi,
|
|
270
|
+
)
|
|
271
|
+
fig = plt.gcf()
|
|
272
|
+
|
|
273
|
+
if params.title:
|
|
274
|
+
fig.suptitle(params.title, fontsize=14, y=1.02)
|
|
275
|
+
|
|
276
|
+
plt.tight_layout()
|
|
277
|
+
return fig
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
async def _create_cellrank_fate_map(
|
|
281
|
+
adata: "ad.AnnData",
|
|
282
|
+
params: VisualizationParameters,
|
|
283
|
+
context: Optional["ToolContext"] = None,
|
|
284
|
+
) -> plt.Figure:
|
|
285
|
+
"""Create CellRank aggregated fate probabilities.
|
|
286
|
+
|
|
287
|
+
Shows fate probabilities aggregated by cluster as bar, paga, or heatmap.
|
|
288
|
+
|
|
289
|
+
Data requirements:
|
|
290
|
+
- adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
|
|
291
|
+
- adata.obs[cluster_key]: Cluster labels for aggregation
|
|
292
|
+
"""
|
|
293
|
+
require("cellrank", feature="fate map")
|
|
294
|
+
import cellrank as cr
|
|
295
|
+
|
|
296
|
+
# Check for CellRank results
|
|
297
|
+
fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
|
|
298
|
+
fate_key = None
|
|
299
|
+
for key in fate_key_candidates:
|
|
300
|
+
if key in adata.obsm:
|
|
301
|
+
fate_key = key
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
if not fate_key:
|
|
305
|
+
raise DataNotFoundError(
|
|
306
|
+
"CellRank fate probabilities not found. Run trajectory analysis first."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Determine cluster key
|
|
310
|
+
cluster_key = params.cluster_key
|
|
311
|
+
if not cluster_key:
|
|
312
|
+
categorical_cols = get_categorical_columns(adata)
|
|
313
|
+
if categorical_cols:
|
|
314
|
+
cluster_key = categorical_cols[0]
|
|
315
|
+
if context:
|
|
316
|
+
await context.info(f"Using cluster_key: '{cluster_key}'")
|
|
317
|
+
else:
|
|
318
|
+
raise ParameterError("cluster_key is required for fate map visualization.")
|
|
319
|
+
|
|
320
|
+
if context:
|
|
321
|
+
await context.info(f"Creating CellRank fate map for '{cluster_key}'")
|
|
322
|
+
|
|
323
|
+
# Use centralized figure size resolution
|
|
324
|
+
figsize = resolve_figure_size(params, "violin") # similar width to violin plots
|
|
325
|
+
|
|
326
|
+
with non_interactive_backend():
|
|
327
|
+
cr.pl.aggregate_fate_probabilities(
|
|
328
|
+
adata,
|
|
329
|
+
cluster_key=cluster_key,
|
|
330
|
+
mode="bar",
|
|
331
|
+
figsize=figsize,
|
|
332
|
+
dpi=params.dpi,
|
|
333
|
+
)
|
|
334
|
+
fig = plt.gcf()
|
|
335
|
+
|
|
336
|
+
title = params.title or f"CellRank Fate Probabilities by {cluster_key}"
|
|
337
|
+
fig.suptitle(title, fontsize=14, y=1.02)
|
|
338
|
+
|
|
339
|
+
plt.tight_layout()
|
|
340
|
+
return fig
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
async def _create_cellrank_gene_trends(
|
|
344
|
+
adata: "ad.AnnData",
|
|
345
|
+
params: VisualizationParameters,
|
|
346
|
+
context: Optional["ToolContext"] = None,
|
|
347
|
+
) -> plt.Figure:
|
|
348
|
+
"""Create CellRank gene expression trends using cr.pl.gene_trends.
|
|
349
|
+
|
|
350
|
+
Shows gene expression trends along lineages/pseudotime.
|
|
351
|
+
|
|
352
|
+
Data requirements:
|
|
353
|
+
- adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
|
|
354
|
+
- adata.obs['latent_time'] or similar pseudotime
|
|
355
|
+
- Gene expression in adata.X
|
|
356
|
+
"""
|
|
357
|
+
require("cellrank", feature="gene trends")
|
|
358
|
+
import cellrank as cr
|
|
359
|
+
|
|
360
|
+
# Import GAM model preparation from trajectory module
|
|
361
|
+
from ..trajectory import prepare_gam_model_for_visualization
|
|
362
|
+
|
|
363
|
+
# Check for fate probabilities
|
|
364
|
+
fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
|
|
365
|
+
fate_key = None
|
|
366
|
+
for key in fate_key_candidates:
|
|
367
|
+
if key in adata.obsm:
|
|
368
|
+
fate_key = key
|
|
369
|
+
break
|
|
370
|
+
|
|
371
|
+
if not fate_key:
|
|
372
|
+
raise DataNotFoundError(
|
|
373
|
+
"CellRank fate probabilities not found. Run trajectory analysis first."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Find time key
|
|
377
|
+
time_key = None
|
|
378
|
+
time_candidates = ["latent_time", "palantir_pseudotime", "dpt_pseudotime"]
|
|
379
|
+
for key in time_candidates:
|
|
380
|
+
if key in adata.obs.columns:
|
|
381
|
+
time_key = key
|
|
382
|
+
break
|
|
383
|
+
|
|
384
|
+
if not time_key:
|
|
385
|
+
raise DataNotFoundError("No pseudotime found. Run trajectory analysis first.")
|
|
386
|
+
|
|
387
|
+
# Get genes to plot
|
|
388
|
+
if params.feature:
|
|
389
|
+
if isinstance(params.feature, str):
|
|
390
|
+
genes = [params.feature]
|
|
391
|
+
else:
|
|
392
|
+
genes = list(params.feature)
|
|
393
|
+
valid_genes = [g for g in genes if g in adata.var_names]
|
|
394
|
+
if not valid_genes:
|
|
395
|
+
raise DataNotFoundError(f"None of the specified genes found: {genes}")
|
|
396
|
+
genes = valid_genes[:6]
|
|
397
|
+
else:
|
|
398
|
+
if "highly_variable" in adata.var.columns:
|
|
399
|
+
hvg = adata.var_names[adata.var["highly_variable"]]
|
|
400
|
+
genes = list(hvg[:6])
|
|
401
|
+
else:
|
|
402
|
+
genes = list(adata.var_names[:6])
|
|
403
|
+
|
|
404
|
+
if context:
|
|
405
|
+
await context.info(f"Creating gene trends for: {genes}")
|
|
406
|
+
|
|
407
|
+
# Use centralized figure size resolution with dynamic panel height
|
|
408
|
+
figsize = resolve_figure_size(
|
|
409
|
+
params, n_panels=len(genes), panel_width=12, panel_height=3
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
model, lineage_names = prepare_gam_model_for_visualization(
|
|
413
|
+
adata, genes, time_key=time_key, fate_key=fate_key
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
if context:
|
|
417
|
+
await context.info(f"Lineages: {lineage_names}")
|
|
418
|
+
|
|
419
|
+
with non_interactive_backend():
|
|
420
|
+
cr.pl.gene_trends(
|
|
421
|
+
adata,
|
|
422
|
+
model=model,
|
|
423
|
+
genes=genes,
|
|
424
|
+
time_key=time_key,
|
|
425
|
+
figsize=figsize,
|
|
426
|
+
n_jobs=1,
|
|
427
|
+
show_progress_bar=False,
|
|
428
|
+
)
|
|
429
|
+
fig = plt.gcf()
|
|
430
|
+
fig.set_dpi(params.dpi)
|
|
431
|
+
|
|
432
|
+
if params.title:
|
|
433
|
+
fig.suptitle(params.title, fontsize=14, y=1.02)
|
|
434
|
+
|
|
435
|
+
plt.tight_layout()
|
|
436
|
+
return fig
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
async def _create_cellrank_fate_heatmap(
|
|
440
|
+
adata: "ad.AnnData",
|
|
441
|
+
params: VisualizationParameters,
|
|
442
|
+
context: Optional["ToolContext"] = None,
|
|
443
|
+
) -> plt.Figure:
|
|
444
|
+
"""Create CellRank fate heatmap using cr.pl.heatmap.
|
|
445
|
+
|
|
446
|
+
Shows smoothed gene expression ordered by pseudotime per lineage.
|
|
447
|
+
|
|
448
|
+
Data requirements:
|
|
449
|
+
- adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
|
|
450
|
+
- adata.obs['latent_time'] or similar pseudotime
|
|
451
|
+
- Gene expression in adata.X
|
|
452
|
+
"""
|
|
453
|
+
require("cellrank", feature="fate heatmap")
|
|
454
|
+
import cellrank as cr
|
|
455
|
+
|
|
456
|
+
# Import GAM model preparation from trajectory module
|
|
457
|
+
from ..trajectory import prepare_gam_model_for_visualization
|
|
458
|
+
|
|
459
|
+
# Check for fate probabilities
|
|
460
|
+
fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
|
|
461
|
+
fate_key = None
|
|
462
|
+
for key in fate_key_candidates:
|
|
463
|
+
if key in adata.obsm:
|
|
464
|
+
fate_key = key
|
|
465
|
+
break
|
|
466
|
+
|
|
467
|
+
if not fate_key:
|
|
468
|
+
raise DataNotFoundError(
|
|
469
|
+
"CellRank fate probabilities not found. Run trajectory analysis first."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Find time key
|
|
473
|
+
time_key = None
|
|
474
|
+
time_candidates = ["latent_time", "palantir_pseudotime", "dpt_pseudotime"]
|
|
475
|
+
for key in time_candidates:
|
|
476
|
+
if key in adata.obs.columns:
|
|
477
|
+
time_key = key
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
if not time_key:
|
|
481
|
+
raise DataNotFoundError("No pseudotime found for fate heatmap.")
|
|
482
|
+
|
|
483
|
+
# Get genes
|
|
484
|
+
if params.feature:
|
|
485
|
+
if isinstance(params.feature, str):
|
|
486
|
+
genes = [params.feature]
|
|
487
|
+
else:
|
|
488
|
+
genes = list(params.feature)
|
|
489
|
+
valid_genes = [g for g in genes if g in adata.var_names]
|
|
490
|
+
if not valid_genes:
|
|
491
|
+
raise DataNotFoundError(f"None of the genes found: {genes}")
|
|
492
|
+
genes = valid_genes[:50]
|
|
493
|
+
else:
|
|
494
|
+
if "highly_variable" in adata.var.columns:
|
|
495
|
+
hvg = adata.var_names[adata.var["highly_variable"]]
|
|
496
|
+
genes = list(hvg[:50])
|
|
497
|
+
else:
|
|
498
|
+
genes = list(adata.var_names[:50])
|
|
499
|
+
|
|
500
|
+
if context:
|
|
501
|
+
await context.info(f"Creating fate heatmap with {len(genes)} genes")
|
|
502
|
+
|
|
503
|
+
# Use centralized figure size resolution
|
|
504
|
+
figsize = resolve_figure_size(params, "heatmap")
|
|
505
|
+
|
|
506
|
+
model, lineage_names = prepare_gam_model_for_visualization(
|
|
507
|
+
adata, genes, time_key=time_key, fate_key=fate_key
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
if context:
|
|
511
|
+
await context.info(f"Lineages: {lineage_names}")
|
|
512
|
+
|
|
513
|
+
with non_interactive_backend():
|
|
514
|
+
cr.pl.heatmap(
|
|
515
|
+
adata,
|
|
516
|
+
model=model,
|
|
517
|
+
genes=genes,
|
|
518
|
+
time_key=time_key,
|
|
519
|
+
figsize=figsize,
|
|
520
|
+
n_jobs=1,
|
|
521
|
+
show_progress_bar=False,
|
|
522
|
+
)
|
|
523
|
+
fig = plt.gcf()
|
|
524
|
+
fig.set_dpi(params.dpi)
|
|
525
|
+
|
|
526
|
+
if params.title:
|
|
527
|
+
fig.suptitle(params.title, fontsize=14, y=1.02)
|
|
528
|
+
|
|
529
|
+
plt.tight_layout()
|
|
530
|
+
return fig
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
async def _create_palantir_results(
|
|
534
|
+
adata: "ad.AnnData",
|
|
535
|
+
params: VisualizationParameters,
|
|
536
|
+
context: Optional["ToolContext"] = None,
|
|
537
|
+
) -> plt.Figure:
|
|
538
|
+
"""Create Palantir comprehensive results visualization.
|
|
539
|
+
|
|
540
|
+
Shows pseudotime, entropy, and fate probabilities in a multi-panel figure.
|
|
541
|
+
|
|
542
|
+
Data requirements:
|
|
543
|
+
- adata.obs['palantir_pseudotime']: Pseudotime
|
|
544
|
+
- adata.obs['palantir_entropy']: Differentiation entropy
|
|
545
|
+
- adata.obsm['palantir_fate_probs'] or 'palantir_branch_probs': Fate probabilities
|
|
546
|
+
"""
|
|
547
|
+
# Check for Palantir results
|
|
548
|
+
has_pseudotime = "palantir_pseudotime" in adata.obs.columns
|
|
549
|
+
has_entropy = "palantir_entropy" in adata.obs.columns
|
|
550
|
+
fate_key = None
|
|
551
|
+
for key in ["palantir_fate_probs", "palantir_branch_probs"]:
|
|
552
|
+
if key in adata.obsm:
|
|
553
|
+
fate_key = key
|
|
554
|
+
break
|
|
555
|
+
|
|
556
|
+
if not has_pseudotime:
|
|
557
|
+
raise DataNotFoundError(
|
|
558
|
+
"Palantir results not found. Run trajectory analysis first."
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
if context:
|
|
562
|
+
await context.info("Creating Palantir results visualization")
|
|
563
|
+
|
|
564
|
+
# Determine basis
|
|
565
|
+
basis = infer_basis(
|
|
566
|
+
adata, preferred=params.basis, priority=["umap", "spatial", "pca"]
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Determine number of panels
|
|
570
|
+
n_panels = 1 + int(has_entropy) + (1 if fate_key else 0)
|
|
571
|
+
|
|
572
|
+
# Create figure with centralized utility
|
|
573
|
+
figsize = resolve_figure_size(params, n_panels=n_panels, panel_width=5, panel_height=5)
|
|
574
|
+
fig, axes = plt.subplots(1, n_panels, figsize=figsize, dpi=params.dpi)
|
|
575
|
+
if n_panels == 1:
|
|
576
|
+
axes = [axes]
|
|
577
|
+
|
|
578
|
+
panel_idx = 0
|
|
579
|
+
|
|
580
|
+
# Panel 1: Pseudotime
|
|
581
|
+
ax = axes[panel_idx]
|
|
582
|
+
sc.pl.embedding(
|
|
583
|
+
adata,
|
|
584
|
+
basis=basis,
|
|
585
|
+
color="palantir_pseudotime",
|
|
586
|
+
cmap="viridis",
|
|
587
|
+
ax=ax,
|
|
588
|
+
show=False,
|
|
589
|
+
frameon=params.show_axes,
|
|
590
|
+
title="Palantir Pseudotime",
|
|
591
|
+
)
|
|
592
|
+
if basis == "spatial":
|
|
593
|
+
ax.invert_yaxis()
|
|
594
|
+
panel_idx += 1
|
|
595
|
+
|
|
596
|
+
# Panel 2: Entropy (if available)
|
|
597
|
+
if has_entropy and panel_idx < n_panels:
|
|
598
|
+
ax = axes[panel_idx]
|
|
599
|
+
sc.pl.embedding(
|
|
600
|
+
adata,
|
|
601
|
+
basis=basis,
|
|
602
|
+
color="palantir_entropy",
|
|
603
|
+
cmap="magma",
|
|
604
|
+
ax=ax,
|
|
605
|
+
show=False,
|
|
606
|
+
frameon=params.show_axes,
|
|
607
|
+
title="Differentiation Entropy",
|
|
608
|
+
)
|
|
609
|
+
if basis == "spatial":
|
|
610
|
+
ax.invert_yaxis()
|
|
611
|
+
panel_idx += 1
|
|
612
|
+
|
|
613
|
+
# Panel 3: Fate probabilities summary (if available)
|
|
614
|
+
if fate_key and panel_idx < n_panels:
|
|
615
|
+
ax = axes[panel_idx]
|
|
616
|
+
fate_probs = adata.obsm[fate_key]
|
|
617
|
+
dominant_fate = fate_probs.argmax(axis=1)
|
|
618
|
+
adata.obs["_dominant_fate"] = dominant_fate.astype(str)
|
|
619
|
+
|
|
620
|
+
sc.pl.embedding(
|
|
621
|
+
adata,
|
|
622
|
+
basis=basis,
|
|
623
|
+
color="_dominant_fate",
|
|
624
|
+
ax=ax,
|
|
625
|
+
show=False,
|
|
626
|
+
frameon=params.show_axes,
|
|
627
|
+
title="Dominant Fate",
|
|
628
|
+
)
|
|
629
|
+
if basis == "spatial":
|
|
630
|
+
ax.invert_yaxis()
|
|
631
|
+
|
|
632
|
+
# Clean up temporary column
|
|
633
|
+
del adata.obs["_dominant_fate"]
|
|
634
|
+
|
|
635
|
+
title = params.title or "Palantir Trajectory Analysis"
|
|
636
|
+
fig.suptitle(title, fontsize=14, y=1.02)
|
|
637
|
+
|
|
638
|
+
plt.tight_layout()
|
|
639
|
+
return fig
|