chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,699 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cell communication visualization functions for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- LIANA+ cluster-based visualizations (dotplot, tileplot, circle_plot)
|
|
6
|
+
- LIANA+ spatial bivariate visualizations
|
|
7
|
+
- CellPhoneDB visualizations (heatmap, dotplot, chord)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, Optional
|
|
11
|
+
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import anndata as ad
|
|
18
|
+
|
|
19
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
20
|
+
|
|
21
|
+
from ...models.data import VisualizationParameters
|
|
22
|
+
from ...utils.adata_utils import (
|
|
23
|
+
get_cluster_key,
|
|
24
|
+
require_spatial_coords,
|
|
25
|
+
validate_obs_column,
|
|
26
|
+
)
|
|
27
|
+
from ...utils.dependency_manager import require
|
|
28
|
+
from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
|
|
29
|
+
from .core import CellCommunicationData
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# Data Retrieval
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def get_cell_communication_data(
|
|
37
|
+
adata: "ad.AnnData",
|
|
38
|
+
method: Optional[str] = None,
|
|
39
|
+
context: Optional["ToolContext"] = None,
|
|
40
|
+
) -> CellCommunicationData:
|
|
41
|
+
"""
|
|
42
|
+
Unified function to retrieve cell communication results from AnnData.
|
|
43
|
+
|
|
44
|
+
This function consolidates all cell communication data retrieval logic into
|
|
45
|
+
a single, consistent interface. It handles:
|
|
46
|
+
- LIANA+ spatial bivariate analysis results
|
|
47
|
+
- LIANA+ cluster-based analysis results
|
|
48
|
+
- CellPhoneDB analysis results
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
adata: AnnData object with cell communication results
|
|
52
|
+
method: Analysis method hint (optional)
|
|
53
|
+
context: MCP context for logging
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
CellCommunicationData object with results and metadata
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
DataNotFoundError: No cell communication results found
|
|
60
|
+
"""
|
|
61
|
+
# Check for LIANA+ spatial bivariate results (highest priority)
|
|
62
|
+
if "liana_spatial_scores" in adata.obsm:
|
|
63
|
+
spatial_scores = adata.obsm["liana_spatial_scores"]
|
|
64
|
+
lr_pairs = adata.uns.get("liana_spatial_interactions", [])
|
|
65
|
+
results_df = adata.uns.get("liana_spatial_res", pd.DataFrame())
|
|
66
|
+
|
|
67
|
+
if not isinstance(results_df, pd.DataFrame):
|
|
68
|
+
results_df = pd.DataFrame()
|
|
69
|
+
|
|
70
|
+
if context:
|
|
71
|
+
await context.info(
|
|
72
|
+
f"Found LIANA+ spatial results: {len(lr_pairs)} LR pairs, "
|
|
73
|
+
f"{spatial_scores.shape[0]} spots"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return CellCommunicationData(
|
|
77
|
+
results=results_df,
|
|
78
|
+
method="liana_spatial",
|
|
79
|
+
analysis_type="spatial",
|
|
80
|
+
lr_pairs=lr_pairs if lr_pairs else [],
|
|
81
|
+
spatial_scores=spatial_scores,
|
|
82
|
+
spatial_pvals=adata.obsm.get("liana_spatial_pvals"),
|
|
83
|
+
results_key="liana_spatial_res",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Check for LIANA+ cluster-based results
|
|
87
|
+
if "liana_res" in adata.uns:
|
|
88
|
+
results = adata.uns["liana_res"]
|
|
89
|
+
if isinstance(results, pd.DataFrame) and len(results) > 0:
|
|
90
|
+
if (
|
|
91
|
+
"ligand_complex" in results.columns
|
|
92
|
+
and "receptor_complex" in results.columns
|
|
93
|
+
):
|
|
94
|
+
lr_pairs = (
|
|
95
|
+
(results["ligand_complex"] + "^" + results["receptor_complex"])
|
|
96
|
+
.unique()
|
|
97
|
+
.tolist()
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
lr_pairs = []
|
|
101
|
+
|
|
102
|
+
source_labels = (
|
|
103
|
+
results["source"].unique().tolist()
|
|
104
|
+
if "source" in results.columns
|
|
105
|
+
else None
|
|
106
|
+
)
|
|
107
|
+
target_labels = (
|
|
108
|
+
results["target"].unique().tolist()
|
|
109
|
+
if "target" in results.columns
|
|
110
|
+
else None
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if context:
|
|
114
|
+
await context.info(
|
|
115
|
+
f"Found LIANA+ cluster results: {len(lr_pairs)} LR pairs"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return CellCommunicationData(
|
|
119
|
+
results=results,
|
|
120
|
+
method="liana_cluster",
|
|
121
|
+
analysis_type="cluster",
|
|
122
|
+
lr_pairs=lr_pairs,
|
|
123
|
+
source_labels=source_labels,
|
|
124
|
+
target_labels=target_labels,
|
|
125
|
+
results_key="liana_res",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Check for CellPhoneDB results
|
|
129
|
+
if "cellphonedb_means" in adata.uns:
|
|
130
|
+
means = adata.uns["cellphonedb_means"]
|
|
131
|
+
if isinstance(means, pd.DataFrame):
|
|
132
|
+
lr_pairs = means.index.tolist()
|
|
133
|
+
|
|
134
|
+
if context:
|
|
135
|
+
await context.info(
|
|
136
|
+
f"Found CellPhoneDB results: {len(lr_pairs)} LR pairs"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return CellCommunicationData(
|
|
140
|
+
results=means,
|
|
141
|
+
method="cellphonedb",
|
|
142
|
+
analysis_type="cluster",
|
|
143
|
+
lr_pairs=lr_pairs,
|
|
144
|
+
results_key="cellphonedb_means",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# No results found
|
|
148
|
+
raise DataNotFoundError(
|
|
149
|
+
"No cell communication results found. "
|
|
150
|
+
"Run analyze_cell_communication() first."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# =============================================================================
|
|
155
|
+
# Main Router
|
|
156
|
+
# =============================================================================
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def create_cell_communication_visualization(
|
|
160
|
+
adata: "ad.AnnData",
|
|
161
|
+
params: VisualizationParameters,
|
|
162
|
+
context: Optional["ToolContext"] = None,
|
|
163
|
+
) -> plt.Figure:
|
|
164
|
+
"""Create cell communication visualization using unified data retrieval.
|
|
165
|
+
|
|
166
|
+
Routes to appropriate visualization based on analysis type and subtype:
|
|
167
|
+
- Spatial analysis: Multi-panel spatial plot
|
|
168
|
+
- Cluster analysis: LIANA+ visualizations or CellPhoneDB
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
adata: AnnData object with cell communication results
|
|
172
|
+
params: Visualization parameters (use params.subtype to select viz type)
|
|
173
|
+
context: MCP context for logging
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
matplotlib Figure object
|
|
177
|
+
"""
|
|
178
|
+
if context:
|
|
179
|
+
await context.info("Creating cell communication visualization")
|
|
180
|
+
|
|
181
|
+
data = await get_cell_communication_data(adata, context=context)
|
|
182
|
+
|
|
183
|
+
if context:
|
|
184
|
+
await context.info(
|
|
185
|
+
f"Using {data.method} results ({data.analysis_type} analysis, "
|
|
186
|
+
f"{len(data.lr_pairs)} LR pairs)"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if data.analysis_type == "spatial":
|
|
190
|
+
return _create_spatial_lr_visualization(adata, data, params, context)
|
|
191
|
+
else:
|
|
192
|
+
if data.method == "cellphonedb":
|
|
193
|
+
subtype = params.subtype or "heatmap"
|
|
194
|
+
if subtype == "dotplot":
|
|
195
|
+
return _create_cellphonedb_dotplot(adata, data, params, context)
|
|
196
|
+
elif subtype == "chord":
|
|
197
|
+
return _create_cellphonedb_chord(adata, data, params, context)
|
|
198
|
+
else:
|
|
199
|
+
return _create_cellphonedb_heatmap(adata, data, params, context)
|
|
200
|
+
else:
|
|
201
|
+
subtype = params.subtype or "dotplot"
|
|
202
|
+
if subtype == "tileplot":
|
|
203
|
+
return await _create_liana_tileplot(adata, data, params, context)
|
|
204
|
+
elif subtype == "circle_plot":
|
|
205
|
+
return await _create_liana_circle_plot(adata, data, params, context)
|
|
206
|
+
else:
|
|
207
|
+
return await _create_cluster_lr_visualization(
|
|
208
|
+
adata, data, params, context
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# =============================================================================
|
|
213
|
+
# LIANA+ Visualizations
|
|
214
|
+
# =============================================================================
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _create_spatial_lr_visualization(
|
|
218
|
+
adata: "ad.AnnData",
|
|
219
|
+
data: CellCommunicationData,
|
|
220
|
+
params: VisualizationParameters,
|
|
221
|
+
context: Optional["ToolContext"] = None,
|
|
222
|
+
) -> plt.Figure:
|
|
223
|
+
"""Create spatial L-R visualization using scanpy (official LIANA+ approach)."""
|
|
224
|
+
if data.spatial_scores is None or len(data.lr_pairs) == 0:
|
|
225
|
+
raise DataNotFoundError(
|
|
226
|
+
"No spatial communication scores found. Run spatial analysis first."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
n_pairs = min(params.plot_top_pairs or 6, len(data.lr_pairs), 6)
|
|
230
|
+
|
|
231
|
+
# Determine top pairs based on global metric
|
|
232
|
+
if len(data.results) > 0:
|
|
233
|
+
metric_col = None
|
|
234
|
+
for col in ["morans", "lee", "global_score"]:
|
|
235
|
+
if col in data.results.columns:
|
|
236
|
+
metric_col = col
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
if metric_col:
|
|
240
|
+
top_results = data.results.nlargest(n_pairs, metric_col)
|
|
241
|
+
top_pairs = top_results.index.tolist()
|
|
242
|
+
else:
|
|
243
|
+
top_pairs = data.lr_pairs[:n_pairs]
|
|
244
|
+
else:
|
|
245
|
+
top_pairs = data.lr_pairs[:n_pairs]
|
|
246
|
+
|
|
247
|
+
if not top_pairs:
|
|
248
|
+
raise DataNotFoundError("No LR pairs found in spatial results.")
|
|
249
|
+
|
|
250
|
+
# Get pair indices
|
|
251
|
+
pair_indices = []
|
|
252
|
+
valid_pairs = []
|
|
253
|
+
for pair in top_pairs:
|
|
254
|
+
if pair in data.lr_pairs:
|
|
255
|
+
pair_indices.append(data.lr_pairs.index(pair))
|
|
256
|
+
valid_pairs.append(pair)
|
|
257
|
+
|
|
258
|
+
if not valid_pairs:
|
|
259
|
+
valid_pairs = data.lr_pairs[:n_pairs]
|
|
260
|
+
pair_indices = list(range(len(valid_pairs)))
|
|
261
|
+
|
|
262
|
+
# Create figure
|
|
263
|
+
n_panels = len(valid_pairs)
|
|
264
|
+
n_cols = min(3, n_panels)
|
|
265
|
+
n_rows = (n_panels + n_cols - 1) // n_cols
|
|
266
|
+
|
|
267
|
+
figsize = params.figure_size or (5 * n_cols, 4 * n_rows)
|
|
268
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
|
269
|
+
|
|
270
|
+
if n_panels == 1:
|
|
271
|
+
axes = np.array([axes])
|
|
272
|
+
axes = np.atleast_1d(axes).flatten()
|
|
273
|
+
|
|
274
|
+
coords = require_spatial_coords(adata)
|
|
275
|
+
x_coords, y_coords = coords[:, 0], coords[:, 1]
|
|
276
|
+
|
|
277
|
+
for i, (pair, pair_idx) in enumerate(zip(valid_pairs, pair_indices, strict=False)):
|
|
278
|
+
ax = axes[i]
|
|
279
|
+
|
|
280
|
+
if pair_idx < data.spatial_scores.shape[1]:
|
|
281
|
+
scores = data.spatial_scores[:, pair_idx]
|
|
282
|
+
else:
|
|
283
|
+
scores = np.zeros(len(adata))
|
|
284
|
+
|
|
285
|
+
scatter = ax.scatter(
|
|
286
|
+
x_coords,
|
|
287
|
+
y_coords,
|
|
288
|
+
c=scores,
|
|
289
|
+
cmap=params.colormap or "viridis",
|
|
290
|
+
s=params.spot_size or 15,
|
|
291
|
+
alpha=params.alpha or 0.8,
|
|
292
|
+
edgecolors="none",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
display_name = pair.replace("^", " → ").replace("_", " → ")
|
|
296
|
+
|
|
297
|
+
if len(data.results) > 0 and pair in data.results.index:
|
|
298
|
+
for metric in ["morans", "lee", "global_score"]:
|
|
299
|
+
if metric in data.results.columns:
|
|
300
|
+
val = data.results.loc[pair, metric]
|
|
301
|
+
display_name += f"\n({metric}: {val:.3f})"
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
ax.set_title(display_name, fontsize=10)
|
|
305
|
+
ax.set_aspect("equal")
|
|
306
|
+
ax.set_xlabel("")
|
|
307
|
+
ax.set_ylabel("")
|
|
308
|
+
plt.colorbar(scatter, ax=ax, shrink=0.7, label="Score")
|
|
309
|
+
|
|
310
|
+
for i in range(n_panels, len(axes)):
|
|
311
|
+
axes[i].set_visible(False)
|
|
312
|
+
|
|
313
|
+
plt.suptitle("Spatial Cell Communication", fontsize=14, fontweight="bold")
|
|
314
|
+
plt.tight_layout()
|
|
315
|
+
|
|
316
|
+
return fig
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
async def _create_cluster_lr_visualization(
|
|
320
|
+
adata: "ad.AnnData",
|
|
321
|
+
data: CellCommunicationData,
|
|
322
|
+
params: VisualizationParameters,
|
|
323
|
+
context: Optional["ToolContext"] = None,
|
|
324
|
+
) -> plt.Figure:
|
|
325
|
+
"""Create cluster-based L-R visualization using LIANA+ dotplot."""
|
|
326
|
+
require("liana", feature="LIANA+ plotting")
|
|
327
|
+
require("plotnine", feature="LIANA+ plotting")
|
|
328
|
+
import liana as li
|
|
329
|
+
|
|
330
|
+
if context:
|
|
331
|
+
await context.info("Using LIANA+ official dotplot")
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
orderby_col = None
|
|
335
|
+
for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
|
|
336
|
+
if col in data.results.columns:
|
|
337
|
+
orderby_col = col
|
|
338
|
+
break
|
|
339
|
+
|
|
340
|
+
if orderby_col is None:
|
|
341
|
+
raise DataNotFoundError("No valid orderby column found in LIANA results")
|
|
342
|
+
|
|
343
|
+
p = li.pl.dotplot(
|
|
344
|
+
adata=adata,
|
|
345
|
+
uns_key=data.results_key,
|
|
346
|
+
colour=(
|
|
347
|
+
"magnitude_rank" if "magnitude_rank" in data.results.columns else None
|
|
348
|
+
),
|
|
349
|
+
size=(
|
|
350
|
+
"specificity_rank"
|
|
351
|
+
if "specificity_rank" in data.results.columns
|
|
352
|
+
else None
|
|
353
|
+
),
|
|
354
|
+
orderby=orderby_col,
|
|
355
|
+
orderby_ascending=True,
|
|
356
|
+
top_n=params.plot_top_pairs or 20,
|
|
357
|
+
inverse_colour=True,
|
|
358
|
+
inverse_size=True,
|
|
359
|
+
cmap=params.colormap or "viridis",
|
|
360
|
+
figure_size=params.figure_size or (10, 8),
|
|
361
|
+
return_fig=True,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
fig = _plotnine_to_matplotlib(p, params)
|
|
365
|
+
return fig
|
|
366
|
+
|
|
367
|
+
except Exception as e:
|
|
368
|
+
raise ProcessingError(
|
|
369
|
+
f"LIANA+ dotplot failed: {e}\n\n"
|
|
370
|
+
"Ensure cell communication analysis completed successfully."
|
|
371
|
+
) from e
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def _create_liana_tileplot(
|
|
375
|
+
adata: "ad.AnnData",
|
|
376
|
+
data: CellCommunicationData,
|
|
377
|
+
params: VisualizationParameters,
|
|
378
|
+
context: Optional["ToolContext"] = None,
|
|
379
|
+
) -> plt.Figure:
|
|
380
|
+
"""Create LIANA+ tileplot visualization."""
|
|
381
|
+
try:
|
|
382
|
+
import liana as li
|
|
383
|
+
|
|
384
|
+
if context:
|
|
385
|
+
await context.info("Creating LIANA+ tileplot")
|
|
386
|
+
|
|
387
|
+
orderby_col = None
|
|
388
|
+
for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
|
|
389
|
+
if col in data.results.columns:
|
|
390
|
+
orderby_col = col
|
|
391
|
+
break
|
|
392
|
+
|
|
393
|
+
if orderby_col is None:
|
|
394
|
+
raise DataNotFoundError("No valid orderby column found in LIANA results")
|
|
395
|
+
|
|
396
|
+
fill_col = (
|
|
397
|
+
"magnitude_rank"
|
|
398
|
+
if "magnitude_rank" in data.results.columns
|
|
399
|
+
else orderby_col
|
|
400
|
+
)
|
|
401
|
+
label_col = "lr_means" if "lr_means" in data.results.columns else fill_col
|
|
402
|
+
|
|
403
|
+
p = li.pl.tileplot(
|
|
404
|
+
adata=adata,
|
|
405
|
+
uns_key=data.results_key,
|
|
406
|
+
fill=fill_col,
|
|
407
|
+
label=label_col,
|
|
408
|
+
orderby=orderby_col,
|
|
409
|
+
orderby_ascending=True,
|
|
410
|
+
top_n=params.plot_top_pairs or 15,
|
|
411
|
+
figure_size=params.figure_size or (14, 8),
|
|
412
|
+
return_fig=True,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
fig = _plotnine_to_matplotlib(p, params)
|
|
416
|
+
return fig
|
|
417
|
+
|
|
418
|
+
except Exception as e:
|
|
419
|
+
raise ProcessingError(
|
|
420
|
+
f"LIANA+ tileplot failed: {e}\n\n"
|
|
421
|
+
"Ensure cell communication analysis completed successfully."
|
|
422
|
+
) from e
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
async def _create_liana_circle_plot(
|
|
426
|
+
adata: "ad.AnnData",
|
|
427
|
+
data: CellCommunicationData,
|
|
428
|
+
params: VisualizationParameters,
|
|
429
|
+
context: Optional["ToolContext"] = None,
|
|
430
|
+
) -> plt.Figure:
|
|
431
|
+
"""Create LIANA+ circle plot (network diagram) visualization."""
|
|
432
|
+
try:
|
|
433
|
+
import liana as li
|
|
434
|
+
|
|
435
|
+
if context:
|
|
436
|
+
await context.info("Creating LIANA+ circle plot")
|
|
437
|
+
|
|
438
|
+
score_col = None
|
|
439
|
+
for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
|
|
440
|
+
if col in data.results.columns:
|
|
441
|
+
score_col = col
|
|
442
|
+
break
|
|
443
|
+
|
|
444
|
+
if score_col is None:
|
|
445
|
+
raise DataNotFoundError("No valid score column found in LIANA results")
|
|
446
|
+
|
|
447
|
+
groupby = params.cluster_key
|
|
448
|
+
if groupby is None:
|
|
449
|
+
if "source" in data.results.columns:
|
|
450
|
+
groupby = (
|
|
451
|
+
data.results["source"].iloc[0] if len(data.results) > 0 else None
|
|
452
|
+
)
|
|
453
|
+
if groupby is None:
|
|
454
|
+
raise ParameterError(
|
|
455
|
+
"cluster_key is required for circle_plot. "
|
|
456
|
+
"Specify the cell type column used in analysis."
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
fig_size = params.figure_size or (10, 10)
|
|
460
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
461
|
+
|
|
462
|
+
li.pl.circle_plot(
|
|
463
|
+
adata=adata,
|
|
464
|
+
uns_key=data.results_key,
|
|
465
|
+
groupby=groupby,
|
|
466
|
+
score_key=score_col,
|
|
467
|
+
inverse_score=True,
|
|
468
|
+
top_n=params.plot_top_pairs * 3 if params.plot_top_pairs else 50,
|
|
469
|
+
orderby=score_col,
|
|
470
|
+
orderby_ascending=True,
|
|
471
|
+
figure_size=fig_size,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
fig = plt.gcf()
|
|
475
|
+
return fig
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
raise ProcessingError(
|
|
479
|
+
f"LIANA+ circle_plot failed: {e}\n\n"
|
|
480
|
+
"Ensure cell communication analysis completed successfully."
|
|
481
|
+
) from e
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
# =============================================================================
|
|
485
|
+
# CellPhoneDB Visualizations
|
|
486
|
+
# =============================================================================
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _create_cellphonedb_heatmap(
|
|
490
|
+
adata: "ad.AnnData",
|
|
491
|
+
data: CellCommunicationData,
|
|
492
|
+
params: VisualizationParameters,
|
|
493
|
+
context: Optional["ToolContext"] = None,
|
|
494
|
+
) -> plt.Figure:
|
|
495
|
+
"""Create CellPhoneDB heatmap visualization using ktplotspy."""
|
|
496
|
+
import ktplotspy as kpy
|
|
497
|
+
|
|
498
|
+
means = data.results
|
|
499
|
+
|
|
500
|
+
if not isinstance(means, pd.DataFrame) or len(means) == 0:
|
|
501
|
+
raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
|
|
502
|
+
|
|
503
|
+
pvalues = adata.uns.get("cellphonedb_pvalues")
|
|
504
|
+
|
|
505
|
+
if pvalues is None or not isinstance(pvalues, pd.DataFrame):
|
|
506
|
+
raise DataNotFoundError("CellPhoneDB pvalues not found. Re-run analysis.")
|
|
507
|
+
|
|
508
|
+
grid = kpy.plot_cpdb_heatmap(
|
|
509
|
+
pvals=pvalues,
|
|
510
|
+
title=params.title or "CellPhoneDB: Significant Interactions",
|
|
511
|
+
alpha=0.05,
|
|
512
|
+
symmetrical=True,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return grid.fig
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _create_cellphonedb_dotplot(
|
|
519
|
+
adata: "ad.AnnData",
|
|
520
|
+
data: CellCommunicationData,
|
|
521
|
+
params: VisualizationParameters,
|
|
522
|
+
context: Optional["ToolContext"] = None,
|
|
523
|
+
) -> plt.Figure:
|
|
524
|
+
"""Create CellPhoneDB dotplot visualization using ktplotspy."""
|
|
525
|
+
means = data.results
|
|
526
|
+
|
|
527
|
+
if not isinstance(means, pd.DataFrame) or len(means) == 0:
|
|
528
|
+
raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
|
|
529
|
+
|
|
530
|
+
require("ktplotspy", feature="CellPhoneDB dotplot visualization")
|
|
531
|
+
import ktplotspy as kpy
|
|
532
|
+
|
|
533
|
+
try:
|
|
534
|
+
pvalues = adata.uns.get("cellphonedb_pvalues")
|
|
535
|
+
|
|
536
|
+
if pvalues is None or not isinstance(pvalues, pd.DataFrame):
|
|
537
|
+
raise DataNotFoundError("Missing pvalues DataFrame for ktplotspy dotplot")
|
|
538
|
+
|
|
539
|
+
cluster_key = params.cluster_key or get_cluster_key(adata)
|
|
540
|
+
if not cluster_key:
|
|
541
|
+
raise ParameterError(
|
|
542
|
+
"cluster_key required for CellPhoneDB dotplot. "
|
|
543
|
+
"No default cluster key found in data."
|
|
544
|
+
)
|
|
545
|
+
validate_obs_column(adata, cluster_key, "Cluster")
|
|
546
|
+
|
|
547
|
+
gg = kpy.plot_cpdb(
|
|
548
|
+
adata=adata,
|
|
549
|
+
cell_type1=".",
|
|
550
|
+
cell_type2=".",
|
|
551
|
+
means=means,
|
|
552
|
+
pvals=pvalues,
|
|
553
|
+
celltype_key=cluster_key,
|
|
554
|
+
genes=None,
|
|
555
|
+
figsize=params.figure_size or (12, 10),
|
|
556
|
+
title="CellPhoneDB: L-R Interactions",
|
|
557
|
+
max_size=10,
|
|
558
|
+
alpha=0.05,
|
|
559
|
+
keep_significant_only=True,
|
|
560
|
+
standard_scale=True,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
fig = gg.draw()
|
|
564
|
+
return fig
|
|
565
|
+
|
|
566
|
+
except Exception as e:
|
|
567
|
+
raise ProcessingError(
|
|
568
|
+
f"Failed to create CellPhoneDB dotplot: {e}\n\n"
|
|
569
|
+
"Try using subtype='heatmap' instead."
|
|
570
|
+
) from e
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def _create_cellphonedb_chord(
|
|
574
|
+
adata: "ad.AnnData",
|
|
575
|
+
data: CellCommunicationData,
|
|
576
|
+
params: VisualizationParameters,
|
|
577
|
+
context: Optional["ToolContext"] = None,
|
|
578
|
+
) -> plt.Figure:
|
|
579
|
+
"""Create CellPhoneDB chord/circos diagram using ktplotspy."""
|
|
580
|
+
from matplotlib.lines import Line2D
|
|
581
|
+
|
|
582
|
+
means = data.results
|
|
583
|
+
|
|
584
|
+
if not isinstance(means, pd.DataFrame) or len(means) == 0:
|
|
585
|
+
raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
|
|
586
|
+
|
|
587
|
+
require("ktplotspy", feature="CellPhoneDB chord visualization")
|
|
588
|
+
import ktplotspy as kpy
|
|
589
|
+
import matplotlib.colors as mcolors
|
|
590
|
+
|
|
591
|
+
try:
|
|
592
|
+
pvalues = adata.uns.get("cellphonedb_pvalues")
|
|
593
|
+
deconvoluted = adata.uns.get("cellphonedb_deconvoluted")
|
|
594
|
+
|
|
595
|
+
if pvalues is None or not isinstance(pvalues, pd.DataFrame):
|
|
596
|
+
raise DataNotFoundError(
|
|
597
|
+
"Missing pvalues DataFrame for ktplotspy chord plot"
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
if deconvoluted is None or not isinstance(deconvoluted, pd.DataFrame):
|
|
601
|
+
raise DataNotFoundError(
|
|
602
|
+
"Missing deconvoluted DataFrame for chord plot. "
|
|
603
|
+
"Re-run CellPhoneDB analysis."
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
cluster_key = params.cluster_key or get_cluster_key(adata)
|
|
607
|
+
if not cluster_key:
|
|
608
|
+
raise ParameterError(
|
|
609
|
+
"cluster_key required for CellPhoneDB chord plot. "
|
|
610
|
+
"No default cluster key found in data."
|
|
611
|
+
)
|
|
612
|
+
validate_obs_column(adata, cluster_key, "Cluster")
|
|
613
|
+
|
|
614
|
+
link_colors = None
|
|
615
|
+
legend_items = []
|
|
616
|
+
|
|
617
|
+
if "interacting_pair" in deconvoluted.columns:
|
|
618
|
+
unique_pairs = deconvoluted["interacting_pair"].unique()
|
|
619
|
+
n_pairs = min(params.plot_top_pairs or 50, len(unique_pairs))
|
|
620
|
+
top_pairs = unique_pairs[:n_pairs]
|
|
621
|
+
|
|
622
|
+
if n_pairs <= 10:
|
|
623
|
+
cmap = plt.cm.get_cmap("tab10", 10)
|
|
624
|
+
elif n_pairs <= 20:
|
|
625
|
+
cmap = plt.cm.get_cmap("tab20", 20)
|
|
626
|
+
else:
|
|
627
|
+
cmap = plt.cm.get_cmap("nipy_spectral", n_pairs)
|
|
628
|
+
|
|
629
|
+
link_colors = {}
|
|
630
|
+
for i, pair in enumerate(top_pairs):
|
|
631
|
+
color = mcolors.rgb2hex(cmap(i % cmap.N))
|
|
632
|
+
link_colors[pair] = color
|
|
633
|
+
legend_items.append((pair, color))
|
|
634
|
+
|
|
635
|
+
circos = kpy.plot_cpdb_chord(
|
|
636
|
+
adata=adata,
|
|
637
|
+
means=means,
|
|
638
|
+
pvals=pvalues,
|
|
639
|
+
deconvoluted=deconvoluted,
|
|
640
|
+
celltype_key=cluster_key,
|
|
641
|
+
cell_type1=".",
|
|
642
|
+
cell_type2=".",
|
|
643
|
+
link_colors=link_colors,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
fig = circos.ax.figure
|
|
647
|
+
fig.set_size_inches(14, 10)
|
|
648
|
+
|
|
649
|
+
if legend_items:
|
|
650
|
+
line_handles = [
|
|
651
|
+
Line2D([], [], color=color, label=label, linewidth=2)
|
|
652
|
+
for label, color in legend_items
|
|
653
|
+
]
|
|
654
|
+
|
|
655
|
+
legend = circos.ax.legend(
|
|
656
|
+
handles=line_handles,
|
|
657
|
+
loc="center left",
|
|
658
|
+
bbox_to_anchor=(1.15, 0.5),
|
|
659
|
+
fontsize=6,
|
|
660
|
+
frameon=True,
|
|
661
|
+
framealpha=0.9,
|
|
662
|
+
title="L-R Pairs",
|
|
663
|
+
title_fontsize=7,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
fig._chatspatial_extra_artists = [legend]
|
|
667
|
+
|
|
668
|
+
return fig
|
|
669
|
+
|
|
670
|
+
except Exception as e:
|
|
671
|
+
raise ProcessingError(
|
|
672
|
+
f"Failed to create CellPhoneDB chord diagram: {e}\n\n"
|
|
673
|
+
"Try using subtype='heatmap' instead."
|
|
674
|
+
) from e
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
# =============================================================================
|
|
678
|
+
# Utilities
|
|
679
|
+
# =============================================================================
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def _plotnine_to_matplotlib(p, params: VisualizationParameters) -> plt.Figure:
|
|
683
|
+
"""Convert plotnine ggplot object to matplotlib Figure.
|
|
684
|
+
|
|
685
|
+
Uses plotnine's native draw() method which returns the underlying
|
|
686
|
+
matplotlib Figure, avoiding rasterization through PNG buffer.
|
|
687
|
+
"""
|
|
688
|
+
try:
|
|
689
|
+
# plotnine's draw() returns the matplotlib Figure directly
|
|
690
|
+
fig = p.draw()
|
|
691
|
+
|
|
692
|
+
# Apply DPI setting if specified
|
|
693
|
+
if params.dpi:
|
|
694
|
+
fig.set_dpi(params.dpi)
|
|
695
|
+
|
|
696
|
+
return fig
|
|
697
|
+
|
|
698
|
+
except Exception as e:
|
|
699
|
+
raise ProcessingError(f"Failed to convert plotnine figure: {e}") from e
|