chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,660 @@
1
+ """
2
+ Enrichment analysis visualization functions.
3
+
4
+ This module contains:
5
+ - Pathway enrichment barplots and dotplots
6
+ - GSEA enrichment score plots
7
+ - Spatial enrichment score visualization
8
+ - EnrichMap spatial autocorrelation plots
9
+ """
10
+
11
+ from typing import TYPE_CHECKING, Optional
12
+
13
+ import matplotlib.pyplot as plt
14
+ import pandas as pd
15
+ import seaborn as sns
16
+
17
+ if TYPE_CHECKING:
18
+ import anndata as ad
19
+
20
+ from ...spatial_mcp_adapter import ToolContext
21
+
22
+ from ...models.data import VisualizationParameters
23
+ from ...utils.adata_utils import get_analysis_parameter, validate_obs_column
24
+ from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
25
+ from .core import (
26
+ create_figure,
27
+ get_categorical_columns,
28
+ plot_spatial_feature,
29
+ resolve_figure_size,
30
+ setup_multi_panel_figure,
31
+ )
32
+
33
+ # =============================================================================
34
+ # Helper Functions
35
+ # =============================================================================
36
+
37
+
38
+ def _ensure_enrichmap_compatibility(adata: "ad.AnnData") -> None:
39
+ """Ensure data has required metadata structure for EnrichMap visualization.
40
+
41
+ EnrichMap and squidpy require:
42
+ 1. adata.obs['library_id'] - sample identifier column
43
+ 2. adata.uns['spatial'] - spatial metadata dictionary
44
+
45
+ This function adds minimal metadata for single-sample data without these.
46
+ """
47
+ if "library_id" not in adata.obs.columns:
48
+ adata.obs["library_id"] = "sample_1"
49
+
50
+ if "spatial" not in adata.uns:
51
+ library_ids = adata.obs["library_id"].unique()
52
+ adata.uns["spatial"] = {}
53
+ for lib_id in library_ids:
54
+ adata.uns["spatial"][lib_id] = {
55
+ "images": {},
56
+ "scalefactors": {
57
+ "spot_diameter_fullres": 1.0,
58
+ "tissue_hires_scalef": 1.0,
59
+ "fiducial_diameter_fullres": 1.0,
60
+ "tissue_lowres_scalef": 1.0,
61
+ },
62
+ }
63
+
64
+
65
+ def _get_score_columns(adata: "ad.AnnData") -> list[str]:
66
+ """Get all enrichment score columns from adata.obs.
67
+
68
+ Priority:
69
+ 1. Read from stored metadata (most reliable, knows exact columns)
70
+ 2. Fall back to suffix search (for legacy data without metadata)
71
+
72
+ Returns columns from:
73
+ - enrichment_spatial_metadata["results_keys"]["obs"] (e.g., 'Wnt_score')
74
+ - enrichment_ssgsea_metadata["results_keys"]["obs"] (e.g., 'ssgsea_Wnt')
75
+ """
76
+ score_cols = []
77
+
78
+ # Try to get from stored metadata (first principles: read what was stored)
79
+ for analysis_name in ["enrichment_spatial", "enrichment_ssgsea"]:
80
+ obs_cols = get_analysis_parameter(adata, analysis_name, "results_keys")
81
+ if obs_cols and isinstance(obs_cols, dict) and "obs" in obs_cols:
82
+ # Filter to only columns that actually exist
83
+ for col in obs_cols["obs"]:
84
+ if col in adata.obs.columns and col not in score_cols:
85
+ score_cols.append(col)
86
+
87
+ # Fall back to suffix search (for legacy data without metadata)
88
+ if not score_cols:
89
+ score_cols = [col for col in adata.obs.columns if col.endswith("_score")]
90
+
91
+ return score_cols
92
+
93
+
94
+ def _resolve_score_column(
95
+ adata: "ad.AnnData",
96
+ feature: Optional[str],
97
+ score_cols: list[str],
98
+ ) -> str:
99
+ """Resolve feature name to actual score column name."""
100
+ if feature:
101
+ if feature in adata.obs.columns:
102
+ return feature
103
+ if f"{feature}_score" in adata.obs.columns:
104
+ return f"{feature}_score"
105
+ raise DataNotFoundError(
106
+ f"Score column '{feature}' not found. Available: {score_cols}"
107
+ )
108
+ if score_cols:
109
+ return score_cols[0]
110
+ raise DataNotFoundError("No enrichment scores found in adata.obs")
111
+
112
+
113
+ # =============================================================================
114
+ # Main Routers
115
+ # =============================================================================
116
+
117
+
118
+ async def create_enrichment_visualization(
119
+ adata: "ad.AnnData",
120
+ params: VisualizationParameters,
121
+ context: Optional["ToolContext"] = None,
122
+ ) -> plt.Figure:
123
+ """Create enrichment score visualization.
124
+
125
+ Routes to appropriate visualization based on params:
126
+ - violin: Enrichment scores violin plot by cluster
127
+ - spatial_*: EnrichMap spatial visualizations
128
+ - Default: Standard spatial scatter plot
129
+
130
+ Args:
131
+ adata: AnnData object with enrichment scores
132
+ params: Visualization parameters
133
+ context: MCP context for logging
134
+
135
+ Returns:
136
+ Matplotlib figure
137
+ """
138
+ if context:
139
+ await context.info("Creating enrichment visualization")
140
+
141
+ score_cols = _get_score_columns(adata)
142
+ if not score_cols:
143
+ raise DataNotFoundError(
144
+ "No enrichment scores found. Run 'analyze_enrichment' first."
145
+ )
146
+
147
+ # Route based on plot_type or subtype
148
+ if params.plot_type == "violin":
149
+ return _create_enrichment_violin(adata, params, score_cols, context)
150
+
151
+ if params.subtype and params.subtype.startswith("spatial_"):
152
+ return _create_enrichmap_spatial(adata, params, score_cols, context)
153
+
154
+ # Default: spatial scatter plot
155
+ return await _create_enrichment_spatial(adata, params, score_cols, context)
156
+
157
+
158
+ async def create_pathway_enrichment_visualization(
159
+ adata: "ad.AnnData",
160
+ params: VisualizationParameters,
161
+ context: Optional["ToolContext"] = None,
162
+ ) -> plt.Figure:
163
+ """Create pathway enrichment visualization (GSEA/ORA results).
164
+
165
+ Supports multiple visualization types:
166
+ Traditional:
167
+ - barplot: Top enriched pathways barplot
168
+ - dotplot: Multi-cluster enrichment dotplot
169
+ - enrichment_plot: Classic GSEA running score plot
170
+
171
+ Spatial EnrichMap:
172
+ - spatial_score, spatial_correlogram, etc.
173
+
174
+ Args:
175
+ adata: AnnData object with enrichment results
176
+ params: Visualization parameters
177
+ context: MCP context for logging
178
+
179
+ Returns:
180
+ Matplotlib figure
181
+ """
182
+ if context:
183
+ await context.info("Creating pathway enrichment visualization")
184
+
185
+ plot_type = params.subtype or "barplot"
186
+
187
+ # Route spatial subtypes to enrichment visualization
188
+ if plot_type.startswith("spatial_"):
189
+ return await create_enrichment_visualization(adata, params, context)
190
+
191
+ # Get GSEA/ORA results from adata.uns
192
+ gsea_key = getattr(params, "gsea_results_key", "gsea_results")
193
+ if gsea_key not in adata.uns:
194
+ alt_keys = ["rank_genes_groups", "de_results", "pathway_enrichment"]
195
+ for key in alt_keys:
196
+ if key in adata.uns:
197
+ gsea_key = key
198
+ break
199
+ else:
200
+ raise DataNotFoundError(f"GSEA results not found. Expected key: {gsea_key}")
201
+
202
+ gsea_results = adata.uns[gsea_key]
203
+
204
+ if plot_type == "enrichment_plot":
205
+ return _create_gsea_enrichment_plot(gsea_results, params)
206
+ elif plot_type == "dotplot":
207
+ return _create_gsea_dotplot(gsea_results, params)
208
+ else: # Default to barplot
209
+ return _create_gsea_barplot(gsea_results, params)
210
+
211
+
212
+ # =============================================================================
213
+ # Enrichment Score Visualizations
214
+ # =============================================================================
215
+
216
+
217
+ def _create_enrichment_violin(
218
+ adata: "ad.AnnData",
219
+ params: VisualizationParameters,
220
+ score_cols: list[str],
221
+ context: Optional["ToolContext"] = None,
222
+ ) -> plt.Figure:
223
+ """Create violin plot of enrichment scores grouped by cluster."""
224
+ if not params.cluster_key:
225
+ categorical_cols = get_categorical_columns(adata, limit=15)
226
+ raise ParameterError(
227
+ "Enrichment violin plot requires 'cluster_key' parameter.\n"
228
+ f"Available categorical columns: {', '.join(categorical_cols)}"
229
+ )
230
+
231
+ validate_obs_column(adata, params.cluster_key, "Cluster")
232
+
233
+ # Determine scores to plot
234
+ scores_to_plot = _resolve_feature_list(
235
+ params.feature, adata.obs.columns, score_cols
236
+ )
237
+ if not scores_to_plot:
238
+ scores_to_plot = score_cols[:3]
239
+
240
+ n_scores = len(scores_to_plot)
241
+ # Use centralized figure size resolution for multi-panel layout
242
+ figsize = resolve_figure_size(params, n_panels=n_scores, panel_width=5, panel_height=6)
243
+ fig, axes = plt.subplots(1, n_scores, figsize=figsize)
244
+ if n_scores == 1:
245
+ axes = [axes]
246
+
247
+ for i, score in enumerate(scores_to_plot):
248
+ ax = axes[i]
249
+ data = pd.DataFrame(
250
+ {
251
+ params.cluster_key: adata.obs[params.cluster_key],
252
+ "Score": adata.obs[score],
253
+ }
254
+ )
255
+ sns.violinplot(data=data, x=params.cluster_key, y="Score", ax=ax)
256
+
257
+ sig_name = score.replace("_score", "")
258
+ ax.set_title(f"{sig_name} by {params.cluster_key}")
259
+ ax.set_xlabel(params.cluster_key)
260
+ ax.set_ylabel("Enrichment Score")
261
+ ax.tick_params(axis="x", rotation=45)
262
+ for label in ax.get_xticklabels():
263
+ label.set_horizontalalignment("right")
264
+
265
+ plt.tight_layout()
266
+ return fig
267
+
268
+
269
+ async def _create_enrichment_spatial(
270
+ adata: "ad.AnnData",
271
+ params: VisualizationParameters,
272
+ score_cols: list[str],
273
+ context: Optional["ToolContext"] = None,
274
+ ) -> plt.Figure:
275
+ """Create spatial scatter plot of enrichment scores."""
276
+ feature_list = _resolve_feature_list(params.feature, adata.obs.columns, score_cols)
277
+
278
+ if feature_list and len(feature_list) > 1:
279
+ # Multi-score visualization
280
+ scores_to_plot = []
281
+ for feat in feature_list:
282
+ if feat in adata.obs.columns:
283
+ scores_to_plot.append(feat)
284
+ elif f"{feat}_score" in adata.obs.columns:
285
+ scores_to_plot.append(f"{feat}_score")
286
+
287
+ if not scores_to_plot:
288
+ raise DataNotFoundError(
289
+ f"None of the specified scores found: {feature_list}"
290
+ )
291
+
292
+ fig, axes = setup_multi_panel_figure(
293
+ n_panels=len(scores_to_plot),
294
+ params=params,
295
+ default_title="Enrichment Scores",
296
+ )
297
+
298
+ for i, score in enumerate(scores_to_plot):
299
+ if i < len(axes):
300
+ ax = axes[i]
301
+ plot_spatial_feature(adata, feature=score, ax=ax, params=params)
302
+ sig_name = score.replace("_score", "")
303
+ ax.set_title(f"{sig_name} Enrichment")
304
+ else:
305
+ # Single score visualization
306
+ score_col = _resolve_score_column(adata, params.feature, score_cols)
307
+ if context:
308
+ await context.info(f"Using score column: {score_col}")
309
+
310
+ fig, ax = create_figure(figsize=(10, 8))
311
+ plot_spatial_feature(adata, feature=score_col, ax=ax, params=params)
312
+
313
+ sig_name = score_col.replace("_score", "")
314
+ ax.set_title(f"{sig_name} Enrichment Score", fontsize=14)
315
+
316
+ if params.show_colorbar and hasattr(ax, "collections") and ax.collections:
317
+ cbar = plt.colorbar(ax.collections[0], ax=ax)
318
+ cbar.set_label("Enrichment Score", fontsize=12)
319
+
320
+ plt.tight_layout()
321
+ return fig
322
+
323
+
324
+ def _create_enrichmap_spatial(
325
+ adata: "ad.AnnData",
326
+ params: VisualizationParameters,
327
+ score_cols: list[str],
328
+ context: Optional["ToolContext"] = None,
329
+ ) -> plt.Figure:
330
+ """Create EnrichMap spatial autocorrelation visualizations."""
331
+ try:
332
+ import enrichmap as em
333
+ except ImportError as e:
334
+ raise ProcessingError(
335
+ f"Spatial enrichment visualization ('{params.subtype}') requires EnrichMap.\n"
336
+ "Install with: pip install enrichmap"
337
+ ) from e
338
+
339
+ _ensure_enrichmap_compatibility(adata)
340
+ library_id = adata.obs["library_id"].unique()[0]
341
+
342
+ try:
343
+ if params.subtype == "spatial_cross_correlation":
344
+ return _create_enrichmap_cross_correlation(adata, params, library_id, em)
345
+ else:
346
+ return _create_enrichmap_single_score(
347
+ adata, params, library_id, em, context
348
+ )
349
+ except DataNotFoundError:
350
+ raise
351
+ except Exception as e:
352
+ plt.close("all")
353
+ raise ProcessingError(
354
+ f"EnrichMap {params.subtype} visualization failed: {e}\n\n"
355
+ "Solutions:\n"
356
+ "1. Verify the enrichment analysis completed successfully\n"
357
+ "2. Check that spatial neighbors graph exists\n"
358
+ "3. Ensure enrichment scores are stored in adata.obs"
359
+ ) from e
360
+
361
+
362
+ def _create_enrichmap_cross_correlation(
363
+ adata: "ad.AnnData",
364
+ params: VisualizationParameters,
365
+ library_id: str,
366
+ em,
367
+ ) -> plt.Figure:
368
+ """Create EnrichMap cross-correlation visualization."""
369
+ if "enrichment_gene_sets" not in adata.uns:
370
+ raise DataNotFoundError("enrichment_gene_sets not found in adata.uns")
371
+
372
+ pathways = list(adata.uns["enrichment_gene_sets"].keys())
373
+ if len(pathways) < 2:
374
+ raise DataNotFoundError("Need at least 2 pathways for cross-correlation")
375
+
376
+ score_x = f"{pathways[0]}_score"
377
+ score_y = f"{pathways[1]}_score"
378
+
379
+ em.pl.cross_moran_scatter(
380
+ adata, score_x=score_x, score_y=score_y, library_id=library_id
381
+ )
382
+
383
+ fig = plt.gcf()
384
+ if params.figure_size:
385
+ fig.set_size_inches(params.figure_size)
386
+ if params.dpi:
387
+ fig.set_dpi(params.dpi)
388
+
389
+ return fig
390
+
391
+
392
+ def _create_enrichmap_single_score(
393
+ adata: "ad.AnnData",
394
+ params: VisualizationParameters,
395
+ library_id: str,
396
+ em,
397
+ context: Optional["ToolContext"] = None,
398
+ ) -> plt.Figure:
399
+ """Create single-score EnrichMap visualization."""
400
+ if not params.feature:
401
+ raise DataNotFoundError(
402
+ "Feature parameter required for spatial enrichment visualization"
403
+ )
404
+
405
+ score_col = f"{params.feature}_score"
406
+ validate_obs_column(adata, score_col, "Score")
407
+
408
+ if params.subtype == "spatial_correlogram":
409
+ em.pl.morans_correlogram(adata, score_key=score_col, library_id=library_id)
410
+ elif params.subtype == "spatial_variogram":
411
+ em.pl.variogram(adata, score_keys=[score_col])
412
+ elif params.subtype == "spatial_score":
413
+ spot_size = params.spot_size if params.spot_size is not None else 0.5
414
+ em.pl.spatial_enrichmap(
415
+ adata,
416
+ score_key=score_col,
417
+ library_id=library_id,
418
+ cmap="seismic",
419
+ vcenter=0,
420
+ size=spot_size,
421
+ img=False,
422
+ )
423
+
424
+ fig = plt.gcf()
425
+ if params.figure_size:
426
+ fig.set_size_inches(params.figure_size)
427
+ if params.dpi:
428
+ fig.set_dpi(params.dpi)
429
+
430
+ return fig
431
+
432
+
433
+ # =============================================================================
434
+ # GSEA/ORA Pathway Visualizations
435
+ # =============================================================================
436
+
437
+
438
+ def _create_gsea_enrichment_plot(
439
+ gsea_results,
440
+ params: VisualizationParameters,
441
+ ) -> plt.Figure:
442
+ """Create classic GSEA running enrichment score plot.
443
+
444
+ Requires full gseapy result object with RES and hits data.
445
+ """
446
+ pathway = params.feature if params.feature else None
447
+
448
+ if isinstance(gsea_results, pd.DataFrame):
449
+ raise DataNotFoundError(
450
+ "Enrichment plot requires running enrichment scores (RES) data.\n"
451
+ "The stored results contain only summary statistics.\n\n"
452
+ "Solutions:\n"
453
+ "1. Use subtype='barplot' or subtype='dotplot' instead\n"
454
+ "2. Re-run GSEA analysis and store the full result object"
455
+ )
456
+
457
+ if isinstance(gsea_results, dict):
458
+ if pathway and pathway in gsea_results:
459
+ result = gsea_results[pathway]
460
+ else:
461
+ pathway = next(iter(gsea_results))
462
+ result = gsea_results[pathway]
463
+
464
+ if not isinstance(result, dict) or "RES" not in result:
465
+ raise DataNotFoundError(
466
+ "Enrichment plot requires 'RES' (running enrichment scores) data.\n"
467
+ "Use subtype='barplot' or subtype='dotplot' instead."
468
+ )
469
+
470
+ import gseapy as gp
471
+
472
+ # Use centralized figure size with enrichment default
473
+ figsize = resolve_figure_size(params, "enrichment")
474
+ fig = gp.gseaplot(
475
+ term=pathway,
476
+ hits=result.get("hits", result.get("hit_indices", [])),
477
+ nes=result.get("NES", result.get("nes", 0)),
478
+ pval=result.get("pval", result.get("NOM p-val", 0)),
479
+ fdr=result.get("fdr", result.get("FDR q-val", 0)),
480
+ RES=result["RES"],
481
+ rank_metric=result.get("rank_metric"),
482
+ figsize=figsize,
483
+ ofname=None,
484
+ )
485
+ return fig
486
+
487
+ raise ParameterError(f"Unsupported GSEA results format: {type(gsea_results)}")
488
+
489
+
490
+ def _create_gsea_barplot(
491
+ gsea_results,
492
+ params: VisualizationParameters,
493
+ ) -> plt.Figure:
494
+ """Create barplot of top enriched pathways."""
495
+ import gseapy as gp
496
+
497
+ n_top = getattr(params, "n_top_pathways", 10)
498
+ df = _gsea_results_to_dataframe(gsea_results)
499
+
500
+ if df.empty:
501
+ raise DataNotFoundError("No enrichment results found")
502
+
503
+ pval_col = _find_pvalue_column(df)
504
+ _ensure_term_column(df)
505
+
506
+ # Use centralized figure size with dynamic height based on pathway count
507
+ figsize = resolve_figure_size(
508
+ params, n_panels=n_top, panel_width=6, panel_height=0.4
509
+ )
510
+ color = params.colormap if params.colormap != "coolwarm" else "salmon"
511
+
512
+ try:
513
+ ax = gp.barplot(
514
+ df=df,
515
+ column=pval_col,
516
+ title=params.title or "Top Enriched Pathways",
517
+ cutoff=1.0,
518
+ top_term=n_top,
519
+ figsize=figsize,
520
+ color=color,
521
+ ofname=None,
522
+ )
523
+ fig = ax.get_figure()
524
+ plt.tight_layout()
525
+ return fig
526
+ except Exception as e:
527
+ raise ProcessingError(
528
+ f"gseapy.barplot failed: {e}\n" f"Available columns: {list(df.columns)}"
529
+ ) from e
530
+
531
+
532
+ def _create_gsea_dotplot(
533
+ gsea_results,
534
+ params: VisualizationParameters,
535
+ ) -> plt.Figure:
536
+ """Create dotplot of pathway enrichment."""
537
+ import gseapy as gp
538
+
539
+ n_top = getattr(params, "n_top_pathways", 10)
540
+
541
+ # Handle nested dict (multi-condition)
542
+ if isinstance(gsea_results, dict) and all(
543
+ isinstance(v, dict) for v in gsea_results.values()
544
+ ):
545
+ df, x_col = _nested_dict_to_dataframe(gsea_results)
546
+ else:
547
+ df = _gsea_results_to_dataframe(gsea_results)
548
+ x_col = None
549
+
550
+ if df.empty:
551
+ raise DataNotFoundError("No enrichment results found")
552
+
553
+ _ensure_term_column(df)
554
+ pval_col = _find_pvalue_column(df)
555
+
556
+ figsize = params.figure_size or (6, 8)
557
+ cmap = params.colormap if params.colormap != "coolwarm" else "viridis_r"
558
+
559
+ try:
560
+ ax = gp.dotplot(
561
+ df=df,
562
+ column=pval_col,
563
+ x=x_col,
564
+ y="Term",
565
+ title=params.title or "Pathway Enrichment",
566
+ cutoff=1.0,
567
+ top_term=n_top,
568
+ figsize=figsize,
569
+ cmap=cmap,
570
+ size=5,
571
+ ofname=None,
572
+ )
573
+ fig = ax.get_figure()
574
+ plt.tight_layout()
575
+ return fig
576
+ except Exception as e:
577
+ raise ProcessingError(
578
+ f"gseapy.dotplot failed: {e}\n" f"Available columns: {list(df.columns)}"
579
+ ) from e
580
+
581
+
582
+ # =============================================================================
583
+ # Utility Functions
584
+ # =============================================================================
585
+
586
+
587
+ def _resolve_feature_list(
588
+ feature,
589
+ obs_columns: pd.Index,
590
+ score_cols: list[str],
591
+ ) -> list[str]:
592
+ """Resolve feature parameter to list of valid score columns."""
593
+ if feature is None:
594
+ return []
595
+ if isinstance(feature, list):
596
+ return feature
597
+ return [feature]
598
+
599
+
600
+ def _gsea_results_to_dataframe(gsea_results) -> pd.DataFrame:
601
+ """Convert GSEA results to DataFrame."""
602
+ if isinstance(gsea_results, pd.DataFrame):
603
+ return gsea_results.copy()
604
+ if isinstance(gsea_results, dict):
605
+ rows = []
606
+ for pathway, data in gsea_results.items():
607
+ if isinstance(data, dict):
608
+ row = {"Term": pathway}
609
+ row.update(data)
610
+ rows.append(row)
611
+ return pd.DataFrame(rows)
612
+ raise ParameterError("Unsupported GSEA results format")
613
+
614
+
615
+ def _nested_dict_to_dataframe(gsea_results: dict):
616
+ """Convert nested dict (multi-condition) to DataFrame with Group column."""
617
+ rows = []
618
+ for condition, pathways in gsea_results.items():
619
+ for pathway, data in pathways.items():
620
+ if isinstance(data, dict):
621
+ row = {"Term": pathway, "Group": condition}
622
+ row.update(data)
623
+ rows.append(row)
624
+ return pd.DataFrame(rows), "Group"
625
+
626
+
627
+ def _find_pvalue_column(df: pd.DataFrame) -> str:
628
+ """Find the p-value column in GSEA results DataFrame.
629
+
630
+ Handles multiple naming conventions from different enrichment methods.
631
+ """
632
+ # Check common p-value column names (order by preference)
633
+ candidates = [
634
+ "Adjusted P-value", # gseapy standard
635
+ "adjusted_pvalue", # ChatSpatial internal format
636
+ "FDR q-val", # GSEA standard
637
+ "fdr",
638
+ "P-value",
639
+ "pvalue",
640
+ "NOM p-val",
641
+ "pval",
642
+ ]
643
+ for col in candidates:
644
+ if col in df.columns:
645
+ return col
646
+ return "Adjusted P-value"
647
+
648
+
649
+ def _ensure_term_column(df: pd.DataFrame) -> None:
650
+ """Ensure DataFrame has a 'Term' column."""
651
+ if "Term" in df.columns:
652
+ return
653
+ if "pathway" in df.columns:
654
+ df["Term"] = df["pathway"]
655
+ elif df.index.name or not df.index.equals(pd.RangeIndex(len(df))):
656
+ df["Term"] = df.index
657
+ else:
658
+ raise DataNotFoundError(
659
+ "No pathway/term column found. Expected 'Term' or 'pathway' column."
660
+ )