chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,660 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enrichment analysis visualization functions.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Pathway enrichment barplots and dotplots
|
|
6
|
+
- GSEA enrichment score plots
|
|
7
|
+
- Spatial enrichment score visualization
|
|
8
|
+
- EnrichMap spatial autocorrelation plots
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import TYPE_CHECKING, Optional
|
|
12
|
+
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import seaborn as sns
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import anndata as ad
|
|
19
|
+
|
|
20
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
21
|
+
|
|
22
|
+
from ...models.data import VisualizationParameters
|
|
23
|
+
from ...utils.adata_utils import get_analysis_parameter, validate_obs_column
|
|
24
|
+
from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
|
|
25
|
+
from .core import (
|
|
26
|
+
create_figure,
|
|
27
|
+
get_categorical_columns,
|
|
28
|
+
plot_spatial_feature,
|
|
29
|
+
resolve_figure_size,
|
|
30
|
+
setup_multi_panel_figure,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# Helper Functions
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _ensure_enrichmap_compatibility(adata: "ad.AnnData") -> None:
|
|
39
|
+
"""Ensure data has required metadata structure for EnrichMap visualization.
|
|
40
|
+
|
|
41
|
+
EnrichMap and squidpy require:
|
|
42
|
+
1. adata.obs['library_id'] - sample identifier column
|
|
43
|
+
2. adata.uns['spatial'] - spatial metadata dictionary
|
|
44
|
+
|
|
45
|
+
This function adds minimal metadata for single-sample data without these.
|
|
46
|
+
"""
|
|
47
|
+
if "library_id" not in adata.obs.columns:
|
|
48
|
+
adata.obs["library_id"] = "sample_1"
|
|
49
|
+
|
|
50
|
+
if "spatial" not in adata.uns:
|
|
51
|
+
library_ids = adata.obs["library_id"].unique()
|
|
52
|
+
adata.uns["spatial"] = {}
|
|
53
|
+
for lib_id in library_ids:
|
|
54
|
+
adata.uns["spatial"][lib_id] = {
|
|
55
|
+
"images": {},
|
|
56
|
+
"scalefactors": {
|
|
57
|
+
"spot_diameter_fullres": 1.0,
|
|
58
|
+
"tissue_hires_scalef": 1.0,
|
|
59
|
+
"fiducial_diameter_fullres": 1.0,
|
|
60
|
+
"tissue_lowres_scalef": 1.0,
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_score_columns(adata: "ad.AnnData") -> list[str]:
|
|
66
|
+
"""Get all enrichment score columns from adata.obs.
|
|
67
|
+
|
|
68
|
+
Priority:
|
|
69
|
+
1. Read from stored metadata (most reliable, knows exact columns)
|
|
70
|
+
2. Fall back to suffix search (for legacy data without metadata)
|
|
71
|
+
|
|
72
|
+
Returns columns from:
|
|
73
|
+
- enrichment_spatial_metadata["results_keys"]["obs"] (e.g., 'Wnt_score')
|
|
74
|
+
- enrichment_ssgsea_metadata["results_keys"]["obs"] (e.g., 'ssgsea_Wnt')
|
|
75
|
+
"""
|
|
76
|
+
score_cols = []
|
|
77
|
+
|
|
78
|
+
# Try to get from stored metadata (first principles: read what was stored)
|
|
79
|
+
for analysis_name in ["enrichment_spatial", "enrichment_ssgsea"]:
|
|
80
|
+
obs_cols = get_analysis_parameter(adata, analysis_name, "results_keys")
|
|
81
|
+
if obs_cols and isinstance(obs_cols, dict) and "obs" in obs_cols:
|
|
82
|
+
# Filter to only columns that actually exist
|
|
83
|
+
for col in obs_cols["obs"]:
|
|
84
|
+
if col in adata.obs.columns and col not in score_cols:
|
|
85
|
+
score_cols.append(col)
|
|
86
|
+
|
|
87
|
+
# Fall back to suffix search (for legacy data without metadata)
|
|
88
|
+
if not score_cols:
|
|
89
|
+
score_cols = [col for col in adata.obs.columns if col.endswith("_score")]
|
|
90
|
+
|
|
91
|
+
return score_cols
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _resolve_score_column(
|
|
95
|
+
adata: "ad.AnnData",
|
|
96
|
+
feature: Optional[str],
|
|
97
|
+
score_cols: list[str],
|
|
98
|
+
) -> str:
|
|
99
|
+
"""Resolve feature name to actual score column name."""
|
|
100
|
+
if feature:
|
|
101
|
+
if feature in adata.obs.columns:
|
|
102
|
+
return feature
|
|
103
|
+
if f"{feature}_score" in adata.obs.columns:
|
|
104
|
+
return f"{feature}_score"
|
|
105
|
+
raise DataNotFoundError(
|
|
106
|
+
f"Score column '{feature}' not found. Available: {score_cols}"
|
|
107
|
+
)
|
|
108
|
+
if score_cols:
|
|
109
|
+
return score_cols[0]
|
|
110
|
+
raise DataNotFoundError("No enrichment scores found in adata.obs")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# =============================================================================
|
|
114
|
+
# Main Routers
|
|
115
|
+
# =============================================================================
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
async def create_enrichment_visualization(
|
|
119
|
+
adata: "ad.AnnData",
|
|
120
|
+
params: VisualizationParameters,
|
|
121
|
+
context: Optional["ToolContext"] = None,
|
|
122
|
+
) -> plt.Figure:
|
|
123
|
+
"""Create enrichment score visualization.
|
|
124
|
+
|
|
125
|
+
Routes to appropriate visualization based on params:
|
|
126
|
+
- violin: Enrichment scores violin plot by cluster
|
|
127
|
+
- spatial_*: EnrichMap spatial visualizations
|
|
128
|
+
- Default: Standard spatial scatter plot
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
adata: AnnData object with enrichment scores
|
|
132
|
+
params: Visualization parameters
|
|
133
|
+
context: MCP context for logging
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Matplotlib figure
|
|
137
|
+
"""
|
|
138
|
+
if context:
|
|
139
|
+
await context.info("Creating enrichment visualization")
|
|
140
|
+
|
|
141
|
+
score_cols = _get_score_columns(adata)
|
|
142
|
+
if not score_cols:
|
|
143
|
+
raise DataNotFoundError(
|
|
144
|
+
"No enrichment scores found. Run 'analyze_enrichment' first."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Route based on plot_type or subtype
|
|
148
|
+
if params.plot_type == "violin":
|
|
149
|
+
return _create_enrichment_violin(adata, params, score_cols, context)
|
|
150
|
+
|
|
151
|
+
if params.subtype and params.subtype.startswith("spatial_"):
|
|
152
|
+
return _create_enrichmap_spatial(adata, params, score_cols, context)
|
|
153
|
+
|
|
154
|
+
# Default: spatial scatter plot
|
|
155
|
+
return await _create_enrichment_spatial(adata, params, score_cols, context)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
async def create_pathway_enrichment_visualization(
|
|
159
|
+
adata: "ad.AnnData",
|
|
160
|
+
params: VisualizationParameters,
|
|
161
|
+
context: Optional["ToolContext"] = None,
|
|
162
|
+
) -> plt.Figure:
|
|
163
|
+
"""Create pathway enrichment visualization (GSEA/ORA results).
|
|
164
|
+
|
|
165
|
+
Supports multiple visualization types:
|
|
166
|
+
Traditional:
|
|
167
|
+
- barplot: Top enriched pathways barplot
|
|
168
|
+
- dotplot: Multi-cluster enrichment dotplot
|
|
169
|
+
- enrichment_plot: Classic GSEA running score plot
|
|
170
|
+
|
|
171
|
+
Spatial EnrichMap:
|
|
172
|
+
- spatial_score, spatial_correlogram, etc.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
adata: AnnData object with enrichment results
|
|
176
|
+
params: Visualization parameters
|
|
177
|
+
context: MCP context for logging
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Matplotlib figure
|
|
181
|
+
"""
|
|
182
|
+
if context:
|
|
183
|
+
await context.info("Creating pathway enrichment visualization")
|
|
184
|
+
|
|
185
|
+
plot_type = params.subtype or "barplot"
|
|
186
|
+
|
|
187
|
+
# Route spatial subtypes to enrichment visualization
|
|
188
|
+
if plot_type.startswith("spatial_"):
|
|
189
|
+
return await create_enrichment_visualization(adata, params, context)
|
|
190
|
+
|
|
191
|
+
# Get GSEA/ORA results from adata.uns
|
|
192
|
+
gsea_key = getattr(params, "gsea_results_key", "gsea_results")
|
|
193
|
+
if gsea_key not in adata.uns:
|
|
194
|
+
alt_keys = ["rank_genes_groups", "de_results", "pathway_enrichment"]
|
|
195
|
+
for key in alt_keys:
|
|
196
|
+
if key in adata.uns:
|
|
197
|
+
gsea_key = key
|
|
198
|
+
break
|
|
199
|
+
else:
|
|
200
|
+
raise DataNotFoundError(f"GSEA results not found. Expected key: {gsea_key}")
|
|
201
|
+
|
|
202
|
+
gsea_results = adata.uns[gsea_key]
|
|
203
|
+
|
|
204
|
+
if plot_type == "enrichment_plot":
|
|
205
|
+
return _create_gsea_enrichment_plot(gsea_results, params)
|
|
206
|
+
elif plot_type == "dotplot":
|
|
207
|
+
return _create_gsea_dotplot(gsea_results, params)
|
|
208
|
+
else: # Default to barplot
|
|
209
|
+
return _create_gsea_barplot(gsea_results, params)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# =============================================================================
|
|
213
|
+
# Enrichment Score Visualizations
|
|
214
|
+
# =============================================================================
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _create_enrichment_violin(
|
|
218
|
+
adata: "ad.AnnData",
|
|
219
|
+
params: VisualizationParameters,
|
|
220
|
+
score_cols: list[str],
|
|
221
|
+
context: Optional["ToolContext"] = None,
|
|
222
|
+
) -> plt.Figure:
|
|
223
|
+
"""Create violin plot of enrichment scores grouped by cluster."""
|
|
224
|
+
if not params.cluster_key:
|
|
225
|
+
categorical_cols = get_categorical_columns(adata, limit=15)
|
|
226
|
+
raise ParameterError(
|
|
227
|
+
"Enrichment violin plot requires 'cluster_key' parameter.\n"
|
|
228
|
+
f"Available categorical columns: {', '.join(categorical_cols)}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
validate_obs_column(adata, params.cluster_key, "Cluster")
|
|
232
|
+
|
|
233
|
+
# Determine scores to plot
|
|
234
|
+
scores_to_plot = _resolve_feature_list(
|
|
235
|
+
params.feature, adata.obs.columns, score_cols
|
|
236
|
+
)
|
|
237
|
+
if not scores_to_plot:
|
|
238
|
+
scores_to_plot = score_cols[:3]
|
|
239
|
+
|
|
240
|
+
n_scores = len(scores_to_plot)
|
|
241
|
+
# Use centralized figure size resolution for multi-panel layout
|
|
242
|
+
figsize = resolve_figure_size(params, n_panels=n_scores, panel_width=5, panel_height=6)
|
|
243
|
+
fig, axes = plt.subplots(1, n_scores, figsize=figsize)
|
|
244
|
+
if n_scores == 1:
|
|
245
|
+
axes = [axes]
|
|
246
|
+
|
|
247
|
+
for i, score in enumerate(scores_to_plot):
|
|
248
|
+
ax = axes[i]
|
|
249
|
+
data = pd.DataFrame(
|
|
250
|
+
{
|
|
251
|
+
params.cluster_key: adata.obs[params.cluster_key],
|
|
252
|
+
"Score": adata.obs[score],
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
sns.violinplot(data=data, x=params.cluster_key, y="Score", ax=ax)
|
|
256
|
+
|
|
257
|
+
sig_name = score.replace("_score", "")
|
|
258
|
+
ax.set_title(f"{sig_name} by {params.cluster_key}")
|
|
259
|
+
ax.set_xlabel(params.cluster_key)
|
|
260
|
+
ax.set_ylabel("Enrichment Score")
|
|
261
|
+
ax.tick_params(axis="x", rotation=45)
|
|
262
|
+
for label in ax.get_xticklabels():
|
|
263
|
+
label.set_horizontalalignment("right")
|
|
264
|
+
|
|
265
|
+
plt.tight_layout()
|
|
266
|
+
return fig
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
async def _create_enrichment_spatial(
|
|
270
|
+
adata: "ad.AnnData",
|
|
271
|
+
params: VisualizationParameters,
|
|
272
|
+
score_cols: list[str],
|
|
273
|
+
context: Optional["ToolContext"] = None,
|
|
274
|
+
) -> plt.Figure:
|
|
275
|
+
"""Create spatial scatter plot of enrichment scores."""
|
|
276
|
+
feature_list = _resolve_feature_list(params.feature, adata.obs.columns, score_cols)
|
|
277
|
+
|
|
278
|
+
if feature_list and len(feature_list) > 1:
|
|
279
|
+
# Multi-score visualization
|
|
280
|
+
scores_to_plot = []
|
|
281
|
+
for feat in feature_list:
|
|
282
|
+
if feat in adata.obs.columns:
|
|
283
|
+
scores_to_plot.append(feat)
|
|
284
|
+
elif f"{feat}_score" in adata.obs.columns:
|
|
285
|
+
scores_to_plot.append(f"{feat}_score")
|
|
286
|
+
|
|
287
|
+
if not scores_to_plot:
|
|
288
|
+
raise DataNotFoundError(
|
|
289
|
+
f"None of the specified scores found: {feature_list}"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
fig, axes = setup_multi_panel_figure(
|
|
293
|
+
n_panels=len(scores_to_plot),
|
|
294
|
+
params=params,
|
|
295
|
+
default_title="Enrichment Scores",
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
for i, score in enumerate(scores_to_plot):
|
|
299
|
+
if i < len(axes):
|
|
300
|
+
ax = axes[i]
|
|
301
|
+
plot_spatial_feature(adata, feature=score, ax=ax, params=params)
|
|
302
|
+
sig_name = score.replace("_score", "")
|
|
303
|
+
ax.set_title(f"{sig_name} Enrichment")
|
|
304
|
+
else:
|
|
305
|
+
# Single score visualization
|
|
306
|
+
score_col = _resolve_score_column(adata, params.feature, score_cols)
|
|
307
|
+
if context:
|
|
308
|
+
await context.info(f"Using score column: {score_col}")
|
|
309
|
+
|
|
310
|
+
fig, ax = create_figure(figsize=(10, 8))
|
|
311
|
+
plot_spatial_feature(adata, feature=score_col, ax=ax, params=params)
|
|
312
|
+
|
|
313
|
+
sig_name = score_col.replace("_score", "")
|
|
314
|
+
ax.set_title(f"{sig_name} Enrichment Score", fontsize=14)
|
|
315
|
+
|
|
316
|
+
if params.show_colorbar and hasattr(ax, "collections") and ax.collections:
|
|
317
|
+
cbar = plt.colorbar(ax.collections[0], ax=ax)
|
|
318
|
+
cbar.set_label("Enrichment Score", fontsize=12)
|
|
319
|
+
|
|
320
|
+
plt.tight_layout()
|
|
321
|
+
return fig
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _create_enrichmap_spatial(
|
|
325
|
+
adata: "ad.AnnData",
|
|
326
|
+
params: VisualizationParameters,
|
|
327
|
+
score_cols: list[str],
|
|
328
|
+
context: Optional["ToolContext"] = None,
|
|
329
|
+
) -> plt.Figure:
|
|
330
|
+
"""Create EnrichMap spatial autocorrelation visualizations."""
|
|
331
|
+
try:
|
|
332
|
+
import enrichmap as em
|
|
333
|
+
except ImportError as e:
|
|
334
|
+
raise ProcessingError(
|
|
335
|
+
f"Spatial enrichment visualization ('{params.subtype}') requires EnrichMap.\n"
|
|
336
|
+
"Install with: pip install enrichmap"
|
|
337
|
+
) from e
|
|
338
|
+
|
|
339
|
+
_ensure_enrichmap_compatibility(adata)
|
|
340
|
+
library_id = adata.obs["library_id"].unique()[0]
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
if params.subtype == "spatial_cross_correlation":
|
|
344
|
+
return _create_enrichmap_cross_correlation(adata, params, library_id, em)
|
|
345
|
+
else:
|
|
346
|
+
return _create_enrichmap_single_score(
|
|
347
|
+
adata, params, library_id, em, context
|
|
348
|
+
)
|
|
349
|
+
except DataNotFoundError:
|
|
350
|
+
raise
|
|
351
|
+
except Exception as e:
|
|
352
|
+
plt.close("all")
|
|
353
|
+
raise ProcessingError(
|
|
354
|
+
f"EnrichMap {params.subtype} visualization failed: {e}\n\n"
|
|
355
|
+
"Solutions:\n"
|
|
356
|
+
"1. Verify the enrichment analysis completed successfully\n"
|
|
357
|
+
"2. Check that spatial neighbors graph exists\n"
|
|
358
|
+
"3. Ensure enrichment scores are stored in adata.obs"
|
|
359
|
+
) from e
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _create_enrichmap_cross_correlation(
|
|
363
|
+
adata: "ad.AnnData",
|
|
364
|
+
params: VisualizationParameters,
|
|
365
|
+
library_id: str,
|
|
366
|
+
em,
|
|
367
|
+
) -> plt.Figure:
|
|
368
|
+
"""Create EnrichMap cross-correlation visualization."""
|
|
369
|
+
if "enrichment_gene_sets" not in adata.uns:
|
|
370
|
+
raise DataNotFoundError("enrichment_gene_sets not found in adata.uns")
|
|
371
|
+
|
|
372
|
+
pathways = list(adata.uns["enrichment_gene_sets"].keys())
|
|
373
|
+
if len(pathways) < 2:
|
|
374
|
+
raise DataNotFoundError("Need at least 2 pathways for cross-correlation")
|
|
375
|
+
|
|
376
|
+
score_x = f"{pathways[0]}_score"
|
|
377
|
+
score_y = f"{pathways[1]}_score"
|
|
378
|
+
|
|
379
|
+
em.pl.cross_moran_scatter(
|
|
380
|
+
adata, score_x=score_x, score_y=score_y, library_id=library_id
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
fig = plt.gcf()
|
|
384
|
+
if params.figure_size:
|
|
385
|
+
fig.set_size_inches(params.figure_size)
|
|
386
|
+
if params.dpi:
|
|
387
|
+
fig.set_dpi(params.dpi)
|
|
388
|
+
|
|
389
|
+
return fig
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _create_enrichmap_single_score(
|
|
393
|
+
adata: "ad.AnnData",
|
|
394
|
+
params: VisualizationParameters,
|
|
395
|
+
library_id: str,
|
|
396
|
+
em,
|
|
397
|
+
context: Optional["ToolContext"] = None,
|
|
398
|
+
) -> plt.Figure:
|
|
399
|
+
"""Create single-score EnrichMap visualization."""
|
|
400
|
+
if not params.feature:
|
|
401
|
+
raise DataNotFoundError(
|
|
402
|
+
"Feature parameter required for spatial enrichment visualization"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
score_col = f"{params.feature}_score"
|
|
406
|
+
validate_obs_column(adata, score_col, "Score")
|
|
407
|
+
|
|
408
|
+
if params.subtype == "spatial_correlogram":
|
|
409
|
+
em.pl.morans_correlogram(adata, score_key=score_col, library_id=library_id)
|
|
410
|
+
elif params.subtype == "spatial_variogram":
|
|
411
|
+
em.pl.variogram(adata, score_keys=[score_col])
|
|
412
|
+
elif params.subtype == "spatial_score":
|
|
413
|
+
spot_size = params.spot_size if params.spot_size is not None else 0.5
|
|
414
|
+
em.pl.spatial_enrichmap(
|
|
415
|
+
adata,
|
|
416
|
+
score_key=score_col,
|
|
417
|
+
library_id=library_id,
|
|
418
|
+
cmap="seismic",
|
|
419
|
+
vcenter=0,
|
|
420
|
+
size=spot_size,
|
|
421
|
+
img=False,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
fig = plt.gcf()
|
|
425
|
+
if params.figure_size:
|
|
426
|
+
fig.set_size_inches(params.figure_size)
|
|
427
|
+
if params.dpi:
|
|
428
|
+
fig.set_dpi(params.dpi)
|
|
429
|
+
|
|
430
|
+
return fig
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
# =============================================================================
|
|
434
|
+
# GSEA/ORA Pathway Visualizations
|
|
435
|
+
# =============================================================================
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _create_gsea_enrichment_plot(
|
|
439
|
+
gsea_results,
|
|
440
|
+
params: VisualizationParameters,
|
|
441
|
+
) -> plt.Figure:
|
|
442
|
+
"""Create classic GSEA running enrichment score plot.
|
|
443
|
+
|
|
444
|
+
Requires full gseapy result object with RES and hits data.
|
|
445
|
+
"""
|
|
446
|
+
pathway = params.feature if params.feature else None
|
|
447
|
+
|
|
448
|
+
if isinstance(gsea_results, pd.DataFrame):
|
|
449
|
+
raise DataNotFoundError(
|
|
450
|
+
"Enrichment plot requires running enrichment scores (RES) data.\n"
|
|
451
|
+
"The stored results contain only summary statistics.\n\n"
|
|
452
|
+
"Solutions:\n"
|
|
453
|
+
"1. Use subtype='barplot' or subtype='dotplot' instead\n"
|
|
454
|
+
"2. Re-run GSEA analysis and store the full result object"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if isinstance(gsea_results, dict):
|
|
458
|
+
if pathway and pathway in gsea_results:
|
|
459
|
+
result = gsea_results[pathway]
|
|
460
|
+
else:
|
|
461
|
+
pathway = next(iter(gsea_results))
|
|
462
|
+
result = gsea_results[pathway]
|
|
463
|
+
|
|
464
|
+
if not isinstance(result, dict) or "RES" not in result:
|
|
465
|
+
raise DataNotFoundError(
|
|
466
|
+
"Enrichment plot requires 'RES' (running enrichment scores) data.\n"
|
|
467
|
+
"Use subtype='barplot' or subtype='dotplot' instead."
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
import gseapy as gp
|
|
471
|
+
|
|
472
|
+
# Use centralized figure size with enrichment default
|
|
473
|
+
figsize = resolve_figure_size(params, "enrichment")
|
|
474
|
+
fig = gp.gseaplot(
|
|
475
|
+
term=pathway,
|
|
476
|
+
hits=result.get("hits", result.get("hit_indices", [])),
|
|
477
|
+
nes=result.get("NES", result.get("nes", 0)),
|
|
478
|
+
pval=result.get("pval", result.get("NOM p-val", 0)),
|
|
479
|
+
fdr=result.get("fdr", result.get("FDR q-val", 0)),
|
|
480
|
+
RES=result["RES"],
|
|
481
|
+
rank_metric=result.get("rank_metric"),
|
|
482
|
+
figsize=figsize,
|
|
483
|
+
ofname=None,
|
|
484
|
+
)
|
|
485
|
+
return fig
|
|
486
|
+
|
|
487
|
+
raise ParameterError(f"Unsupported GSEA results format: {type(gsea_results)}")
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _create_gsea_barplot(
|
|
491
|
+
gsea_results,
|
|
492
|
+
params: VisualizationParameters,
|
|
493
|
+
) -> plt.Figure:
|
|
494
|
+
"""Create barplot of top enriched pathways."""
|
|
495
|
+
import gseapy as gp
|
|
496
|
+
|
|
497
|
+
n_top = getattr(params, "n_top_pathways", 10)
|
|
498
|
+
df = _gsea_results_to_dataframe(gsea_results)
|
|
499
|
+
|
|
500
|
+
if df.empty:
|
|
501
|
+
raise DataNotFoundError("No enrichment results found")
|
|
502
|
+
|
|
503
|
+
pval_col = _find_pvalue_column(df)
|
|
504
|
+
_ensure_term_column(df)
|
|
505
|
+
|
|
506
|
+
# Use centralized figure size with dynamic height based on pathway count
|
|
507
|
+
figsize = resolve_figure_size(
|
|
508
|
+
params, n_panels=n_top, panel_width=6, panel_height=0.4
|
|
509
|
+
)
|
|
510
|
+
color = params.colormap if params.colormap != "coolwarm" else "salmon"
|
|
511
|
+
|
|
512
|
+
try:
|
|
513
|
+
ax = gp.barplot(
|
|
514
|
+
df=df,
|
|
515
|
+
column=pval_col,
|
|
516
|
+
title=params.title or "Top Enriched Pathways",
|
|
517
|
+
cutoff=1.0,
|
|
518
|
+
top_term=n_top,
|
|
519
|
+
figsize=figsize,
|
|
520
|
+
color=color,
|
|
521
|
+
ofname=None,
|
|
522
|
+
)
|
|
523
|
+
fig = ax.get_figure()
|
|
524
|
+
plt.tight_layout()
|
|
525
|
+
return fig
|
|
526
|
+
except Exception as e:
|
|
527
|
+
raise ProcessingError(
|
|
528
|
+
f"gseapy.barplot failed: {e}\n" f"Available columns: {list(df.columns)}"
|
|
529
|
+
) from e
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _create_gsea_dotplot(
|
|
533
|
+
gsea_results,
|
|
534
|
+
params: VisualizationParameters,
|
|
535
|
+
) -> plt.Figure:
|
|
536
|
+
"""Create dotplot of pathway enrichment."""
|
|
537
|
+
import gseapy as gp
|
|
538
|
+
|
|
539
|
+
n_top = getattr(params, "n_top_pathways", 10)
|
|
540
|
+
|
|
541
|
+
# Handle nested dict (multi-condition)
|
|
542
|
+
if isinstance(gsea_results, dict) and all(
|
|
543
|
+
isinstance(v, dict) for v in gsea_results.values()
|
|
544
|
+
):
|
|
545
|
+
df, x_col = _nested_dict_to_dataframe(gsea_results)
|
|
546
|
+
else:
|
|
547
|
+
df = _gsea_results_to_dataframe(gsea_results)
|
|
548
|
+
x_col = None
|
|
549
|
+
|
|
550
|
+
if df.empty:
|
|
551
|
+
raise DataNotFoundError("No enrichment results found")
|
|
552
|
+
|
|
553
|
+
_ensure_term_column(df)
|
|
554
|
+
pval_col = _find_pvalue_column(df)
|
|
555
|
+
|
|
556
|
+
figsize = params.figure_size or (6, 8)
|
|
557
|
+
cmap = params.colormap if params.colormap != "coolwarm" else "viridis_r"
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
ax = gp.dotplot(
|
|
561
|
+
df=df,
|
|
562
|
+
column=pval_col,
|
|
563
|
+
x=x_col,
|
|
564
|
+
y="Term",
|
|
565
|
+
title=params.title or "Pathway Enrichment",
|
|
566
|
+
cutoff=1.0,
|
|
567
|
+
top_term=n_top,
|
|
568
|
+
figsize=figsize,
|
|
569
|
+
cmap=cmap,
|
|
570
|
+
size=5,
|
|
571
|
+
ofname=None,
|
|
572
|
+
)
|
|
573
|
+
fig = ax.get_figure()
|
|
574
|
+
plt.tight_layout()
|
|
575
|
+
return fig
|
|
576
|
+
except Exception as e:
|
|
577
|
+
raise ProcessingError(
|
|
578
|
+
f"gseapy.dotplot failed: {e}\n" f"Available columns: {list(df.columns)}"
|
|
579
|
+
) from e
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
# =============================================================================
|
|
583
|
+
# Utility Functions
|
|
584
|
+
# =============================================================================
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def _resolve_feature_list(
|
|
588
|
+
feature,
|
|
589
|
+
obs_columns: pd.Index,
|
|
590
|
+
score_cols: list[str],
|
|
591
|
+
) -> list[str]:
|
|
592
|
+
"""Resolve feature parameter to list of valid score columns."""
|
|
593
|
+
if feature is None:
|
|
594
|
+
return []
|
|
595
|
+
if isinstance(feature, list):
|
|
596
|
+
return feature
|
|
597
|
+
return [feature]
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def _gsea_results_to_dataframe(gsea_results) -> pd.DataFrame:
|
|
601
|
+
"""Convert GSEA results to DataFrame."""
|
|
602
|
+
if isinstance(gsea_results, pd.DataFrame):
|
|
603
|
+
return gsea_results.copy()
|
|
604
|
+
if isinstance(gsea_results, dict):
|
|
605
|
+
rows = []
|
|
606
|
+
for pathway, data in gsea_results.items():
|
|
607
|
+
if isinstance(data, dict):
|
|
608
|
+
row = {"Term": pathway}
|
|
609
|
+
row.update(data)
|
|
610
|
+
rows.append(row)
|
|
611
|
+
return pd.DataFrame(rows)
|
|
612
|
+
raise ParameterError("Unsupported GSEA results format")
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def _nested_dict_to_dataframe(gsea_results: dict):
|
|
616
|
+
"""Convert nested dict (multi-condition) to DataFrame with Group column."""
|
|
617
|
+
rows = []
|
|
618
|
+
for condition, pathways in gsea_results.items():
|
|
619
|
+
for pathway, data in pathways.items():
|
|
620
|
+
if isinstance(data, dict):
|
|
621
|
+
row = {"Term": pathway, "Group": condition}
|
|
622
|
+
row.update(data)
|
|
623
|
+
rows.append(row)
|
|
624
|
+
return pd.DataFrame(rows), "Group"
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _find_pvalue_column(df: pd.DataFrame) -> str:
|
|
628
|
+
"""Find the p-value column in GSEA results DataFrame.
|
|
629
|
+
|
|
630
|
+
Handles multiple naming conventions from different enrichment methods.
|
|
631
|
+
"""
|
|
632
|
+
# Check common p-value column names (order by preference)
|
|
633
|
+
candidates = [
|
|
634
|
+
"Adjusted P-value", # gseapy standard
|
|
635
|
+
"adjusted_pvalue", # ChatSpatial internal format
|
|
636
|
+
"FDR q-val", # GSEA standard
|
|
637
|
+
"fdr",
|
|
638
|
+
"P-value",
|
|
639
|
+
"pvalue",
|
|
640
|
+
"NOM p-val",
|
|
641
|
+
"pval",
|
|
642
|
+
]
|
|
643
|
+
for col in candidates:
|
|
644
|
+
if col in df.columns:
|
|
645
|
+
return col
|
|
646
|
+
return "Adjusted P-value"
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def _ensure_term_column(df: pd.DataFrame) -> None:
|
|
650
|
+
"""Ensure DataFrame has a 'Term' column."""
|
|
651
|
+
if "Term" in df.columns:
|
|
652
|
+
return
|
|
653
|
+
if "pathway" in df.columns:
|
|
654
|
+
df["Term"] = df["pathway"]
|
|
655
|
+
elif df.index.name or not df.index.equals(pd.RangeIndex(len(df))):
|
|
656
|
+
df["Term"] = df.index
|
|
657
|
+
else:
|
|
658
|
+
raise DataNotFoundError(
|
|
659
|
+
"No pathway/term column found. Expected 'Term' or 'pathway' column."
|
|
660
|
+
)
|