chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,852 @@
1
+ """
2
+ Deconvolution visualization functions for spatial transcriptomics.
3
+
4
+ This module contains:
5
+ - Cell type proportion spatial maps
6
+ - Dominant cell type visualization
7
+ - Diversity/entropy maps
8
+ - Stacked barplots
9
+ - Scatterpie plots (SPOTlight-style)
10
+ - UMAP proportion plots
11
+ - CARD imputation visualization
12
+ """
13
+
14
+ from typing import TYPE_CHECKING, Optional
15
+
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import pandas as pd
19
+ from matplotlib.patches import Patch, Wedge
20
+ from scipy.stats import entropy
21
+
22
+ if TYPE_CHECKING:
23
+ import anndata as ad
24
+
25
+ from ...spatial_mcp_adapter import ToolContext
26
+
27
+ from ...models.data import VisualizationParameters
28
+ from ...utils.adata_utils import (
29
+ get_analysis_parameter,
30
+ get_cluster_key,
31
+ get_spatial_key,
32
+ require_spatial_coords,
33
+ )
34
+ from ...utils.exceptions import DataNotFoundError, ParameterError
35
+ from .core import (
36
+ DeconvolutionData,
37
+ create_figure_from_params,
38
+ get_category_colors,
39
+ plot_spatial_feature,
40
+ resolve_figure_size,
41
+ setup_multi_panel_figure,
42
+ )
43
+
44
+ # =============================================================================
45
+ # Data Retrieval
46
+ # =============================================================================
47
+
48
+
49
+ def _get_available_methods(adata: "ad.AnnData") -> list[str]:
50
+ """Get available deconvolution methods from metadata or key names.
51
+
52
+ Priority:
53
+ 1. Read from stored metadata (most reliable)
54
+ 2. Fall back to key name search (for legacy data)
55
+ """
56
+ methods = []
57
+
58
+ # First: try to get from stored metadata
59
+ for key in adata.uns.keys():
60
+ if key.startswith("deconvolution_") and key.endswith("_metadata"):
61
+ # Extract method name: deconvolution_{method}_metadata -> {method}
62
+ method = key.replace("deconvolution_", "").replace("_metadata", "")
63
+ if method not in methods:
64
+ methods.append(method)
65
+
66
+ # Fallback: search obsm keys (for legacy data without metadata)
67
+ if not methods:
68
+ for key in adata.obsm.keys():
69
+ if key.startswith("deconvolution_"):
70
+ method = key.replace("deconvolution_", "")
71
+ if method not in methods:
72
+ methods.append(method)
73
+
74
+ return methods
75
+
76
+
77
+ async def get_deconvolution_data(
78
+ adata: "ad.AnnData",
79
+ method: Optional[str] = None,
80
+ context: Optional["ToolContext"] = None,
81
+ ) -> DeconvolutionData:
82
+ """
83
+ Unified function to retrieve deconvolution results from AnnData.
84
+
85
+ This function consolidates all deconvolution data retrieval logic into
86
+ a single, consistent interface. It handles:
87
+ - Auto-detection when only one result exists
88
+ - Explicit method specification
89
+ - Clear error messages with solutions
90
+
91
+ Priority for reading data:
92
+ 1. Read from stored metadata (most reliable)
93
+ 2. Fall back to key name inference (for legacy data)
94
+
95
+ Args:
96
+ adata: AnnData object with deconvolution results
97
+ method: Deconvolution method name (e.g., "cell2location", "rctd").
98
+ If None and only one result exists, auto-selects it.
99
+ If None and multiple results exist, raises ValueError.
100
+ context: MCP context for logging
101
+
102
+ Returns:
103
+ DeconvolutionData object with proportions and metadata
104
+
105
+ Raises:
106
+ DataNotFoundError: No deconvolution results found
107
+ ValueError: Multiple results found but method not specified
108
+ """
109
+ available_methods = _get_available_methods(adata)
110
+
111
+ # Handle method specification
112
+ if method is not None:
113
+ if method not in available_methods:
114
+ raise DataNotFoundError(
115
+ f"Deconvolution '{method}' not found. "
116
+ f"Available: {available_methods if available_methods else 'None'}. "
117
+ f"Run deconvolve_data() first."
118
+ )
119
+ else:
120
+ # Auto-detect
121
+ if not available_methods:
122
+ raise DataNotFoundError(
123
+ "No deconvolution results found. Run deconvolve_data() first."
124
+ )
125
+
126
+ if len(available_methods) > 1:
127
+ raise ParameterError(
128
+ f"Multiple deconvolution results: {available_methods}. "
129
+ f"Specify deconv_method parameter."
130
+ )
131
+
132
+ # Single result - auto-select
133
+ method = available_methods[0]
134
+ if context:
135
+ await context.info(f"Auto-selected deconvolution method: {method}")
136
+
137
+ # Get data from metadata or fall back to convention
138
+ analysis_name = f"deconvolution_{method}"
139
+
140
+ # Try to get from stored metadata first
141
+ proportions_key = get_analysis_parameter(adata, analysis_name, "proportions_key")
142
+ cell_types = get_analysis_parameter(adata, analysis_name, "cell_types")
143
+ dominant_type_key = get_analysis_parameter(
144
+ adata, analysis_name, "dominant_type_key"
145
+ )
146
+
147
+ # Fall back to convention-based keys
148
+ if not proportions_key:
149
+ proportions_key = f"deconvolution_{method}"
150
+
151
+ if proportions_key not in adata.obsm:
152
+ raise DataNotFoundError(
153
+ f"Proportions data '{proportions_key}' not found in adata.obsm"
154
+ )
155
+
156
+ # Get cell type names
157
+ if not cell_types:
158
+ cell_types_key = f"{proportions_key}_cell_types"
159
+ if cell_types_key in adata.uns:
160
+ cell_types = list(adata.uns[cell_types_key])
161
+ else:
162
+ # Fallback: generate generic names from shape
163
+ n_cell_types = adata.obsm[proportions_key].shape[1]
164
+ cell_types = [f"CellType_{i}" for i in range(n_cell_types)]
165
+ if context:
166
+ await context.warning("Cell type names not found. Using generic names.")
167
+
168
+ # Check dominant type key
169
+ if not dominant_type_key:
170
+ dominant_type_key = f"dominant_celltype_{method}"
171
+
172
+ if dominant_type_key not in adata.obs.columns:
173
+ dominant_type_key = None
174
+
175
+ # Create DataFrame
176
+ proportions = pd.DataFrame(
177
+ adata.obsm[proportions_key], index=adata.obs_names, columns=cell_types
178
+ )
179
+
180
+ return DeconvolutionData(
181
+ proportions=proportions,
182
+ method=method,
183
+ cell_types=cell_types,
184
+ proportions_key=proportions_key,
185
+ dominant_type_key=dominant_type_key,
186
+ )
187
+
188
+
189
+ # =============================================================================
190
+ # Visualization Functions
191
+ # =============================================================================
192
+
193
+
194
+ async def create_deconvolution_visualization(
195
+ adata: "ad.AnnData",
196
+ params: VisualizationParameters,
197
+ context: Optional["ToolContext"] = None,
198
+ ) -> plt.Figure:
199
+ """Create deconvolution results visualization.
200
+
201
+ Routes to appropriate visualization based on params.subtype:
202
+ - spatial_multi: Multi-panel spatial maps (default)
203
+ - dominant_type: Dominant cell type map (CARD-style)
204
+ - diversity: Shannon entropy diversity map
205
+ - stacked_bar: Stacked barplot
206
+ - scatterpie: Spatial scatterpie (SPOTlight-style)
207
+ - umap: UMAP colored by proportions
208
+
209
+ Args:
210
+ adata: AnnData object with deconvolution results
211
+ params: Visualization parameters
212
+ context: MCP context
213
+
214
+ Returns:
215
+ Matplotlib figure with deconvolution visualization
216
+ """
217
+ viz_type = params.subtype or "spatial_multi"
218
+
219
+ if viz_type == "dominant_type":
220
+ return await _create_dominant_celltype_map(adata, params, context)
221
+ elif viz_type == "diversity":
222
+ return await _create_diversity_map(adata, params, context)
223
+ elif viz_type == "stacked_bar":
224
+ return await _create_stacked_barplot(adata, params, context)
225
+ elif viz_type == "scatterpie":
226
+ return await _create_scatterpie_plot(adata, params, context)
227
+ elif viz_type == "umap":
228
+ return await _create_umap_proportions(adata, params, context)
229
+ elif viz_type == "spatial_multi":
230
+ return await _create_spatial_multi_deconvolution(adata, params, context)
231
+ else:
232
+ raise ParameterError(
233
+ f"Unknown deconvolution visualization type: {viz_type}. "
234
+ f"Available: spatial_multi, dominant_type, diversity, stacked_bar, "
235
+ f"scatterpie, umap"
236
+ )
237
+
238
+
239
+ async def _create_dominant_celltype_map(
240
+ adata: "ad.AnnData",
241
+ params: VisualizationParameters,
242
+ context: Optional["ToolContext"] = None,
243
+ ) -> plt.Figure:
244
+ """Create dominant cell type map (CARD-style).
245
+
246
+ Shows the dominant cell type at each spatial location, optionally
247
+ marking "pure" vs "mixed" spots based on proportion threshold.
248
+ """
249
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
250
+
251
+ # Get dominant cell type
252
+ dominant_idx = data.proportions.values.argmax(axis=1)
253
+ dominant_types = data.proportions.columns[dominant_idx].values
254
+ dominant_proportions = data.proportions.values.max(axis=1)
255
+
256
+ # Mark pure vs mixed spots
257
+ if params.show_mixed_spots:
258
+ spot_categories = np.where(
259
+ dominant_proportions >= params.min_proportion_threshold,
260
+ dominant_types,
261
+ "Mixed",
262
+ )
263
+ else:
264
+ spot_categories = dominant_types
265
+
266
+ # Get spatial coordinates
267
+ spatial_coords = require_spatial_coords(adata)
268
+
269
+ # Create figure
270
+ fig, axes = create_figure_from_params(params, "deconvolution")
271
+ ax = axes[0]
272
+
273
+ # Get unique categories
274
+ unique_categories = np.unique(spot_categories)
275
+ n_categories = len(unique_categories)
276
+
277
+ # Create colormap using centralized utility
278
+ if params.show_mixed_spots and "Mixed" in unique_categories:
279
+ cell_type_categories = [c for c in unique_categories if c != "Mixed"]
280
+ n_cell_types = len(cell_type_categories)
281
+
282
+ colors = get_category_colors(n_cell_types, params.colormap)
283
+ cell_type_colors = {ct: colors[i] for i, ct in enumerate(cell_type_categories)}
284
+ cell_type_colors["Mixed"] = (0.7, 0.7, 0.7, 1.0)
285
+
286
+ for category in unique_categories:
287
+ mask = spot_categories == category
288
+ ax.scatter(
289
+ spatial_coords[mask, 0],
290
+ spatial_coords[mask, 1],
291
+ c=[cell_type_colors[category]],
292
+ s=params.spot_size or 10,
293
+ alpha=0.8 if category == "Mixed" else 1.0,
294
+ label=category,
295
+ edgecolors="none",
296
+ )
297
+ else:
298
+ colors = get_category_colors(n_categories, params.colormap)
299
+ color_map = {cat: colors[i] for i, cat in enumerate(unique_categories)}
300
+
301
+ for category in unique_categories:
302
+ mask = spot_categories == category
303
+ ax.scatter(
304
+ spatial_coords[mask, 0],
305
+ spatial_coords[mask, 1],
306
+ c=[color_map[category]],
307
+ s=params.spot_size or 10,
308
+ alpha=1.0,
309
+ label=category,
310
+ edgecolors="none",
311
+ )
312
+
313
+ # Formatting
314
+ ax.set_xlabel("Spatial X")
315
+ ax.set_ylabel("Spatial Y")
316
+ ax.set_title(
317
+ f"Dominant Cell Type Map ({data.method})\n"
318
+ f"Threshold: {params.min_proportion_threshold:.2f}"
319
+ if params.show_mixed_spots
320
+ else f"Dominant Cell Type Map ({data.method})"
321
+ )
322
+ ax.legend(
323
+ bbox_to_anchor=(1.05, 1),
324
+ loc="upper left",
325
+ ncol=1 if n_categories <= 15 else 2,
326
+ fontsize=8,
327
+ markerscale=0.5,
328
+ )
329
+ ax.set_aspect("equal")
330
+
331
+ plt.tight_layout()
332
+ return fig
333
+
334
+
335
+ async def _create_diversity_map(
336
+ adata: "ad.AnnData",
337
+ params: VisualizationParameters,
338
+ context: Optional["ToolContext"] = None,
339
+ ) -> plt.Figure:
340
+ """Create Shannon entropy diversity map.
341
+
342
+ Shows cell type diversity at each spatial location using Shannon entropy.
343
+ Higher entropy = more diverse/mixed cell types.
344
+ Lower entropy = more homogeneous/dominated by single type.
345
+ """
346
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
347
+
348
+ # Calculate Shannon entropy for each spot
349
+ epsilon = 1e-10
350
+ proportions_safe = data.proportions.values + epsilon
351
+ spot_entropy = entropy(proportions_safe.T, base=2)
352
+
353
+ # Normalize to [0, 1] range
354
+ max_entropy = np.log2(data.proportions.shape[1])
355
+ normalized_entropy = spot_entropy / max_entropy
356
+
357
+ # Get spatial coordinates
358
+ spatial_coords = require_spatial_coords(adata)
359
+
360
+ # Create figure
361
+ fig, axes = create_figure_from_params(params, "deconvolution")
362
+ ax = axes[0]
363
+
364
+ scatter = ax.scatter(
365
+ spatial_coords[:, 0],
366
+ spatial_coords[:, 1],
367
+ c=normalized_entropy,
368
+ cmap=params.colormap or "viridis",
369
+ s=params.spot_size or 10,
370
+ alpha=1.0,
371
+ edgecolors="none",
372
+ )
373
+
374
+ cbar = plt.colorbar(scatter, ax=ax)
375
+ cbar.set_label("Cell Type Diversity (Shannon Entropy)", rotation=270, labelpad=20)
376
+
377
+ ax.set_xlabel("Spatial X")
378
+ ax.set_ylabel("Spatial Y")
379
+ ax.set_title(
380
+ f"Cell Type Diversity Map ({data.method})\n"
381
+ f"Shannon Entropy (0=homogeneous, 1=maximally diverse)"
382
+ )
383
+ ax.set_aspect("equal")
384
+
385
+ plt.tight_layout()
386
+
387
+ if context:
388
+ mean_entropy = normalized_entropy.mean()
389
+ std_entropy = normalized_entropy.std()
390
+ high_div_pct = (normalized_entropy > 0.7).sum() / len(normalized_entropy) * 100
391
+ low_div_pct = (normalized_entropy < 0.3).sum() / len(normalized_entropy) * 100
392
+ await context.info(
393
+ f"Created diversity map:\n"
394
+ f" Mean entropy: {mean_entropy:.3f} ± {std_entropy:.3f}\n"
395
+ f" High diversity (>0.7): {high_div_pct:.1f}% of spots\n"
396
+ f" Low diversity (<0.3): {low_div_pct:.1f}% of spots"
397
+ )
398
+
399
+ return fig
400
+
401
+
402
+ async def _create_stacked_barplot(
403
+ adata: "ad.AnnData",
404
+ params: VisualizationParameters,
405
+ context: Optional["ToolContext"] = None,
406
+ ) -> plt.Figure:
407
+ """Create stacked barplot of cell type proportions.
408
+
409
+ Shows cell type proportions for each spot as stacked bars.
410
+ Spots can be sorted by dominant cell type, spatial order, or cluster.
411
+ """
412
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
413
+
414
+ # Limit number of spots for readability
415
+ n_spots = len(data.proportions)
416
+ if n_spots > params.max_spots:
417
+ sample_indices = np.random.choice(n_spots, size=params.max_spots, replace=False)
418
+ proportions_plot = data.proportions.iloc[sample_indices]
419
+ if context:
420
+ await context.warning(
421
+ f"Sampled {params.max_spots} spots out of {n_spots} for readability."
422
+ )
423
+ else:
424
+ proportions_plot = data.proportions
425
+
426
+ # Sort spots based on sort_by parameter
427
+ if params.sort_by == "dominant_type":
428
+ dominant_idx = proportions_plot.values.argmax(axis=1)
429
+ dominant_types = proportions_plot.columns[dominant_idx]
430
+ sort_order = np.argsort(dominant_types)
431
+ elif params.sort_by == "spatial":
432
+ spatial_key = get_spatial_key(adata)
433
+ if spatial_key:
434
+ from scipy.cluster.hierarchy import dendrogram, linkage
435
+
436
+ spatial_coords = adata.obsm[spatial_key][proportions_plot.index]
437
+ linkage_matrix = linkage(spatial_coords, method="ward")
438
+ dend = dendrogram(linkage_matrix, no_plot=True)
439
+ sort_order = dend["leaves"]
440
+ else:
441
+ sort_order = np.arange(len(proportions_plot))
442
+ elif params.sort_by == "cluster":
443
+ cluster_key = params.cluster_key or get_cluster_key(adata)
444
+ if cluster_key and cluster_key in adata.obs.columns:
445
+ cluster_values = adata.obs.loc[proportions_plot.index, cluster_key]
446
+ sort_order = np.argsort(cluster_values.astype(str))
447
+ else:
448
+ sort_order = np.arange(len(proportions_plot))
449
+ else:
450
+ sort_order = np.arange(len(proportions_plot))
451
+
452
+ proportions_sorted = proportions_plot.iloc[sort_order]
453
+
454
+ # Create figure
455
+ fig, axes = create_figure_from_params(params, "violin") # violin uses (12, 6)
456
+ ax = axes[0]
457
+
458
+ cell_types = proportions_sorted.columns.tolist()
459
+ n_cell_types = len(cell_types)
460
+
461
+ # Use centralized colormap utility
462
+ colors = get_category_colors(n_cell_types, params.colormap)
463
+
464
+ x_positions = np.arange(len(proportions_sorted))
465
+ bottom = np.zeros(len(proportions_sorted))
466
+
467
+ for i, cell_type in enumerate(cell_types):
468
+ values = proportions_sorted[cell_type].values
469
+ ax.bar(
470
+ x_positions,
471
+ values,
472
+ bottom=bottom,
473
+ color=colors[i],
474
+ label=cell_type,
475
+ width=1.0,
476
+ edgecolor="none",
477
+ )
478
+ bottom += values
479
+
480
+ ax.set_xlabel(params.sort_by.replace("_", " ").title())
481
+ ax.set_ylabel("Cell Type Proportion")
482
+ ax.set_title(
483
+ f"Cell Type Proportions ({data.method})\n"
484
+ f"Sorted by: {params.sort_by.replace('_', ' ').title()}"
485
+ )
486
+ ax.set_ylim((0, 1))
487
+ ax.set_xlim((0, len(proportions_sorted)))
488
+ ax.legend(
489
+ bbox_to_anchor=(1.05, 1),
490
+ loc="upper left",
491
+ ncol=1 if n_cell_types <= 15 else 2,
492
+ fontsize=8,
493
+ )
494
+ ax.set_xticks([])
495
+
496
+ plt.tight_layout()
497
+ return fig
498
+
499
+
500
+ async def _create_scatterpie_plot(
501
+ adata: "ad.AnnData",
502
+ params: VisualizationParameters,
503
+ context: Optional["ToolContext"] = None,
504
+ ) -> plt.Figure:
505
+ """Create spatial scatterpie plot (SPOTlight-style).
506
+
507
+ Shows cell type proportions as pie charts at each spatial location.
508
+ """
509
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
510
+ spatial_coords = require_spatial_coords(adata)
511
+
512
+ proportions_plot = data.proportions
513
+ coords_plot = spatial_coords
514
+
515
+ cell_types = proportions_plot.columns.tolist()
516
+ n_cell_types = len(cell_types)
517
+
518
+ # Use centralized colormap utility
519
+ color_list = get_category_colors(n_cell_types, params.colormap)
520
+ colors = {cell_type: color_list[i] for i, cell_type in enumerate(cell_types)}
521
+
522
+ # Create figure
523
+ fig, axes = create_figure_from_params(params, "deconvolution")
524
+ ax = axes[0]
525
+
526
+ # Calculate pie radius based on spatial scale
527
+ coord_range = np.ptp(coords_plot, axis=0).max()
528
+ base_radius = coord_range * 0.02
529
+ pie_radius = base_radius * params.pie_scale
530
+
531
+ for (x, y), (_, prop_row) in zip(coords_plot, proportions_plot.iterrows()):
532
+ prop_values = prop_row.values
533
+
534
+ if prop_values.sum() == 0:
535
+ continue
536
+
537
+ prop_normalized = prop_values / prop_values.sum()
538
+
539
+ start_angle = 0
540
+ for cell_type, proportion in zip(cell_types, prop_normalized, strict=False):
541
+ if proportion > 0.01:
542
+ angle = proportion * 360
543
+ wedge = Wedge(
544
+ center=(x, y),
545
+ r=pie_radius,
546
+ theta1=start_angle,
547
+ theta2=start_angle + angle,
548
+ facecolor=colors[cell_type],
549
+ edgecolor="white",
550
+ linewidth=0.5,
551
+ alpha=params.scatterpie_alpha,
552
+ )
553
+ ax.add_patch(wedge)
554
+ start_angle += angle
555
+
556
+ x_min, x_max = coords_plot[:, 0].min(), coords_plot[:, 0].max()
557
+ y_min, y_max = coords_plot[:, 1].min(), coords_plot[:, 1].max()
558
+ padding = pie_radius * 2
559
+ ax.set_xlim((x_min - padding, x_max + padding))
560
+ ax.set_ylim((y_min - padding, y_max + padding))
561
+
562
+ ax.set_xlabel("Spatial X")
563
+ ax.set_ylabel("Spatial Y")
564
+ ax.set_title(
565
+ f"Spatial Scatterpie Plot ({data.method})\n"
566
+ f"Cell Type Composition (pie scale: {params.pie_scale:.2f})"
567
+ )
568
+ ax.set_aspect("equal")
569
+
570
+ legend_elements = [Patch(facecolor=colors[ct], label=ct) for ct in cell_types]
571
+ ax.legend(
572
+ handles=legend_elements,
573
+ bbox_to_anchor=(1.05, 1),
574
+ loc="upper left",
575
+ ncol=1 if n_cell_types <= 15 else 2,
576
+ fontsize=8,
577
+ )
578
+
579
+ plt.tight_layout()
580
+ return fig
581
+
582
+
583
+ async def _create_umap_proportions(
584
+ adata: "ad.AnnData",
585
+ params: VisualizationParameters,
586
+ context: Optional["ToolContext"] = None,
587
+ ) -> plt.Figure:
588
+ """Create UMAP colored by cell type proportions.
589
+
590
+ Shows UMAP embeddings in multi-panel format, with each panel showing
591
+ the proportion of a specific cell type.
592
+ """
593
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
594
+
595
+ if "X_umap" not in adata.obsm:
596
+ raise DataNotFoundError(
597
+ "UMAP coordinates not found in adata.obsm['X_umap']. "
598
+ "Run UMAP dimensionality reduction first."
599
+ )
600
+ umap_coords = adata.obsm["X_umap"]
601
+
602
+ # Select top cell types by mean proportion
603
+ mean_proportions = data.proportions.mean(axis=0).sort_values(ascending=False)
604
+ top_cell_types = mean_proportions.head(params.n_cell_types).index.tolist()
605
+
606
+ n_panels = len(top_cell_types)
607
+ ncols = min(3, n_panels)
608
+ nrows = int(np.ceil(n_panels / ncols))
609
+
610
+ # Use centralized figure size resolution
611
+ figsize = resolve_figure_size(
612
+ params, n_panels=n_panels, panel_width=4, panel_height=3.5
613
+ )
614
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
615
+ axes = axes.flatten()
616
+
617
+ for idx, cell_type in enumerate(top_cell_types):
618
+ ax = axes[idx]
619
+ prop_values = data.proportions[cell_type].values
620
+
621
+ scatter = ax.scatter(
622
+ umap_coords[:, 0],
623
+ umap_coords[:, 1],
624
+ c=prop_values,
625
+ cmap=params.colormap or "viridis",
626
+ s=params.spot_size or 5,
627
+ alpha=0.8,
628
+ vmin=0,
629
+ vmax=1,
630
+ edgecolors="none",
631
+ )
632
+
633
+ ax.set_xlabel("UMAP 1")
634
+ ax.set_ylabel("UMAP 2")
635
+ ax.set_title(f"{cell_type}\n(mean: {mean_proportions[cell_type]:.3f})")
636
+ ax.set_aspect("equal")
637
+
638
+ cbar = plt.colorbar(scatter, ax=ax)
639
+ cbar.set_label("Proportion", rotation=270, labelpad=15, fontsize=8)
640
+
641
+ for idx in range(n_panels, len(axes)):
642
+ axes[idx].axis("off")
643
+
644
+ fig.suptitle(
645
+ f"UMAP Cell Type Proportions ({data.method})\n"
646
+ f"Top {n_panels} cell types (out of {len(data.cell_types)})",
647
+ fontsize=12,
648
+ y=0.995,
649
+ )
650
+
651
+ plt.tight_layout()
652
+ return fig
653
+
654
+
655
+ async def _create_spatial_multi_deconvolution(
656
+ adata: "ad.AnnData",
657
+ params: VisualizationParameters,
658
+ context: Optional["ToolContext"] = None,
659
+ ) -> plt.Figure:
660
+ """Multi-panel spatial deconvolution visualization.
661
+
662
+ Shows top N cell types as separate spatial plots.
663
+ """
664
+ data = await get_deconvolution_data(adata, params.deconv_method, context)
665
+
666
+ n_cell_types = min(params.n_cell_types, len(data.cell_types))
667
+ top_cell_types = (
668
+ data.proportions.mean().sort_values(ascending=False).index[:n_cell_types]
669
+ )
670
+
671
+ fig, axes = setup_multi_panel_figure(
672
+ n_panels=len(top_cell_types),
673
+ params=params,
674
+ default_title=f"{data.method.upper()} Cell Type Proportions",
675
+ )
676
+
677
+ temp_feature_key = "_deconv_viz_temp"
678
+
679
+ for i, cell_type in enumerate(top_cell_types):
680
+ if i < len(axes):
681
+ ax = axes[i]
682
+ try:
683
+ proportions_values = data.proportions[cell_type].values
684
+
685
+ if pd.isna(proportions_values).any():
686
+ proportions_values = pd.Series(proportions_values).fillna(0).values
687
+
688
+ adata.obs[temp_feature_key] = proportions_values
689
+
690
+ if "spatial" in adata.obsm:
691
+ plot_spatial_feature(
692
+ adata, feature=temp_feature_key, ax=ax, params=params
693
+ )
694
+ ax.set_title(cell_type)
695
+ ax.invert_yaxis()
696
+ else:
697
+ sorted_props = data.proportions[cell_type].sort_values(
698
+ ascending=False
699
+ )
700
+ ax.bar(
701
+ range(len(sorted_props)),
702
+ sorted_props.values,
703
+ alpha=params.alpha,
704
+ )
705
+ ax.set_title(cell_type)
706
+ ax.set_xlabel("Spots (sorted)")
707
+ ax.set_ylabel("Proportion")
708
+
709
+ except Exception as e:
710
+ ax.text(
711
+ 0.5,
712
+ 0.5,
713
+ f"Error plotting {cell_type}:\n{e}",
714
+ ha="center",
715
+ va="center",
716
+ transform=ax.transAxes,
717
+ )
718
+ ax.set_title(f"{cell_type} (Error)")
719
+
720
+ if temp_feature_key in adata.obs.columns:
721
+ del adata.obs[temp_feature_key]
722
+
723
+ fig.subplots_adjust(top=0.92, wspace=0.1, hspace=0.3, right=0.98)
724
+ return fig
725
+
726
+
727
+ # =============================================================================
728
+ # CARD Imputation Visualization
729
+ # =============================================================================
730
+
731
+
732
+ async def create_card_imputation_visualization(
733
+ adata: "ad.AnnData",
734
+ params: VisualizationParameters,
735
+ context: Optional["ToolContext"] = None,
736
+ ) -> plt.Figure:
737
+ """Create CARD imputation visualization.
738
+
739
+ CARD's unique CAR model allows imputation at unmeasured locations,
740
+ creating enhanced high-resolution spatial maps.
741
+
742
+ Args:
743
+ adata: AnnData object with CARD imputation results
744
+ params: Visualization parameters
745
+ context: MCP context for logging
746
+
747
+ Returns:
748
+ matplotlib Figure object
749
+
750
+ Raises:
751
+ DataNotFoundError: If CARD imputation data not found or feature not found
752
+ """
753
+ if context:
754
+ await context.info("Creating CARD imputation visualization")
755
+
756
+ # Check if CARD imputation data exists
757
+ if "card_imputation" not in adata.uns:
758
+ raise DataNotFoundError(
759
+ "CARD imputation data not found. Run CARD with card_imputation=True."
760
+ )
761
+
762
+ # Extract imputation data
763
+ impute_data = adata.uns["card_imputation"]
764
+ imputed_proportions = impute_data["proportions"]
765
+ imputed_coords = impute_data["coordinates"]
766
+
767
+ # Determine what to visualize
768
+ feature = params.feature
769
+ if not feature:
770
+ feature = "dominant"
771
+
772
+ # Create figure using centralized utility
773
+ fig, axes = create_figure_from_params(params, "deconvolution")
774
+ ax = axes[0]
775
+
776
+ if feature == "dominant":
777
+ # Show dominant cell types
778
+ dominant_types = imputed_proportions.idxmax(axis=1)
779
+ unique_types = dominant_types.unique()
780
+
781
+ # Use centralized colormap utility
782
+ colors = get_category_colors(len(unique_types), params.colormap)
783
+ color_map = {ct: colors[i] for i, ct in enumerate(unique_types)}
784
+ point_colors = [color_map[ct] for ct in dominant_types]
785
+
786
+ ax.scatter(
787
+ imputed_coords["x"],
788
+ imputed_coords["y"],
789
+ c=point_colors,
790
+ s=25,
791
+ edgecolors="none",
792
+ alpha=0.7,
793
+ )
794
+
795
+ ax.set_title(
796
+ f"CARD Imputation: Dominant Cell Types\n"
797
+ f"({len(imputed_coords)} locations, "
798
+ f"{impute_data['resolution_increase']:.1f}x resolution)",
799
+ fontsize=14,
800
+ fontweight="bold",
801
+ )
802
+
803
+ legend_elements = [
804
+ Patch(facecolor=color_map[ct], label=ct) for ct in sorted(unique_types)
805
+ ]
806
+ ax.legend(
807
+ handles=legend_elements,
808
+ bbox_to_anchor=(1.05, 1),
809
+ loc="upper left",
810
+ fontsize=9,
811
+ )
812
+
813
+ elif feature in imputed_proportions.columns:
814
+ # Show specific cell type proportion
815
+ scatter = ax.scatter(
816
+ imputed_coords["x"],
817
+ imputed_coords["y"],
818
+ c=imputed_proportions[feature],
819
+ s=30,
820
+ cmap=params.colormap or "viridis",
821
+ vmin=0,
822
+ vmax=imputed_proportions[feature].quantile(0.95),
823
+ edgecolors="none",
824
+ alpha=0.8,
825
+ )
826
+
827
+ ax.set_title(
828
+ f"CARD Imputation: {feature}\n"
829
+ f"(Mean: {imputed_proportions[feature].mean():.3f}, "
830
+ f"{len(imputed_coords)} locations)",
831
+ fontsize=14,
832
+ fontweight="bold",
833
+ )
834
+
835
+ cbar = plt.colorbar(scatter, ax=ax)
836
+ cbar.set_label("Proportion", fontsize=12)
837
+
838
+ else:
839
+ raise DataNotFoundError(
840
+ f"Feature '{feature}' not found. "
841
+ f"Available: {list(imputed_proportions.columns)[:5]}..."
842
+ )
843
+
844
+ ax.set_xlabel("X coordinate", fontsize=12)
845
+ ax.set_ylabel("Y coordinate", fontsize=12)
846
+ ax.set_aspect("equal")
847
+ plt.tight_layout()
848
+
849
+ if context:
850
+ await context.info("CARD imputation visualization created successfully")
851
+
852
+ return fig