chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,684 @@
1
+ """
2
+ Core visualization utilities and shared functions.
3
+
4
+ This module contains:
5
+ - Figure setup and utility functions
6
+ - Shared data structures
7
+ - Common visualization helpers
8
+ """
9
+
10
+ from typing import TYPE_CHECKING, NamedTuple, Optional
11
+
12
+ import anndata as ad
13
+ import matplotlib
14
+
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import pandas as pd
19
+ import seaborn as sns
20
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
21
+
22
+ from ...models.data import VisualizationParameters
23
+ from ...utils.adata_utils import get_gene_expression, require_spatial_coords
24
+ from ...utils.exceptions import DataNotFoundError, ParameterError
25
+
26
+ plt.ioff()
27
+
28
+ if TYPE_CHECKING:
29
+ from ...spatial_mcp_adapter import ToolContext
30
+
31
+
32
+ # =============================================================================
33
+ # Figure Creation Utilities
34
+ # =============================================================================
35
+
36
+
37
+ # Default figure sizes by plot type for consistency
38
+ FIGURE_DEFAULTS = {
39
+ "spatial": (10, 8),
40
+ "umap": (10, 8),
41
+ "heatmap": (12, 10),
42
+ "violin": (12, 6),
43
+ "dotplot": (10, 8),
44
+ "trajectory": (10, 10),
45
+ "gene_trends": (12, 6),
46
+ "velocity": (10, 8),
47
+ "deconvolution": (10, 8),
48
+ "cell_communication": (10, 10),
49
+ "enrichment": (6, 8),
50
+ "cnv": (12, 8),
51
+ "integration": (16, 12),
52
+ "default": (10, 8),
53
+ }
54
+
55
+
56
+ def resolve_figure_size(
57
+ params: VisualizationParameters,
58
+ plot_type: str = "default",
59
+ n_panels: Optional[int] = None,
60
+ panel_width: float = 5.0,
61
+ panel_height: float = 4.0,
62
+ ) -> tuple[int, int]:
63
+ """Resolve figure size from params with smart defaults.
64
+
65
+ This centralizes figure size resolution logic to ensure consistency
66
+ across all visualization modules.
67
+
68
+ Args:
69
+ params: VisualizationParameters with optional figure_size
70
+ plot_type: Type of plot for default selection (e.g., "spatial", "heatmap")
71
+ n_panels: Number of panels for multi-panel figures
72
+ panel_width: Width per panel for multi-panel figures
73
+ panel_height: Height per panel for multi-panel figures
74
+
75
+ Returns:
76
+ Tuple of (width, height) in inches
77
+
78
+ Examples:
79
+ >>> resolve_figure_size(params, "spatial") # User override or (10, 8)
80
+ >>> resolve_figure_size(params, n_panels=4) # Compute from panel count
81
+ """
82
+ # User-specified size always takes precedence
83
+ if params.figure_size:
84
+ return params.figure_size
85
+
86
+ # Multi-panel figure: compute from panel dimensions
87
+ if n_panels is not None and n_panels > 1:
88
+ n_cols = min(3, n_panels)
89
+ n_rows = (n_panels + n_cols - 1) // n_cols
90
+ width = min(panel_width * n_cols, 15)
91
+ height = min(panel_height * n_rows, 16)
92
+ return (int(width), int(height))
93
+
94
+ # Use plot-type specific default
95
+ return FIGURE_DEFAULTS.get(plot_type, FIGURE_DEFAULTS["default"])
96
+
97
+
98
+ def create_figure(figsize: tuple[int, int] = (10, 8)) -> tuple[plt.Figure, plt.Axes]:
99
+ """Create a matplotlib figure with the right size and style."""
100
+ fig, ax = plt.subplots(figsize=figsize)
101
+ return fig, ax
102
+
103
+
104
+ def create_figure_from_params(
105
+ params: VisualizationParameters,
106
+ plot_type: str = "default",
107
+ n_panels: Optional[int] = None,
108
+ n_rows: int = 1,
109
+ n_cols: int = 1,
110
+ squeeze: bool = True,
111
+ ) -> tuple[plt.Figure, np.ndarray]:
112
+ """Create a figure with axes from visualization parameters.
113
+
114
+ This is the preferred way to create figures in visualization modules.
115
+ It centralizes figure size resolution and applies consistent settings.
116
+
117
+ Args:
118
+ params: VisualizationParameters
119
+ plot_type: Type of plot for default size selection
120
+ n_panels: Number of panels (for auto-layout calculation)
121
+ n_rows: Number of subplot rows
122
+ n_cols: Number of subplot columns
123
+ squeeze: Whether to squeeze single-element arrays
124
+
125
+ Returns:
126
+ Tuple of (Figure, array of Axes)
127
+
128
+ Examples:
129
+ >>> fig, axes = create_figure_from_params(params, "spatial")
130
+ >>> fig, axes = create_figure_from_params(params, n_rows=2, n_cols=3)
131
+ """
132
+ figsize = resolve_figure_size(params, plot_type, n_panels)
133
+
134
+ fig, axes = plt.subplots(
135
+ n_rows,
136
+ n_cols,
137
+ figsize=figsize,
138
+ dpi=params.dpi,
139
+ squeeze=squeeze,
140
+ )
141
+
142
+ # Ensure axes is always an array for consistent handling
143
+ if squeeze and n_rows == 1 and n_cols == 1:
144
+ axes = np.array([axes])
145
+ elif squeeze and (n_rows == 1 or n_cols == 1):
146
+ axes = np.atleast_1d(axes)
147
+
148
+ return fig, axes
149
+
150
+
151
+ def setup_multi_panel_figure(
152
+ n_panels: int,
153
+ params: VisualizationParameters,
154
+ default_title: str,
155
+ use_tight_layout: bool = False,
156
+ ) -> tuple[plt.Figure, np.ndarray]:
157
+ """Sets up a multi-panel matplotlib figure.
158
+
159
+ Args:
160
+ n_panels: The total number of panels required.
161
+ params: VisualizationParameters object with GridSpec spacing parameters.
162
+ default_title: Default title for the figure if not provided in params.
163
+ use_tight_layout: If True, skip gridspec_kw and use tight_layout.
164
+
165
+ Returns:
166
+ A tuple of (matplotlib.Figure, flattened numpy.ndarray of Axes).
167
+ """
168
+ if params.panel_layout:
169
+ n_rows, n_cols = params.panel_layout
170
+ else:
171
+ n_cols = min(3, n_panels)
172
+ n_rows = (n_panels + n_cols - 1) // n_cols
173
+
174
+ if params.figure_size:
175
+ figsize = params.figure_size
176
+ else:
177
+ figsize = (min(5 * n_cols, 15), min(4 * n_rows, 16))
178
+
179
+ if not use_tight_layout:
180
+ fig, axes = plt.subplots(
181
+ n_rows,
182
+ n_cols,
183
+ figsize=figsize,
184
+ dpi=params.dpi,
185
+ squeeze=False,
186
+ gridspec_kw={
187
+ "wspace": params.subplot_wspace,
188
+ "hspace": params.subplot_hspace,
189
+ },
190
+ )
191
+ else:
192
+ fig, axes = plt.subplots(
193
+ n_rows, n_cols, figsize=figsize, dpi=params.dpi, squeeze=False
194
+ )
195
+
196
+ axes = axes.flatten()
197
+
198
+ # Only set suptitle if title is explicitly provided and non-empty
199
+ title = params.title or default_title
200
+ if title:
201
+ fig.suptitle(title, fontsize=16)
202
+
203
+ for i in range(n_panels, len(axes)):
204
+ axes[i].axis("off")
205
+
206
+ return fig, axes
207
+
208
+
209
+ def add_colorbar(
210
+ fig: plt.Figure,
211
+ ax: plt.Axes,
212
+ mappable,
213
+ params: VisualizationParameters,
214
+ label: str = "",
215
+ ) -> None:
216
+ """Add a colorbar to an axis with consistent styling.
217
+
218
+ Args:
219
+ fig: The figure object
220
+ ax: The axes object to attach colorbar to
221
+ mappable: The mappable object (from scatter, imshow, etc.)
222
+ params: Visualization parameters for styling
223
+ label: Colorbar label
224
+ """
225
+ divider = make_axes_locatable(ax)
226
+ cax = divider.append_axes(
227
+ "right", size=params.colorbar_size, pad=params.colorbar_pad
228
+ )
229
+ cbar = fig.colorbar(mappable, cax=cax)
230
+ if label:
231
+ cbar.set_label(label, fontsize=10)
232
+
233
+
234
+ # =============================================================================
235
+ # Data Structures for Unified Data Access
236
+ # =============================================================================
237
+
238
+
239
+ class DeconvolutionData(NamedTuple):
240
+ """Unified representation of deconvolution results.
241
+
242
+ Attributes:
243
+ proportions: DataFrame with cell type proportions (n_spots x n_cell_types)
244
+ method: Deconvolution method name (e.g., "cell2location", "rctd")
245
+ cell_types: List of cell type names
246
+ proportions_key: Key in adata.obsm where proportions are stored
247
+ dominant_type_key: Key in adata.obs for dominant cell type (if exists)
248
+ """
249
+
250
+ proportions: pd.DataFrame
251
+ method: str
252
+ cell_types: list[str]
253
+ proportions_key: str
254
+ dominant_type_key: Optional[str] = None
255
+
256
+
257
+ class CellCommunicationData(NamedTuple):
258
+ """Unified representation of cell communication analysis results.
259
+
260
+ Attributes:
261
+ results: Main results DataFrame (format varies by method)
262
+ method: Analysis method name ("liana_cluster", "liana_spatial", "cellphonedb")
263
+ analysis_type: Type of analysis ("cluster" or "spatial")
264
+ lr_pairs: List of ligand-receptor pair names
265
+ spatial_scores: Spatial communication scores array (n_spots x n_pairs)
266
+ spatial_pvals: P-values for spatial scores (optional)
267
+ source_labels: List of source cell type labels
268
+ target_labels: List of target cell type labels
269
+ results_key: Key in adata.uns where results are stored
270
+ """
271
+
272
+ results: pd.DataFrame
273
+ method: str
274
+ analysis_type: str # "cluster" or "spatial"
275
+ lr_pairs: list[str]
276
+ spatial_scores: Optional[np.ndarray] = None
277
+ spatial_pvals: Optional[np.ndarray] = None
278
+ source_labels: Optional[list[str]] = None
279
+ target_labels: Optional[list[str]] = None
280
+ results_key: str = ""
281
+
282
+
283
+ # =============================================================================
284
+ # Feature Validation and Preparation
285
+ # =============================================================================
286
+
287
+
288
+ async def get_validated_features(
289
+ adata: ad.AnnData,
290
+ params: VisualizationParameters,
291
+ context: Optional["ToolContext"] = None,
292
+ max_features: Optional[int] = None,
293
+ genes_only: bool = False,
294
+ ) -> list[str]:
295
+ """Validate and return features for visualization.
296
+
297
+ Args:
298
+ adata: AnnData object
299
+ params: Visualization parameters containing feature specification
300
+ context: Optional tool context for logging
301
+ max_features: Maximum number of features to return (truncates if exceeded)
302
+ genes_only: If True, only validate against var_names (genes).
303
+ If False, also check obs columns and obsm keys.
304
+
305
+ Returns:
306
+ List of validated feature names
307
+ """
308
+ if params.feature is None:
309
+ features: list[str] = []
310
+ elif isinstance(params.feature, list):
311
+ features = params.feature
312
+ else:
313
+ features = [params.feature]
314
+ validated: list[str] = []
315
+
316
+ for feat in features:
317
+ # Check if feature is in var_names (genes)
318
+ if feat in adata.var_names:
319
+ validated.append(feat)
320
+ elif not genes_only:
321
+ # Also check obs columns and obsm keys
322
+ if feat in adata.obs.columns:
323
+ validated.append(feat)
324
+ elif feat in adata.obsm:
325
+ validated.append(feat)
326
+ else:
327
+ if context:
328
+ await context.warning(
329
+ f"Feature '{feat}' not found in genes, obs, or obsm"
330
+ )
331
+ else:
332
+ if context:
333
+ await context.warning(f"Gene '{feat}' not found in var_names")
334
+
335
+ # Truncate if max_features specified
336
+ if max_features is not None and len(validated) > max_features:
337
+ if context:
338
+ await context.warning(
339
+ f"Too many features ({len(validated)}), limiting to {max_features}"
340
+ )
341
+ validated = validated[:max_features]
342
+
343
+ return validated
344
+
345
+
346
+ def validate_and_prepare_feature(
347
+ adata: ad.AnnData,
348
+ feature: str,
349
+ context: Optional["ToolContext"] = None,
350
+ ) -> tuple[np.ndarray, str, bool]:
351
+ """Validate a single feature and prepare its data for visualization.
352
+
353
+ Args:
354
+ adata: AnnData object
355
+ feature: Feature name to validate
356
+ context: Optional tool context for logging
357
+
358
+ Returns:
359
+ Tuple of (data array, display name, is_categorical)
360
+ """
361
+ # Gene expression - use unified utility
362
+ if feature in adata.var_names:
363
+ data = get_gene_expression(adata, feature)
364
+ return data, feature, False
365
+
366
+ # Observation column
367
+ if feature in adata.obs.columns:
368
+ data = adata.obs[feature]
369
+ is_cat = pd.api.types.is_categorical_dtype(data) or data.dtype == object
370
+ return data.values, feature, is_cat
371
+
372
+ raise DataNotFoundError(f"Feature '{feature}' not found in data")
373
+
374
+
375
+ # =============================================================================
376
+ # Colormap Utilities
377
+ # =============================================================================
378
+
379
+ # Categorical colormaps by size threshold
380
+ _CATEGORICAL_CMAPS = {
381
+ 10: "tab10", # Best for <= 10 categories
382
+ 20: "tab20", # Best for 11-20 categories
383
+ 40: "tab20b", # Extended palette for more categories
384
+ }
385
+
386
+
387
+ def get_categorical_cmap(n_categories: int, user_cmap: Optional[str] = None) -> str:
388
+ """Select the best categorical colormap based on number of categories.
389
+
390
+ This centralizes the categorical colormap selection logic that was
391
+ previously scattered across visualization modules.
392
+
393
+ Args:
394
+ n_categories: Number of distinct categories to color
395
+ user_cmap: User-specified colormap (takes precedence if provided
396
+ and is a known categorical palette)
397
+
398
+ Returns:
399
+ Colormap name suitable for categorical data
400
+
401
+ Examples:
402
+ >>> get_categorical_cmap(5) # Returns "tab10"
403
+ >>> get_categorical_cmap(15) # Returns "tab20"
404
+ >>> get_categorical_cmap(8, user_cmap="Set2") # Returns "Set2"
405
+ """
406
+ # Known categorical palettes that user might specify
407
+ categorical_palettes = {
408
+ "tab10", "tab20", "tab20b", "tab20c",
409
+ "Set1", "Set2", "Set3", "Paired", "Accent",
410
+ "Dark2", "Pastel1", "Pastel2",
411
+ }
412
+
413
+ # User preference takes precedence if it's a categorical palette
414
+ if user_cmap and user_cmap in categorical_palettes:
415
+ return user_cmap
416
+
417
+ # Auto-select based on category count
418
+ for threshold, cmap in sorted(_CATEGORICAL_CMAPS.items()):
419
+ if n_categories <= threshold:
420
+ return cmap
421
+
422
+ # Fallback for very large category counts
423
+ return "tab20"
424
+
425
+
426
+ def get_category_colors(
427
+ n_categories: int,
428
+ cmap_name: Optional[str] = None,
429
+ ) -> list:
430
+ """Get a list of colors for categorical data.
431
+
432
+ This is the primary function for obtaining colors for categorical
433
+ visualizations. It handles colormap selection and color extraction.
434
+
435
+ Args:
436
+ n_categories: Number of categories to color
437
+ cmap_name: Colormap name (auto-selected if None)
438
+
439
+ Returns:
440
+ List of colors (can be used with matplotlib scatter, legend, etc.)
441
+
442
+ Examples:
443
+ >>> colors = get_category_colors(5) # 5 distinct colors
444
+ >>> colors = get_category_colors(15, "tab20") # 15 colors from tab20
445
+ """
446
+ # Select appropriate colormap
447
+ if cmap_name is None:
448
+ cmap_name = get_categorical_cmap(n_categories)
449
+
450
+ # Seaborn palettes
451
+ if cmap_name in ["tab10", "tab20", "Set1", "Set2", "Set3", "Paired", "husl"]:
452
+ return sns.color_palette(cmap_name, n_colors=n_categories)
453
+
454
+ # Matplotlib colormaps
455
+ cmap = plt.get_cmap(cmap_name)
456
+ return [cmap(i / max(n_categories - 1, 1)) for i in range(n_categories)]
457
+
458
+
459
+ def get_colormap(name: str, n_colors: Optional[int] = None):
460
+ """Get a matplotlib colormap by name.
461
+
462
+ For categorical data, prefer using get_category_colors() instead.
463
+ This function is for backward compatibility and continuous colormaps.
464
+
465
+ Args:
466
+ name: Colormap name (supports matplotlib and seaborn palettes)
467
+ n_colors: Number of discrete colors (for categorical data)
468
+
469
+ Returns:
470
+ If n_colors is specified: List of colors (always indexable)
471
+ Otherwise: Colormap object (for continuous data)
472
+ """
473
+ # For categorical with n_colors, delegate to specialized function
474
+ if n_colors:
475
+ return get_category_colors(n_colors, name)
476
+
477
+ # Check if it's a seaborn palette (return as palette for consistency)
478
+ if name in ["tab10", "tab20", "Set1", "Set2", "Set3", "Paired", "husl"]:
479
+ return sns.color_palette(name)
480
+
481
+ # For matplotlib colormaps, return the colormap object
482
+ return plt.get_cmap(name)
483
+
484
+
485
+ def get_diverging_colormap(center: float = 0.0) -> str:
486
+ """Get an appropriate diverging colormap centered at a value."""
487
+ return "RdBu_r"
488
+
489
+
490
+ # =============================================================================
491
+ # Spatial Plot Utilities
492
+ # =============================================================================
493
+
494
+
495
+ def plot_spatial_feature(
496
+ adata: ad.AnnData,
497
+ ax: plt.Axes,
498
+ feature: Optional[str] = None,
499
+ values: Optional[np.ndarray] = None,
500
+ params: Optional[VisualizationParameters] = None,
501
+ spatial_key: str = "spatial",
502
+ show_colorbar: bool = True,
503
+ title: Optional[str] = None,
504
+ ) -> Optional[plt.cm.ScalarMappable]:
505
+ """Plot a feature on spatial coordinates.
506
+
507
+ Args:
508
+ adata: AnnData object with spatial coordinates
509
+ ax: Matplotlib axes to plot on
510
+ feature: Feature name (gene or obs column)
511
+ values: Pre-computed values to plot (overrides feature)
512
+ params: Visualization parameters
513
+ spatial_key: Key for spatial coordinates in obsm
514
+ show_colorbar: Whether to add a colorbar
515
+ title: Plot title
516
+
517
+ Returns:
518
+ ScalarMappable for colorbar creation, or None for categorical data
519
+ """
520
+ if params is None:
521
+ params = VisualizationParameters() # type: ignore[call-arg]
522
+
523
+ # Get spatial coordinates
524
+ coords = require_spatial_coords(adata, spatial_key=spatial_key)
525
+
526
+ # Get values to plot
527
+ if values is not None:
528
+ plot_values = values
529
+ is_categorical = pd.api.types.is_categorical_dtype(values)
530
+ elif feature is not None:
531
+ if feature in adata.var_names:
532
+ # Use unified utility for gene expression extraction
533
+ plot_values = get_gene_expression(adata, feature)
534
+ is_categorical = False
535
+ elif feature in adata.obs.columns:
536
+ plot_values = adata.obs[feature].values
537
+ is_categorical = pd.api.types.is_categorical_dtype(adata.obs[feature])
538
+ else:
539
+ raise DataNotFoundError(f"Feature '{feature}' not found")
540
+ else:
541
+ raise ParameterError("Either feature or values must be provided")
542
+
543
+ # Handle categorical data
544
+ if is_categorical:
545
+ categories = (
546
+ plot_values.categories
547
+ if hasattr(plot_values, "categories")
548
+ else np.unique(plot_values)
549
+ )
550
+ n_cats = len(categories)
551
+ colors = get_colormap(params.colormap, n_colors=n_cats)
552
+ cat_to_idx = {cat: i for i, cat in enumerate(categories)}
553
+ color_indices = [cat_to_idx[v] for v in plot_values]
554
+
555
+ scatter = ax.scatter(
556
+ coords[:, 0],
557
+ coords[:, 1],
558
+ c=[colors[i] for i in color_indices],
559
+ s=params.spot_size,
560
+ alpha=params.alpha,
561
+ )
562
+
563
+ # Add legend for categorical
564
+ if params.show_legend:
565
+ handles = [
566
+ plt.Line2D(
567
+ [0],
568
+ [0],
569
+ marker="o",
570
+ color="w",
571
+ markerfacecolor=colors[i],
572
+ markersize=8,
573
+ )
574
+ for i in range(n_cats)
575
+ ]
576
+ ax.legend(
577
+ handles,
578
+ categories,
579
+ loc="center left",
580
+ bbox_to_anchor=(1, 0.5),
581
+ fontsize=8,
582
+ )
583
+ mappable = None
584
+ else:
585
+ # Continuous data
586
+ cmap = get_colormap(params.colormap)
587
+ scatter = ax.scatter(
588
+ coords[:, 0],
589
+ coords[:, 1],
590
+ c=plot_values,
591
+ cmap=cmap,
592
+ s=params.spot_size,
593
+ alpha=params.alpha,
594
+ vmin=params.vmin,
595
+ vmax=params.vmax,
596
+ )
597
+ mappable = scatter
598
+
599
+ ax.set_aspect("equal")
600
+ ax.set_xlabel("")
601
+ ax.set_ylabel("")
602
+
603
+ if not params.show_axes:
604
+ ax.axis("off")
605
+
606
+ if title:
607
+ ax.set_title(title, fontsize=12)
608
+
609
+ return mappable
610
+
611
+
612
+ # =============================================================================
613
+ # Data Inference Utilities
614
+ # =============================================================================
615
+
616
+
617
+ def get_categorical_columns(
618
+ adata: ad.AnnData,
619
+ limit: Optional[int] = None,
620
+ ) -> list[str]:
621
+ """Get categorical column names from adata.obs.
622
+
623
+ Args:
624
+ adata: AnnData object
625
+ limit: Maximum number of columns to return (None for all)
626
+
627
+ Returns:
628
+ List of categorical column names
629
+ """
630
+ categorical_cols = [
631
+ col
632
+ for col in adata.obs.columns
633
+ if adata.obs[col].dtype.name in ["object", "category"]
634
+ ]
635
+ if limit is not None:
636
+ return categorical_cols[:limit]
637
+ return categorical_cols
638
+
639
+
640
+ def infer_basis(
641
+ adata: ad.AnnData,
642
+ preferred: Optional[str] = None,
643
+ priority: Optional[list[str]] = None,
644
+ ) -> Optional[str]:
645
+ """Infer the best embedding basis from available options.
646
+
647
+ Args:
648
+ adata: AnnData object
649
+ preferred: User-specified preferred basis (returned if valid)
650
+ priority: Priority order for basis selection.
651
+ Default: ["spatial", "umap", "pca"]
652
+
653
+ Returns:
654
+ Best available basis name (without X_ prefix), or None if none found
655
+
656
+ Examples:
657
+ >>> infer_basis(adata) # Auto-detect: spatial > umap > pca
658
+ 'umap'
659
+ >>> infer_basis(adata, preferred='tsne') # Use if valid
660
+ 'tsne'
661
+ >>> infer_basis(adata, priority=['umap', 'spatial']) # Custom order
662
+ 'umap'
663
+ """
664
+ if priority is None:
665
+ priority = ["spatial", "umap", "pca"]
666
+
667
+ # Check preferred basis first
668
+ if preferred:
669
+ key = preferred if preferred == "spatial" else f"X_{preferred}"
670
+ if key in adata.obsm:
671
+ return preferred
672
+
673
+ # Check priority list
674
+ for basis in priority:
675
+ key = basis if basis == "spatial" else f"X_{basis}"
676
+ if key in adata.obsm:
677
+ return basis
678
+
679
+ # Fallback: return first available X_* key
680
+ for key in adata.obsm.keys():
681
+ if key.startswith("X_"):
682
+ return key[2:] # Strip X_ prefix
683
+
684
+ return None