chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,639 @@
1
+ """
2
+ Trajectory visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - Pseudotime visualizations
6
+ - CellRank circular projections
7
+ - Fate map visualizations
8
+ - Gene trends along lineages
9
+ - Fate heatmaps
10
+ - Palantir results visualization
11
+ """
12
+
13
+ from typing import TYPE_CHECKING, Optional
14
+
15
+ import matplotlib.pyplot as plt
16
+ import scanpy as sc
17
+
18
+ if TYPE_CHECKING:
19
+ import anndata as ad
20
+
21
+ from ...spatial_mcp_adapter import ToolContext
22
+
23
+ from ...models.data import VisualizationParameters
24
+ from ...utils.adata_utils import validate_obs_column
25
+ from ...utils.dependency_manager import require
26
+ from ...utils.image_utils import non_interactive_backend
27
+ from ...utils.exceptions import (
28
+ DataCompatibilityError,
29
+ DataNotFoundError,
30
+ ParameterError,
31
+ )
32
+ from .core import (
33
+ get_categorical_columns,
34
+ infer_basis,
35
+ resolve_figure_size,
36
+ setup_multi_panel_figure,
37
+ )
38
+
39
+ # =============================================================================
40
+ # Main Router
41
+ # =============================================================================
42
+
43
+
44
+ async def create_trajectory_visualization(
45
+ adata: "ad.AnnData",
46
+ params: VisualizationParameters,
47
+ context: Optional["ToolContext"] = None,
48
+ ) -> plt.Figure:
49
+ """Create trajectory visualization based on subtype.
50
+
51
+ Dispatcher function that routes to appropriate trajectory visualization.
52
+
53
+ Args:
54
+ adata: AnnData object with computed trajectory/pseudotime
55
+ params: Visualization parameters including subtype
56
+ context: MCP context
57
+
58
+ Returns:
59
+ Matplotlib figure with trajectory visualization
60
+
61
+ Subtypes:
62
+ - pseudotime (default): Pseudotime on embedding with optional velocity stream
63
+ - circular: CellRank circular projection of fate probabilities
64
+ - fate_map: CellRank aggregated fate probabilities (bar/paga/heatmap)
65
+ - gene_trends: CellRank gene expression trends along lineages
66
+ - fate_heatmap: CellRank smoothed expression heatmap by pseudotime
67
+ - palantir: Palantir comprehensive results (pseudotime, entropy, fate probs)
68
+ """
69
+ subtype = params.subtype or "pseudotime"
70
+
71
+ if context:
72
+ await context.info(f"Creating trajectory visualization (subtype: {subtype})")
73
+
74
+ if subtype == "pseudotime":
75
+ return await _create_trajectory_pseudotime_plot(adata, params, context)
76
+ elif subtype == "circular":
77
+ return await _create_cellrank_circular_projection(adata, params, context)
78
+ elif subtype == "fate_map":
79
+ return await _create_cellrank_fate_map(adata, params, context)
80
+ elif subtype == "gene_trends":
81
+ return await _create_cellrank_gene_trends(adata, params, context)
82
+ elif subtype == "fate_heatmap":
83
+ return await _create_cellrank_fate_heatmap(adata, params, context)
84
+ elif subtype == "palantir":
85
+ return await _create_palantir_results(adata, params, context)
86
+ else:
87
+ raise ParameterError(
88
+ f"Unsupported subtype for trajectory: '{subtype}'. "
89
+ f"Available subtypes: pseudotime, circular, fate_map, gene_trends, "
90
+ f"fate_heatmap, palantir"
91
+ )
92
+
93
+
94
+ # =============================================================================
95
+ # Visualization Functions
96
+ # =============================================================================
97
+
98
+
99
+ async def _create_trajectory_pseudotime_plot(
100
+ adata: "ad.AnnData",
101
+ params: VisualizationParameters,
102
+ context: Optional["ToolContext"] = None,
103
+ ) -> plt.Figure:
104
+ """Create trajectory pseudotime visualization.
105
+
106
+ Shows pseudotime on embedding with optional velocity stream plot.
107
+
108
+ Data requirements:
109
+ - adata.obs['*pseudotime*']: Any pseudotime column
110
+ - adata.obsm['X_umap'] or 'spatial': Embedding for visualization
111
+ - adata.uns['velocity_graph']: Optional, for velocity stream panel
112
+ """
113
+ # Find pseudotime key
114
+ pseudotime_key = params.feature
115
+ if not pseudotime_key:
116
+ pseudotime_candidates = [
117
+ k for k in adata.obs.columns if "pseudotime" in k.lower()
118
+ ]
119
+ if pseudotime_candidates:
120
+ pseudotime_key = pseudotime_candidates[0]
121
+ if context:
122
+ await context.info(f"Found pseudotime column: {pseudotime_key}")
123
+ else:
124
+ raise DataNotFoundError(
125
+ "No pseudotime found. Run trajectory analysis first."
126
+ )
127
+
128
+ validate_obs_column(adata, pseudotime_key, "Pseudotime")
129
+
130
+ # Check if RNA velocity is available
131
+ has_velocity = "velocity_graph" in adata.uns
132
+
133
+ # Determine basis for plotting
134
+ basis = infer_basis(adata, preferred=params.basis)
135
+ if not basis:
136
+ raise DataCompatibilityError(
137
+ f"No valid embedding basis found. "
138
+ f"Available keys: {list(adata.obsm.keys())}"
139
+ )
140
+
141
+ # Setup figure: 1 panel if no velocity, 2 panels if velocity exists
142
+ n_panels = 2 if has_velocity else 1
143
+
144
+ fig, axes = setup_multi_panel_figure(
145
+ n_panels=n_panels,
146
+ params=params,
147
+ default_title=f"Trajectory Analysis - Pseudotime ({pseudotime_key})",
148
+ )
149
+
150
+ # Panel 1: Pseudotime plot
151
+ ax1 = axes[0]
152
+ try:
153
+ sc.pl.embedding(
154
+ adata,
155
+ basis=basis,
156
+ color=pseudotime_key,
157
+ cmap=params.colormap,
158
+ ax=ax1,
159
+ show=False,
160
+ frameon=params.show_axes,
161
+ alpha=params.alpha,
162
+ colorbar_loc="right" if params.show_colorbar else None,
163
+ )
164
+
165
+ if basis == "spatial":
166
+ ax1.invert_yaxis()
167
+
168
+ except Exception as e:
169
+ ax1.text(
170
+ 0.5,
171
+ 0.5,
172
+ f"Error plotting pseudotime:\n{e}",
173
+ ha="center",
174
+ va="center",
175
+ transform=ax1.transAxes,
176
+ )
177
+ ax1.set_title("Pseudotime (Error)", fontsize=12)
178
+
179
+ # Panel 2: Velocity stream plot (if available)
180
+ if has_velocity and n_panels > 1:
181
+ ax2 = axes[1]
182
+ try:
183
+ import scvelo as scv
184
+
185
+ scv.pl.velocity_embedding_stream(
186
+ adata,
187
+ basis=basis,
188
+ color=pseudotime_key,
189
+ cmap=params.colormap,
190
+ ax=ax2,
191
+ show=False,
192
+ alpha=params.alpha,
193
+ frameon=params.show_axes,
194
+ )
195
+ ax2.set_title("RNA Velocity Stream", fontsize=12)
196
+
197
+ if basis == "spatial":
198
+ ax2.invert_yaxis()
199
+
200
+ except ImportError:
201
+ ax2.text(
202
+ 0.5,
203
+ 0.5,
204
+ "scvelo not installed",
205
+ ha="center",
206
+ va="center",
207
+ transform=ax2.transAxes,
208
+ )
209
+ except Exception as e:
210
+ ax2.text(
211
+ 0.5,
212
+ 0.5,
213
+ f"Error: {str(e)[:50]}",
214
+ ha="center",
215
+ va="center",
216
+ transform=ax2.transAxes,
217
+ )
218
+
219
+ plt.tight_layout(rect=(0, 0, 1, 0.95))
220
+ return fig
221
+
222
+
223
+ async def _create_cellrank_circular_projection(
224
+ adata: "ad.AnnData",
225
+ params: VisualizationParameters,
226
+ context: Optional["ToolContext"] = None,
227
+ ) -> plt.Figure:
228
+ """Create CellRank circular projection using cr.pl.circular_projection.
229
+
230
+ Shows fate probabilities in a circular layout.
231
+
232
+ Data requirements:
233
+ - adata.obs['terminal_states'] or 'term_states_fwd': Terminal state labels
234
+ - adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
235
+ """
236
+ require("cellrank", feature="circular projection")
237
+ import cellrank as cr
238
+
239
+ # Check for CellRank results
240
+ fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
241
+ fate_key = None
242
+ for key in fate_key_candidates:
243
+ if key in adata.obsm:
244
+ fate_key = key
245
+ break
246
+
247
+ if not fate_key:
248
+ raise DataNotFoundError(
249
+ "CellRank fate probabilities not found. Run trajectory analysis first."
250
+ )
251
+
252
+ if context:
253
+ await context.info("Creating CellRank circular projection")
254
+
255
+ # Determine keys for coloring
256
+ keys = [params.cluster_key] if params.cluster_key else None
257
+ if not keys:
258
+ categorical_cols = get_categorical_columns(adata, limit=3)
259
+ keys = categorical_cols if categorical_cols else None
260
+
261
+ # Use centralized figure size resolution
262
+ figsize = resolve_figure_size(params, "trajectory")
263
+
264
+ with non_interactive_backend():
265
+ cr.pl.circular_projection(
266
+ adata,
267
+ keys=keys,
268
+ figsize=figsize,
269
+ dpi=params.dpi,
270
+ )
271
+ fig = plt.gcf()
272
+
273
+ if params.title:
274
+ fig.suptitle(params.title, fontsize=14, y=1.02)
275
+
276
+ plt.tight_layout()
277
+ return fig
278
+
279
+
280
+ async def _create_cellrank_fate_map(
281
+ adata: "ad.AnnData",
282
+ params: VisualizationParameters,
283
+ context: Optional["ToolContext"] = None,
284
+ ) -> plt.Figure:
285
+ """Create CellRank aggregated fate probabilities.
286
+
287
+ Shows fate probabilities aggregated by cluster as bar, paga, or heatmap.
288
+
289
+ Data requirements:
290
+ - adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
291
+ - adata.obs[cluster_key]: Cluster labels for aggregation
292
+ """
293
+ require("cellrank", feature="fate map")
294
+ import cellrank as cr
295
+
296
+ # Check for CellRank results
297
+ fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
298
+ fate_key = None
299
+ for key in fate_key_candidates:
300
+ if key in adata.obsm:
301
+ fate_key = key
302
+ break
303
+
304
+ if not fate_key:
305
+ raise DataNotFoundError(
306
+ "CellRank fate probabilities not found. Run trajectory analysis first."
307
+ )
308
+
309
+ # Determine cluster key
310
+ cluster_key = params.cluster_key
311
+ if not cluster_key:
312
+ categorical_cols = get_categorical_columns(adata)
313
+ if categorical_cols:
314
+ cluster_key = categorical_cols[0]
315
+ if context:
316
+ await context.info(f"Using cluster_key: '{cluster_key}'")
317
+ else:
318
+ raise ParameterError("cluster_key is required for fate map visualization.")
319
+
320
+ if context:
321
+ await context.info(f"Creating CellRank fate map for '{cluster_key}'")
322
+
323
+ # Use centralized figure size resolution
324
+ figsize = resolve_figure_size(params, "violin") # similar width to violin plots
325
+
326
+ with non_interactive_backend():
327
+ cr.pl.aggregate_fate_probabilities(
328
+ adata,
329
+ cluster_key=cluster_key,
330
+ mode="bar",
331
+ figsize=figsize,
332
+ dpi=params.dpi,
333
+ )
334
+ fig = plt.gcf()
335
+
336
+ title = params.title or f"CellRank Fate Probabilities by {cluster_key}"
337
+ fig.suptitle(title, fontsize=14, y=1.02)
338
+
339
+ plt.tight_layout()
340
+ return fig
341
+
342
+
343
+ async def _create_cellrank_gene_trends(
344
+ adata: "ad.AnnData",
345
+ params: VisualizationParameters,
346
+ context: Optional["ToolContext"] = None,
347
+ ) -> plt.Figure:
348
+ """Create CellRank gene expression trends using cr.pl.gene_trends.
349
+
350
+ Shows gene expression trends along lineages/pseudotime.
351
+
352
+ Data requirements:
353
+ - adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
354
+ - adata.obs['latent_time'] or similar pseudotime
355
+ - Gene expression in adata.X
356
+ """
357
+ require("cellrank", feature="gene trends")
358
+ import cellrank as cr
359
+
360
+ # Import GAM model preparation from trajectory module
361
+ from ..trajectory import prepare_gam_model_for_visualization
362
+
363
+ # Check for fate probabilities
364
+ fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
365
+ fate_key = None
366
+ for key in fate_key_candidates:
367
+ if key in adata.obsm:
368
+ fate_key = key
369
+ break
370
+
371
+ if not fate_key:
372
+ raise DataNotFoundError(
373
+ "CellRank fate probabilities not found. Run trajectory analysis first."
374
+ )
375
+
376
+ # Find time key
377
+ time_key = None
378
+ time_candidates = ["latent_time", "palantir_pseudotime", "dpt_pseudotime"]
379
+ for key in time_candidates:
380
+ if key in adata.obs.columns:
381
+ time_key = key
382
+ break
383
+
384
+ if not time_key:
385
+ raise DataNotFoundError("No pseudotime found. Run trajectory analysis first.")
386
+
387
+ # Get genes to plot
388
+ if params.feature:
389
+ if isinstance(params.feature, str):
390
+ genes = [params.feature]
391
+ else:
392
+ genes = list(params.feature)
393
+ valid_genes = [g for g in genes if g in adata.var_names]
394
+ if not valid_genes:
395
+ raise DataNotFoundError(f"None of the specified genes found: {genes}")
396
+ genes = valid_genes[:6]
397
+ else:
398
+ if "highly_variable" in adata.var.columns:
399
+ hvg = adata.var_names[adata.var["highly_variable"]]
400
+ genes = list(hvg[:6])
401
+ else:
402
+ genes = list(adata.var_names[:6])
403
+
404
+ if context:
405
+ await context.info(f"Creating gene trends for: {genes}")
406
+
407
+ # Use centralized figure size resolution with dynamic panel height
408
+ figsize = resolve_figure_size(
409
+ params, n_panels=len(genes), panel_width=12, panel_height=3
410
+ )
411
+
412
+ model, lineage_names = prepare_gam_model_for_visualization(
413
+ adata, genes, time_key=time_key, fate_key=fate_key
414
+ )
415
+
416
+ if context:
417
+ await context.info(f"Lineages: {lineage_names}")
418
+
419
+ with non_interactive_backend():
420
+ cr.pl.gene_trends(
421
+ adata,
422
+ model=model,
423
+ genes=genes,
424
+ time_key=time_key,
425
+ figsize=figsize,
426
+ n_jobs=1,
427
+ show_progress_bar=False,
428
+ )
429
+ fig = plt.gcf()
430
+ fig.set_dpi(params.dpi)
431
+
432
+ if params.title:
433
+ fig.suptitle(params.title, fontsize=14, y=1.02)
434
+
435
+ plt.tight_layout()
436
+ return fig
437
+
438
+
439
+ async def _create_cellrank_fate_heatmap(
440
+ adata: "ad.AnnData",
441
+ params: VisualizationParameters,
442
+ context: Optional["ToolContext"] = None,
443
+ ) -> plt.Figure:
444
+ """Create CellRank fate heatmap using cr.pl.heatmap.
445
+
446
+ Shows smoothed gene expression ordered by pseudotime per lineage.
447
+
448
+ Data requirements:
449
+ - adata.obsm['lineages_fwd'] or 'to_terminal_states': Fate probabilities
450
+ - adata.obs['latent_time'] or similar pseudotime
451
+ - Gene expression in adata.X
452
+ """
453
+ require("cellrank", feature="fate heatmap")
454
+ import cellrank as cr
455
+
456
+ # Import GAM model preparation from trajectory module
457
+ from ..trajectory import prepare_gam_model_for_visualization
458
+
459
+ # Check for fate probabilities
460
+ fate_key_candidates = ["lineages_fwd", "to_terminal_states"]
461
+ fate_key = None
462
+ for key in fate_key_candidates:
463
+ if key in adata.obsm:
464
+ fate_key = key
465
+ break
466
+
467
+ if not fate_key:
468
+ raise DataNotFoundError(
469
+ "CellRank fate probabilities not found. Run trajectory analysis first."
470
+ )
471
+
472
+ # Find time key
473
+ time_key = None
474
+ time_candidates = ["latent_time", "palantir_pseudotime", "dpt_pseudotime"]
475
+ for key in time_candidates:
476
+ if key in adata.obs.columns:
477
+ time_key = key
478
+ break
479
+
480
+ if not time_key:
481
+ raise DataNotFoundError("No pseudotime found for fate heatmap.")
482
+
483
+ # Get genes
484
+ if params.feature:
485
+ if isinstance(params.feature, str):
486
+ genes = [params.feature]
487
+ else:
488
+ genes = list(params.feature)
489
+ valid_genes = [g for g in genes if g in adata.var_names]
490
+ if not valid_genes:
491
+ raise DataNotFoundError(f"None of the genes found: {genes}")
492
+ genes = valid_genes[:50]
493
+ else:
494
+ if "highly_variable" in adata.var.columns:
495
+ hvg = adata.var_names[adata.var["highly_variable"]]
496
+ genes = list(hvg[:50])
497
+ else:
498
+ genes = list(adata.var_names[:50])
499
+
500
+ if context:
501
+ await context.info(f"Creating fate heatmap with {len(genes)} genes")
502
+
503
+ # Use centralized figure size resolution
504
+ figsize = resolve_figure_size(params, "heatmap")
505
+
506
+ model, lineage_names = prepare_gam_model_for_visualization(
507
+ adata, genes, time_key=time_key, fate_key=fate_key
508
+ )
509
+
510
+ if context:
511
+ await context.info(f"Lineages: {lineage_names}")
512
+
513
+ with non_interactive_backend():
514
+ cr.pl.heatmap(
515
+ adata,
516
+ model=model,
517
+ genes=genes,
518
+ time_key=time_key,
519
+ figsize=figsize,
520
+ n_jobs=1,
521
+ show_progress_bar=False,
522
+ )
523
+ fig = plt.gcf()
524
+ fig.set_dpi(params.dpi)
525
+
526
+ if params.title:
527
+ fig.suptitle(params.title, fontsize=14, y=1.02)
528
+
529
+ plt.tight_layout()
530
+ return fig
531
+
532
+
533
+ async def _create_palantir_results(
534
+ adata: "ad.AnnData",
535
+ params: VisualizationParameters,
536
+ context: Optional["ToolContext"] = None,
537
+ ) -> plt.Figure:
538
+ """Create Palantir comprehensive results visualization.
539
+
540
+ Shows pseudotime, entropy, and fate probabilities in a multi-panel figure.
541
+
542
+ Data requirements:
543
+ - adata.obs['palantir_pseudotime']: Pseudotime
544
+ - adata.obs['palantir_entropy']: Differentiation entropy
545
+ - adata.obsm['palantir_fate_probs'] or 'palantir_branch_probs': Fate probabilities
546
+ """
547
+ # Check for Palantir results
548
+ has_pseudotime = "palantir_pseudotime" in adata.obs.columns
549
+ has_entropy = "palantir_entropy" in adata.obs.columns
550
+ fate_key = None
551
+ for key in ["palantir_fate_probs", "palantir_branch_probs"]:
552
+ if key in adata.obsm:
553
+ fate_key = key
554
+ break
555
+
556
+ if not has_pseudotime:
557
+ raise DataNotFoundError(
558
+ "Palantir results not found. Run trajectory analysis first."
559
+ )
560
+
561
+ if context:
562
+ await context.info("Creating Palantir results visualization")
563
+
564
+ # Determine basis
565
+ basis = infer_basis(
566
+ adata, preferred=params.basis, priority=["umap", "spatial", "pca"]
567
+ )
568
+
569
+ # Determine number of panels
570
+ n_panels = 1 + int(has_entropy) + (1 if fate_key else 0)
571
+
572
+ # Create figure with centralized utility
573
+ figsize = resolve_figure_size(params, n_panels=n_panels, panel_width=5, panel_height=5)
574
+ fig, axes = plt.subplots(1, n_panels, figsize=figsize, dpi=params.dpi)
575
+ if n_panels == 1:
576
+ axes = [axes]
577
+
578
+ panel_idx = 0
579
+
580
+ # Panel 1: Pseudotime
581
+ ax = axes[panel_idx]
582
+ sc.pl.embedding(
583
+ adata,
584
+ basis=basis,
585
+ color="palantir_pseudotime",
586
+ cmap="viridis",
587
+ ax=ax,
588
+ show=False,
589
+ frameon=params.show_axes,
590
+ title="Palantir Pseudotime",
591
+ )
592
+ if basis == "spatial":
593
+ ax.invert_yaxis()
594
+ panel_idx += 1
595
+
596
+ # Panel 2: Entropy (if available)
597
+ if has_entropy and panel_idx < n_panels:
598
+ ax = axes[panel_idx]
599
+ sc.pl.embedding(
600
+ adata,
601
+ basis=basis,
602
+ color="palantir_entropy",
603
+ cmap="magma",
604
+ ax=ax,
605
+ show=False,
606
+ frameon=params.show_axes,
607
+ title="Differentiation Entropy",
608
+ )
609
+ if basis == "spatial":
610
+ ax.invert_yaxis()
611
+ panel_idx += 1
612
+
613
+ # Panel 3: Fate probabilities summary (if available)
614
+ if fate_key and panel_idx < n_panels:
615
+ ax = axes[panel_idx]
616
+ fate_probs = adata.obsm[fate_key]
617
+ dominant_fate = fate_probs.argmax(axis=1)
618
+ adata.obs["_dominant_fate"] = dominant_fate.astype(str)
619
+
620
+ sc.pl.embedding(
621
+ adata,
622
+ basis=basis,
623
+ color="_dominant_fate",
624
+ ax=ax,
625
+ show=False,
626
+ frameon=params.show_axes,
627
+ title="Dominant Fate",
628
+ )
629
+ if basis == "spatial":
630
+ ax.invert_yaxis()
631
+
632
+ # Clean up temporary column
633
+ del adata.obs["_dominant_fate"]
634
+
635
+ title = params.title or "Palantir Trajectory Analysis"
636
+ fig.suptitle(title, fontsize=14, y=1.02)
637
+
638
+ plt.tight_layout()
639
+ return fig