chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,699 @@
1
+ """
2
+ Cell communication visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - LIANA+ cluster-based visualizations (dotplot, tileplot, circle_plot)
6
+ - LIANA+ spatial bivariate visualizations
7
+ - CellPhoneDB visualizations (heatmap, dotplot, chord)
8
+ """
9
+
10
+ from typing import TYPE_CHECKING, Optional
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ if TYPE_CHECKING:
17
+ import anndata as ad
18
+
19
+ from ...spatial_mcp_adapter import ToolContext
20
+
21
+ from ...models.data import VisualizationParameters
22
+ from ...utils.adata_utils import (
23
+ get_cluster_key,
24
+ require_spatial_coords,
25
+ validate_obs_column,
26
+ )
27
+ from ...utils.dependency_manager import require
28
+ from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
29
+ from .core import CellCommunicationData
30
+
31
+ # =============================================================================
32
+ # Data Retrieval
33
+ # =============================================================================
34
+
35
+
36
+ async def get_cell_communication_data(
37
+ adata: "ad.AnnData",
38
+ method: Optional[str] = None,
39
+ context: Optional["ToolContext"] = None,
40
+ ) -> CellCommunicationData:
41
+ """
42
+ Unified function to retrieve cell communication results from AnnData.
43
+
44
+ This function consolidates all cell communication data retrieval logic into
45
+ a single, consistent interface. It handles:
46
+ - LIANA+ spatial bivariate analysis results
47
+ - LIANA+ cluster-based analysis results
48
+ - CellPhoneDB analysis results
49
+
50
+ Args:
51
+ adata: AnnData object with cell communication results
52
+ method: Analysis method hint (optional)
53
+ context: MCP context for logging
54
+
55
+ Returns:
56
+ CellCommunicationData object with results and metadata
57
+
58
+ Raises:
59
+ DataNotFoundError: No cell communication results found
60
+ """
61
+ # Check for LIANA+ spatial bivariate results (highest priority)
62
+ if "liana_spatial_scores" in adata.obsm:
63
+ spatial_scores = adata.obsm["liana_spatial_scores"]
64
+ lr_pairs = adata.uns.get("liana_spatial_interactions", [])
65
+ results_df = adata.uns.get("liana_spatial_res", pd.DataFrame())
66
+
67
+ if not isinstance(results_df, pd.DataFrame):
68
+ results_df = pd.DataFrame()
69
+
70
+ if context:
71
+ await context.info(
72
+ f"Found LIANA+ spatial results: {len(lr_pairs)} LR pairs, "
73
+ f"{spatial_scores.shape[0]} spots"
74
+ )
75
+
76
+ return CellCommunicationData(
77
+ results=results_df,
78
+ method="liana_spatial",
79
+ analysis_type="spatial",
80
+ lr_pairs=lr_pairs if lr_pairs else [],
81
+ spatial_scores=spatial_scores,
82
+ spatial_pvals=adata.obsm.get("liana_spatial_pvals"),
83
+ results_key="liana_spatial_res",
84
+ )
85
+
86
+ # Check for LIANA+ cluster-based results
87
+ if "liana_res" in adata.uns:
88
+ results = adata.uns["liana_res"]
89
+ if isinstance(results, pd.DataFrame) and len(results) > 0:
90
+ if (
91
+ "ligand_complex" in results.columns
92
+ and "receptor_complex" in results.columns
93
+ ):
94
+ lr_pairs = (
95
+ (results["ligand_complex"] + "^" + results["receptor_complex"])
96
+ .unique()
97
+ .tolist()
98
+ )
99
+ else:
100
+ lr_pairs = []
101
+
102
+ source_labels = (
103
+ results["source"].unique().tolist()
104
+ if "source" in results.columns
105
+ else None
106
+ )
107
+ target_labels = (
108
+ results["target"].unique().tolist()
109
+ if "target" in results.columns
110
+ else None
111
+ )
112
+
113
+ if context:
114
+ await context.info(
115
+ f"Found LIANA+ cluster results: {len(lr_pairs)} LR pairs"
116
+ )
117
+
118
+ return CellCommunicationData(
119
+ results=results,
120
+ method="liana_cluster",
121
+ analysis_type="cluster",
122
+ lr_pairs=lr_pairs,
123
+ source_labels=source_labels,
124
+ target_labels=target_labels,
125
+ results_key="liana_res",
126
+ )
127
+
128
+ # Check for CellPhoneDB results
129
+ if "cellphonedb_means" in adata.uns:
130
+ means = adata.uns["cellphonedb_means"]
131
+ if isinstance(means, pd.DataFrame):
132
+ lr_pairs = means.index.tolist()
133
+
134
+ if context:
135
+ await context.info(
136
+ f"Found CellPhoneDB results: {len(lr_pairs)} LR pairs"
137
+ )
138
+
139
+ return CellCommunicationData(
140
+ results=means,
141
+ method="cellphonedb",
142
+ analysis_type="cluster",
143
+ lr_pairs=lr_pairs,
144
+ results_key="cellphonedb_means",
145
+ )
146
+
147
+ # No results found
148
+ raise DataNotFoundError(
149
+ "No cell communication results found. "
150
+ "Run analyze_cell_communication() first."
151
+ )
152
+
153
+
154
+ # =============================================================================
155
+ # Main Router
156
+ # =============================================================================
157
+
158
+
159
+ async def create_cell_communication_visualization(
160
+ adata: "ad.AnnData",
161
+ params: VisualizationParameters,
162
+ context: Optional["ToolContext"] = None,
163
+ ) -> plt.Figure:
164
+ """Create cell communication visualization using unified data retrieval.
165
+
166
+ Routes to appropriate visualization based on analysis type and subtype:
167
+ - Spatial analysis: Multi-panel spatial plot
168
+ - Cluster analysis: LIANA+ visualizations or CellPhoneDB
169
+
170
+ Args:
171
+ adata: AnnData object with cell communication results
172
+ params: Visualization parameters (use params.subtype to select viz type)
173
+ context: MCP context for logging
174
+
175
+ Returns:
176
+ matplotlib Figure object
177
+ """
178
+ if context:
179
+ await context.info("Creating cell communication visualization")
180
+
181
+ data = await get_cell_communication_data(adata, context=context)
182
+
183
+ if context:
184
+ await context.info(
185
+ f"Using {data.method} results ({data.analysis_type} analysis, "
186
+ f"{len(data.lr_pairs)} LR pairs)"
187
+ )
188
+
189
+ if data.analysis_type == "spatial":
190
+ return _create_spatial_lr_visualization(adata, data, params, context)
191
+ else:
192
+ if data.method == "cellphonedb":
193
+ subtype = params.subtype or "heatmap"
194
+ if subtype == "dotplot":
195
+ return _create_cellphonedb_dotplot(adata, data, params, context)
196
+ elif subtype == "chord":
197
+ return _create_cellphonedb_chord(adata, data, params, context)
198
+ else:
199
+ return _create_cellphonedb_heatmap(adata, data, params, context)
200
+ else:
201
+ subtype = params.subtype or "dotplot"
202
+ if subtype == "tileplot":
203
+ return await _create_liana_tileplot(adata, data, params, context)
204
+ elif subtype == "circle_plot":
205
+ return await _create_liana_circle_plot(adata, data, params, context)
206
+ else:
207
+ return await _create_cluster_lr_visualization(
208
+ adata, data, params, context
209
+ )
210
+
211
+
212
+ # =============================================================================
213
+ # LIANA+ Visualizations
214
+ # =============================================================================
215
+
216
+
217
+ def _create_spatial_lr_visualization(
218
+ adata: "ad.AnnData",
219
+ data: CellCommunicationData,
220
+ params: VisualizationParameters,
221
+ context: Optional["ToolContext"] = None,
222
+ ) -> plt.Figure:
223
+ """Create spatial L-R visualization using scanpy (official LIANA+ approach)."""
224
+ if data.spatial_scores is None or len(data.lr_pairs) == 0:
225
+ raise DataNotFoundError(
226
+ "No spatial communication scores found. Run spatial analysis first."
227
+ )
228
+
229
+ n_pairs = min(params.plot_top_pairs or 6, len(data.lr_pairs), 6)
230
+
231
+ # Determine top pairs based on global metric
232
+ if len(data.results) > 0:
233
+ metric_col = None
234
+ for col in ["morans", "lee", "global_score"]:
235
+ if col in data.results.columns:
236
+ metric_col = col
237
+ break
238
+
239
+ if metric_col:
240
+ top_results = data.results.nlargest(n_pairs, metric_col)
241
+ top_pairs = top_results.index.tolist()
242
+ else:
243
+ top_pairs = data.lr_pairs[:n_pairs]
244
+ else:
245
+ top_pairs = data.lr_pairs[:n_pairs]
246
+
247
+ if not top_pairs:
248
+ raise DataNotFoundError("No LR pairs found in spatial results.")
249
+
250
+ # Get pair indices
251
+ pair_indices = []
252
+ valid_pairs = []
253
+ for pair in top_pairs:
254
+ if pair in data.lr_pairs:
255
+ pair_indices.append(data.lr_pairs.index(pair))
256
+ valid_pairs.append(pair)
257
+
258
+ if not valid_pairs:
259
+ valid_pairs = data.lr_pairs[:n_pairs]
260
+ pair_indices = list(range(len(valid_pairs)))
261
+
262
+ # Create figure
263
+ n_panels = len(valid_pairs)
264
+ n_cols = min(3, n_panels)
265
+ n_rows = (n_panels + n_cols - 1) // n_cols
266
+
267
+ figsize = params.figure_size or (5 * n_cols, 4 * n_rows)
268
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
269
+
270
+ if n_panels == 1:
271
+ axes = np.array([axes])
272
+ axes = np.atleast_1d(axes).flatten()
273
+
274
+ coords = require_spatial_coords(adata)
275
+ x_coords, y_coords = coords[:, 0], coords[:, 1]
276
+
277
+ for i, (pair, pair_idx) in enumerate(zip(valid_pairs, pair_indices, strict=False)):
278
+ ax = axes[i]
279
+
280
+ if pair_idx < data.spatial_scores.shape[1]:
281
+ scores = data.spatial_scores[:, pair_idx]
282
+ else:
283
+ scores = np.zeros(len(adata))
284
+
285
+ scatter = ax.scatter(
286
+ x_coords,
287
+ y_coords,
288
+ c=scores,
289
+ cmap=params.colormap or "viridis",
290
+ s=params.spot_size or 15,
291
+ alpha=params.alpha or 0.8,
292
+ edgecolors="none",
293
+ )
294
+
295
+ display_name = pair.replace("^", " → ").replace("_", " → ")
296
+
297
+ if len(data.results) > 0 and pair in data.results.index:
298
+ for metric in ["morans", "lee", "global_score"]:
299
+ if metric in data.results.columns:
300
+ val = data.results.loc[pair, metric]
301
+ display_name += f"\n({metric}: {val:.3f})"
302
+ break
303
+
304
+ ax.set_title(display_name, fontsize=10)
305
+ ax.set_aspect("equal")
306
+ ax.set_xlabel("")
307
+ ax.set_ylabel("")
308
+ plt.colorbar(scatter, ax=ax, shrink=0.7, label="Score")
309
+
310
+ for i in range(n_panels, len(axes)):
311
+ axes[i].set_visible(False)
312
+
313
+ plt.suptitle("Spatial Cell Communication", fontsize=14, fontweight="bold")
314
+ plt.tight_layout()
315
+
316
+ return fig
317
+
318
+
319
+ async def _create_cluster_lr_visualization(
320
+ adata: "ad.AnnData",
321
+ data: CellCommunicationData,
322
+ params: VisualizationParameters,
323
+ context: Optional["ToolContext"] = None,
324
+ ) -> plt.Figure:
325
+ """Create cluster-based L-R visualization using LIANA+ dotplot."""
326
+ require("liana", feature="LIANA+ plotting")
327
+ require("plotnine", feature="LIANA+ plotting")
328
+ import liana as li
329
+
330
+ if context:
331
+ await context.info("Using LIANA+ official dotplot")
332
+
333
+ try:
334
+ orderby_col = None
335
+ for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
336
+ if col in data.results.columns:
337
+ orderby_col = col
338
+ break
339
+
340
+ if orderby_col is None:
341
+ raise DataNotFoundError("No valid orderby column found in LIANA results")
342
+
343
+ p = li.pl.dotplot(
344
+ adata=adata,
345
+ uns_key=data.results_key,
346
+ colour=(
347
+ "magnitude_rank" if "magnitude_rank" in data.results.columns else None
348
+ ),
349
+ size=(
350
+ "specificity_rank"
351
+ if "specificity_rank" in data.results.columns
352
+ else None
353
+ ),
354
+ orderby=orderby_col,
355
+ orderby_ascending=True,
356
+ top_n=params.plot_top_pairs or 20,
357
+ inverse_colour=True,
358
+ inverse_size=True,
359
+ cmap=params.colormap or "viridis",
360
+ figure_size=params.figure_size or (10, 8),
361
+ return_fig=True,
362
+ )
363
+
364
+ fig = _plotnine_to_matplotlib(p, params)
365
+ return fig
366
+
367
+ except Exception as e:
368
+ raise ProcessingError(
369
+ f"LIANA+ dotplot failed: {e}\n\n"
370
+ "Ensure cell communication analysis completed successfully."
371
+ ) from e
372
+
373
+
374
+ async def _create_liana_tileplot(
375
+ adata: "ad.AnnData",
376
+ data: CellCommunicationData,
377
+ params: VisualizationParameters,
378
+ context: Optional["ToolContext"] = None,
379
+ ) -> plt.Figure:
380
+ """Create LIANA+ tileplot visualization."""
381
+ try:
382
+ import liana as li
383
+
384
+ if context:
385
+ await context.info("Creating LIANA+ tileplot")
386
+
387
+ orderby_col = None
388
+ for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
389
+ if col in data.results.columns:
390
+ orderby_col = col
391
+ break
392
+
393
+ if orderby_col is None:
394
+ raise DataNotFoundError("No valid orderby column found in LIANA results")
395
+
396
+ fill_col = (
397
+ "magnitude_rank"
398
+ if "magnitude_rank" in data.results.columns
399
+ else orderby_col
400
+ )
401
+ label_col = "lr_means" if "lr_means" in data.results.columns else fill_col
402
+
403
+ p = li.pl.tileplot(
404
+ adata=adata,
405
+ uns_key=data.results_key,
406
+ fill=fill_col,
407
+ label=label_col,
408
+ orderby=orderby_col,
409
+ orderby_ascending=True,
410
+ top_n=params.plot_top_pairs or 15,
411
+ figure_size=params.figure_size or (14, 8),
412
+ return_fig=True,
413
+ )
414
+
415
+ fig = _plotnine_to_matplotlib(p, params)
416
+ return fig
417
+
418
+ except Exception as e:
419
+ raise ProcessingError(
420
+ f"LIANA+ tileplot failed: {e}\n\n"
421
+ "Ensure cell communication analysis completed successfully."
422
+ ) from e
423
+
424
+
425
+ async def _create_liana_circle_plot(
426
+ adata: "ad.AnnData",
427
+ data: CellCommunicationData,
428
+ params: VisualizationParameters,
429
+ context: Optional["ToolContext"] = None,
430
+ ) -> plt.Figure:
431
+ """Create LIANA+ circle plot (network diagram) visualization."""
432
+ try:
433
+ import liana as li
434
+
435
+ if context:
436
+ await context.info("Creating LIANA+ circle plot")
437
+
438
+ score_col = None
439
+ for col in ["magnitude_rank", "specificity_rank", "lr_means"]:
440
+ if col in data.results.columns:
441
+ score_col = col
442
+ break
443
+
444
+ if score_col is None:
445
+ raise DataNotFoundError("No valid score column found in LIANA results")
446
+
447
+ groupby = params.cluster_key
448
+ if groupby is None:
449
+ if "source" in data.results.columns:
450
+ groupby = (
451
+ data.results["source"].iloc[0] if len(data.results) > 0 else None
452
+ )
453
+ if groupby is None:
454
+ raise ParameterError(
455
+ "cluster_key is required for circle_plot. "
456
+ "Specify the cell type column used in analysis."
457
+ )
458
+
459
+ fig_size = params.figure_size or (10, 10)
460
+ fig, ax = plt.subplots(figsize=fig_size)
461
+
462
+ li.pl.circle_plot(
463
+ adata=adata,
464
+ uns_key=data.results_key,
465
+ groupby=groupby,
466
+ score_key=score_col,
467
+ inverse_score=True,
468
+ top_n=params.plot_top_pairs * 3 if params.plot_top_pairs else 50,
469
+ orderby=score_col,
470
+ orderby_ascending=True,
471
+ figure_size=fig_size,
472
+ )
473
+
474
+ fig = plt.gcf()
475
+ return fig
476
+
477
+ except Exception as e:
478
+ raise ProcessingError(
479
+ f"LIANA+ circle_plot failed: {e}\n\n"
480
+ "Ensure cell communication analysis completed successfully."
481
+ ) from e
482
+
483
+
484
+ # =============================================================================
485
+ # CellPhoneDB Visualizations
486
+ # =============================================================================
487
+
488
+
489
+ def _create_cellphonedb_heatmap(
490
+ adata: "ad.AnnData",
491
+ data: CellCommunicationData,
492
+ params: VisualizationParameters,
493
+ context: Optional["ToolContext"] = None,
494
+ ) -> plt.Figure:
495
+ """Create CellPhoneDB heatmap visualization using ktplotspy."""
496
+ import ktplotspy as kpy
497
+
498
+ means = data.results
499
+
500
+ if not isinstance(means, pd.DataFrame) or len(means) == 0:
501
+ raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
502
+
503
+ pvalues = adata.uns.get("cellphonedb_pvalues")
504
+
505
+ if pvalues is None or not isinstance(pvalues, pd.DataFrame):
506
+ raise DataNotFoundError("CellPhoneDB pvalues not found. Re-run analysis.")
507
+
508
+ grid = kpy.plot_cpdb_heatmap(
509
+ pvals=pvalues,
510
+ title=params.title or "CellPhoneDB: Significant Interactions",
511
+ alpha=0.05,
512
+ symmetrical=True,
513
+ )
514
+
515
+ return grid.fig
516
+
517
+
518
+ def _create_cellphonedb_dotplot(
519
+ adata: "ad.AnnData",
520
+ data: CellCommunicationData,
521
+ params: VisualizationParameters,
522
+ context: Optional["ToolContext"] = None,
523
+ ) -> plt.Figure:
524
+ """Create CellPhoneDB dotplot visualization using ktplotspy."""
525
+ means = data.results
526
+
527
+ if not isinstance(means, pd.DataFrame) or len(means) == 0:
528
+ raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
529
+
530
+ require("ktplotspy", feature="CellPhoneDB dotplot visualization")
531
+ import ktplotspy as kpy
532
+
533
+ try:
534
+ pvalues = adata.uns.get("cellphonedb_pvalues")
535
+
536
+ if pvalues is None or not isinstance(pvalues, pd.DataFrame):
537
+ raise DataNotFoundError("Missing pvalues DataFrame for ktplotspy dotplot")
538
+
539
+ cluster_key = params.cluster_key or get_cluster_key(adata)
540
+ if not cluster_key:
541
+ raise ParameterError(
542
+ "cluster_key required for CellPhoneDB dotplot. "
543
+ "No default cluster key found in data."
544
+ )
545
+ validate_obs_column(adata, cluster_key, "Cluster")
546
+
547
+ gg = kpy.plot_cpdb(
548
+ adata=adata,
549
+ cell_type1=".",
550
+ cell_type2=".",
551
+ means=means,
552
+ pvals=pvalues,
553
+ celltype_key=cluster_key,
554
+ genes=None,
555
+ figsize=params.figure_size or (12, 10),
556
+ title="CellPhoneDB: L-R Interactions",
557
+ max_size=10,
558
+ alpha=0.05,
559
+ keep_significant_only=True,
560
+ standard_scale=True,
561
+ )
562
+
563
+ fig = gg.draw()
564
+ return fig
565
+
566
+ except Exception as e:
567
+ raise ProcessingError(
568
+ f"Failed to create CellPhoneDB dotplot: {e}\n\n"
569
+ "Try using subtype='heatmap' instead."
570
+ ) from e
571
+
572
+
573
+ def _create_cellphonedb_chord(
574
+ adata: "ad.AnnData",
575
+ data: CellCommunicationData,
576
+ params: VisualizationParameters,
577
+ context: Optional["ToolContext"] = None,
578
+ ) -> plt.Figure:
579
+ """Create CellPhoneDB chord/circos diagram using ktplotspy."""
580
+ from matplotlib.lines import Line2D
581
+
582
+ means = data.results
583
+
584
+ if not isinstance(means, pd.DataFrame) or len(means) == 0:
585
+ raise DataNotFoundError("CellPhoneDB results empty. Re-run analysis.")
586
+
587
+ require("ktplotspy", feature="CellPhoneDB chord visualization")
588
+ import ktplotspy as kpy
589
+ import matplotlib.colors as mcolors
590
+
591
+ try:
592
+ pvalues = adata.uns.get("cellphonedb_pvalues")
593
+ deconvoluted = adata.uns.get("cellphonedb_deconvoluted")
594
+
595
+ if pvalues is None or not isinstance(pvalues, pd.DataFrame):
596
+ raise DataNotFoundError(
597
+ "Missing pvalues DataFrame for ktplotspy chord plot"
598
+ )
599
+
600
+ if deconvoluted is None or not isinstance(deconvoluted, pd.DataFrame):
601
+ raise DataNotFoundError(
602
+ "Missing deconvoluted DataFrame for chord plot. "
603
+ "Re-run CellPhoneDB analysis."
604
+ )
605
+
606
+ cluster_key = params.cluster_key or get_cluster_key(adata)
607
+ if not cluster_key:
608
+ raise ParameterError(
609
+ "cluster_key required for CellPhoneDB chord plot. "
610
+ "No default cluster key found in data."
611
+ )
612
+ validate_obs_column(adata, cluster_key, "Cluster")
613
+
614
+ link_colors = None
615
+ legend_items = []
616
+
617
+ if "interacting_pair" in deconvoluted.columns:
618
+ unique_pairs = deconvoluted["interacting_pair"].unique()
619
+ n_pairs = min(params.plot_top_pairs or 50, len(unique_pairs))
620
+ top_pairs = unique_pairs[:n_pairs]
621
+
622
+ if n_pairs <= 10:
623
+ cmap = plt.cm.get_cmap("tab10", 10)
624
+ elif n_pairs <= 20:
625
+ cmap = plt.cm.get_cmap("tab20", 20)
626
+ else:
627
+ cmap = plt.cm.get_cmap("nipy_spectral", n_pairs)
628
+
629
+ link_colors = {}
630
+ for i, pair in enumerate(top_pairs):
631
+ color = mcolors.rgb2hex(cmap(i % cmap.N))
632
+ link_colors[pair] = color
633
+ legend_items.append((pair, color))
634
+
635
+ circos = kpy.plot_cpdb_chord(
636
+ adata=adata,
637
+ means=means,
638
+ pvals=pvalues,
639
+ deconvoluted=deconvoluted,
640
+ celltype_key=cluster_key,
641
+ cell_type1=".",
642
+ cell_type2=".",
643
+ link_colors=link_colors,
644
+ )
645
+
646
+ fig = circos.ax.figure
647
+ fig.set_size_inches(14, 10)
648
+
649
+ if legend_items:
650
+ line_handles = [
651
+ Line2D([], [], color=color, label=label, linewidth=2)
652
+ for label, color in legend_items
653
+ ]
654
+
655
+ legend = circos.ax.legend(
656
+ handles=line_handles,
657
+ loc="center left",
658
+ bbox_to_anchor=(1.15, 0.5),
659
+ fontsize=6,
660
+ frameon=True,
661
+ framealpha=0.9,
662
+ title="L-R Pairs",
663
+ title_fontsize=7,
664
+ )
665
+
666
+ fig._chatspatial_extra_artists = [legend]
667
+
668
+ return fig
669
+
670
+ except Exception as e:
671
+ raise ProcessingError(
672
+ f"Failed to create CellPhoneDB chord diagram: {e}\n\n"
673
+ "Try using subtype='heatmap' instead."
674
+ ) from e
675
+
676
+
677
+ # =============================================================================
678
+ # Utilities
679
+ # =============================================================================
680
+
681
+
682
+ def _plotnine_to_matplotlib(p, params: VisualizationParameters) -> plt.Figure:
683
+ """Convert plotnine ggplot object to matplotlib Figure.
684
+
685
+ Uses plotnine's native draw() method which returns the underlying
686
+ matplotlib Figure, avoiding rasterization through PNG buffer.
687
+ """
688
+ try:
689
+ # plotnine's draw() returns the matplotlib Figure directly
690
+ fig = p.draw()
691
+
692
+ # Apply DPI setting if specified
693
+ if params.dpi:
694
+ fig.set_dpi(params.dpi)
695
+
696
+ return fig
697
+
698
+ except Exception as e:
699
+ raise ProcessingError(f"Failed to convert plotnine figure: {e}") from e