chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,684 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core visualization utilities and shared functions.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- Figure setup and utility functions
|
|
6
|
+
- Shared data structures
|
|
7
|
+
- Common visualization helpers
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, NamedTuple, Optional
|
|
11
|
+
|
|
12
|
+
import anndata as ad
|
|
13
|
+
import matplotlib
|
|
14
|
+
|
|
15
|
+
matplotlib.use("Agg")
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import seaborn as sns
|
|
20
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
21
|
+
|
|
22
|
+
from ...models.data import VisualizationParameters
|
|
23
|
+
from ...utils.adata_utils import get_gene_expression, require_spatial_coords
|
|
24
|
+
from ...utils.exceptions import DataNotFoundError, ParameterError
|
|
25
|
+
|
|
26
|
+
plt.ioff()
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =============================================================================
|
|
33
|
+
# Figure Creation Utilities
|
|
34
|
+
# =============================================================================
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Default figure sizes by plot type for consistency
|
|
38
|
+
FIGURE_DEFAULTS = {
|
|
39
|
+
"spatial": (10, 8),
|
|
40
|
+
"umap": (10, 8),
|
|
41
|
+
"heatmap": (12, 10),
|
|
42
|
+
"violin": (12, 6),
|
|
43
|
+
"dotplot": (10, 8),
|
|
44
|
+
"trajectory": (10, 10),
|
|
45
|
+
"gene_trends": (12, 6),
|
|
46
|
+
"velocity": (10, 8),
|
|
47
|
+
"deconvolution": (10, 8),
|
|
48
|
+
"cell_communication": (10, 10),
|
|
49
|
+
"enrichment": (6, 8),
|
|
50
|
+
"cnv": (12, 8),
|
|
51
|
+
"integration": (16, 12),
|
|
52
|
+
"default": (10, 8),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def resolve_figure_size(
|
|
57
|
+
params: VisualizationParameters,
|
|
58
|
+
plot_type: str = "default",
|
|
59
|
+
n_panels: Optional[int] = None,
|
|
60
|
+
panel_width: float = 5.0,
|
|
61
|
+
panel_height: float = 4.0,
|
|
62
|
+
) -> tuple[int, int]:
|
|
63
|
+
"""Resolve figure size from params with smart defaults.
|
|
64
|
+
|
|
65
|
+
This centralizes figure size resolution logic to ensure consistency
|
|
66
|
+
across all visualization modules.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
params: VisualizationParameters with optional figure_size
|
|
70
|
+
plot_type: Type of plot for default selection (e.g., "spatial", "heatmap")
|
|
71
|
+
n_panels: Number of panels for multi-panel figures
|
|
72
|
+
panel_width: Width per panel for multi-panel figures
|
|
73
|
+
panel_height: Height per panel for multi-panel figures
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (width, height) in inches
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
>>> resolve_figure_size(params, "spatial") # User override or (10, 8)
|
|
80
|
+
>>> resolve_figure_size(params, n_panels=4) # Compute from panel count
|
|
81
|
+
"""
|
|
82
|
+
# User-specified size always takes precedence
|
|
83
|
+
if params.figure_size:
|
|
84
|
+
return params.figure_size
|
|
85
|
+
|
|
86
|
+
# Multi-panel figure: compute from panel dimensions
|
|
87
|
+
if n_panels is not None and n_panels > 1:
|
|
88
|
+
n_cols = min(3, n_panels)
|
|
89
|
+
n_rows = (n_panels + n_cols - 1) // n_cols
|
|
90
|
+
width = min(panel_width * n_cols, 15)
|
|
91
|
+
height = min(panel_height * n_rows, 16)
|
|
92
|
+
return (int(width), int(height))
|
|
93
|
+
|
|
94
|
+
# Use plot-type specific default
|
|
95
|
+
return FIGURE_DEFAULTS.get(plot_type, FIGURE_DEFAULTS["default"])
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_figure(figsize: tuple[int, int] = (10, 8)) -> tuple[plt.Figure, plt.Axes]:
|
|
99
|
+
"""Create a matplotlib figure with the right size and style."""
|
|
100
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
101
|
+
return fig, ax
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def create_figure_from_params(
|
|
105
|
+
params: VisualizationParameters,
|
|
106
|
+
plot_type: str = "default",
|
|
107
|
+
n_panels: Optional[int] = None,
|
|
108
|
+
n_rows: int = 1,
|
|
109
|
+
n_cols: int = 1,
|
|
110
|
+
squeeze: bool = True,
|
|
111
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
112
|
+
"""Create a figure with axes from visualization parameters.
|
|
113
|
+
|
|
114
|
+
This is the preferred way to create figures in visualization modules.
|
|
115
|
+
It centralizes figure size resolution and applies consistent settings.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
params: VisualizationParameters
|
|
119
|
+
plot_type: Type of plot for default size selection
|
|
120
|
+
n_panels: Number of panels (for auto-layout calculation)
|
|
121
|
+
n_rows: Number of subplot rows
|
|
122
|
+
n_cols: Number of subplot columns
|
|
123
|
+
squeeze: Whether to squeeze single-element arrays
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Tuple of (Figure, array of Axes)
|
|
127
|
+
|
|
128
|
+
Examples:
|
|
129
|
+
>>> fig, axes = create_figure_from_params(params, "spatial")
|
|
130
|
+
>>> fig, axes = create_figure_from_params(params, n_rows=2, n_cols=3)
|
|
131
|
+
"""
|
|
132
|
+
figsize = resolve_figure_size(params, plot_type, n_panels)
|
|
133
|
+
|
|
134
|
+
fig, axes = plt.subplots(
|
|
135
|
+
n_rows,
|
|
136
|
+
n_cols,
|
|
137
|
+
figsize=figsize,
|
|
138
|
+
dpi=params.dpi,
|
|
139
|
+
squeeze=squeeze,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Ensure axes is always an array for consistent handling
|
|
143
|
+
if squeeze and n_rows == 1 and n_cols == 1:
|
|
144
|
+
axes = np.array([axes])
|
|
145
|
+
elif squeeze and (n_rows == 1 or n_cols == 1):
|
|
146
|
+
axes = np.atleast_1d(axes)
|
|
147
|
+
|
|
148
|
+
return fig, axes
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def setup_multi_panel_figure(
|
|
152
|
+
n_panels: int,
|
|
153
|
+
params: VisualizationParameters,
|
|
154
|
+
default_title: str,
|
|
155
|
+
use_tight_layout: bool = False,
|
|
156
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
157
|
+
"""Sets up a multi-panel matplotlib figure.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
n_panels: The total number of panels required.
|
|
161
|
+
params: VisualizationParameters object with GridSpec spacing parameters.
|
|
162
|
+
default_title: Default title for the figure if not provided in params.
|
|
163
|
+
use_tight_layout: If True, skip gridspec_kw and use tight_layout.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A tuple of (matplotlib.Figure, flattened numpy.ndarray of Axes).
|
|
167
|
+
"""
|
|
168
|
+
if params.panel_layout:
|
|
169
|
+
n_rows, n_cols = params.panel_layout
|
|
170
|
+
else:
|
|
171
|
+
n_cols = min(3, n_panels)
|
|
172
|
+
n_rows = (n_panels + n_cols - 1) // n_cols
|
|
173
|
+
|
|
174
|
+
if params.figure_size:
|
|
175
|
+
figsize = params.figure_size
|
|
176
|
+
else:
|
|
177
|
+
figsize = (min(5 * n_cols, 15), min(4 * n_rows, 16))
|
|
178
|
+
|
|
179
|
+
if not use_tight_layout:
|
|
180
|
+
fig, axes = plt.subplots(
|
|
181
|
+
n_rows,
|
|
182
|
+
n_cols,
|
|
183
|
+
figsize=figsize,
|
|
184
|
+
dpi=params.dpi,
|
|
185
|
+
squeeze=False,
|
|
186
|
+
gridspec_kw={
|
|
187
|
+
"wspace": params.subplot_wspace,
|
|
188
|
+
"hspace": params.subplot_hspace,
|
|
189
|
+
},
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
fig, axes = plt.subplots(
|
|
193
|
+
n_rows, n_cols, figsize=figsize, dpi=params.dpi, squeeze=False
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
axes = axes.flatten()
|
|
197
|
+
|
|
198
|
+
# Only set suptitle if title is explicitly provided and non-empty
|
|
199
|
+
title = params.title or default_title
|
|
200
|
+
if title:
|
|
201
|
+
fig.suptitle(title, fontsize=16)
|
|
202
|
+
|
|
203
|
+
for i in range(n_panels, len(axes)):
|
|
204
|
+
axes[i].axis("off")
|
|
205
|
+
|
|
206
|
+
return fig, axes
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def add_colorbar(
|
|
210
|
+
fig: plt.Figure,
|
|
211
|
+
ax: plt.Axes,
|
|
212
|
+
mappable,
|
|
213
|
+
params: VisualizationParameters,
|
|
214
|
+
label: str = "",
|
|
215
|
+
) -> None:
|
|
216
|
+
"""Add a colorbar to an axis with consistent styling.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
fig: The figure object
|
|
220
|
+
ax: The axes object to attach colorbar to
|
|
221
|
+
mappable: The mappable object (from scatter, imshow, etc.)
|
|
222
|
+
params: Visualization parameters for styling
|
|
223
|
+
label: Colorbar label
|
|
224
|
+
"""
|
|
225
|
+
divider = make_axes_locatable(ax)
|
|
226
|
+
cax = divider.append_axes(
|
|
227
|
+
"right", size=params.colorbar_size, pad=params.colorbar_pad
|
|
228
|
+
)
|
|
229
|
+
cbar = fig.colorbar(mappable, cax=cax)
|
|
230
|
+
if label:
|
|
231
|
+
cbar.set_label(label, fontsize=10)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# =============================================================================
|
|
235
|
+
# Data Structures for Unified Data Access
|
|
236
|
+
# =============================================================================
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class DeconvolutionData(NamedTuple):
|
|
240
|
+
"""Unified representation of deconvolution results.
|
|
241
|
+
|
|
242
|
+
Attributes:
|
|
243
|
+
proportions: DataFrame with cell type proportions (n_spots x n_cell_types)
|
|
244
|
+
method: Deconvolution method name (e.g., "cell2location", "rctd")
|
|
245
|
+
cell_types: List of cell type names
|
|
246
|
+
proportions_key: Key in adata.obsm where proportions are stored
|
|
247
|
+
dominant_type_key: Key in adata.obs for dominant cell type (if exists)
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
proportions: pd.DataFrame
|
|
251
|
+
method: str
|
|
252
|
+
cell_types: list[str]
|
|
253
|
+
proportions_key: str
|
|
254
|
+
dominant_type_key: Optional[str] = None
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class CellCommunicationData(NamedTuple):
|
|
258
|
+
"""Unified representation of cell communication analysis results.
|
|
259
|
+
|
|
260
|
+
Attributes:
|
|
261
|
+
results: Main results DataFrame (format varies by method)
|
|
262
|
+
method: Analysis method name ("liana_cluster", "liana_spatial", "cellphonedb")
|
|
263
|
+
analysis_type: Type of analysis ("cluster" or "spatial")
|
|
264
|
+
lr_pairs: List of ligand-receptor pair names
|
|
265
|
+
spatial_scores: Spatial communication scores array (n_spots x n_pairs)
|
|
266
|
+
spatial_pvals: P-values for spatial scores (optional)
|
|
267
|
+
source_labels: List of source cell type labels
|
|
268
|
+
target_labels: List of target cell type labels
|
|
269
|
+
results_key: Key in adata.uns where results are stored
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
results: pd.DataFrame
|
|
273
|
+
method: str
|
|
274
|
+
analysis_type: str # "cluster" or "spatial"
|
|
275
|
+
lr_pairs: list[str]
|
|
276
|
+
spatial_scores: Optional[np.ndarray] = None
|
|
277
|
+
spatial_pvals: Optional[np.ndarray] = None
|
|
278
|
+
source_labels: Optional[list[str]] = None
|
|
279
|
+
target_labels: Optional[list[str]] = None
|
|
280
|
+
results_key: str = ""
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# =============================================================================
|
|
284
|
+
# Feature Validation and Preparation
|
|
285
|
+
# =============================================================================
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
async def get_validated_features(
|
|
289
|
+
adata: ad.AnnData,
|
|
290
|
+
params: VisualizationParameters,
|
|
291
|
+
context: Optional["ToolContext"] = None,
|
|
292
|
+
max_features: Optional[int] = None,
|
|
293
|
+
genes_only: bool = False,
|
|
294
|
+
) -> list[str]:
|
|
295
|
+
"""Validate and return features for visualization.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
adata: AnnData object
|
|
299
|
+
params: Visualization parameters containing feature specification
|
|
300
|
+
context: Optional tool context for logging
|
|
301
|
+
max_features: Maximum number of features to return (truncates if exceeded)
|
|
302
|
+
genes_only: If True, only validate against var_names (genes).
|
|
303
|
+
If False, also check obs columns and obsm keys.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
List of validated feature names
|
|
307
|
+
"""
|
|
308
|
+
if params.feature is None:
|
|
309
|
+
features: list[str] = []
|
|
310
|
+
elif isinstance(params.feature, list):
|
|
311
|
+
features = params.feature
|
|
312
|
+
else:
|
|
313
|
+
features = [params.feature]
|
|
314
|
+
validated: list[str] = []
|
|
315
|
+
|
|
316
|
+
for feat in features:
|
|
317
|
+
# Check if feature is in var_names (genes)
|
|
318
|
+
if feat in adata.var_names:
|
|
319
|
+
validated.append(feat)
|
|
320
|
+
elif not genes_only:
|
|
321
|
+
# Also check obs columns and obsm keys
|
|
322
|
+
if feat in adata.obs.columns:
|
|
323
|
+
validated.append(feat)
|
|
324
|
+
elif feat in adata.obsm:
|
|
325
|
+
validated.append(feat)
|
|
326
|
+
else:
|
|
327
|
+
if context:
|
|
328
|
+
await context.warning(
|
|
329
|
+
f"Feature '{feat}' not found in genes, obs, or obsm"
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
if context:
|
|
333
|
+
await context.warning(f"Gene '{feat}' not found in var_names")
|
|
334
|
+
|
|
335
|
+
# Truncate if max_features specified
|
|
336
|
+
if max_features is not None and len(validated) > max_features:
|
|
337
|
+
if context:
|
|
338
|
+
await context.warning(
|
|
339
|
+
f"Too many features ({len(validated)}), limiting to {max_features}"
|
|
340
|
+
)
|
|
341
|
+
validated = validated[:max_features]
|
|
342
|
+
|
|
343
|
+
return validated
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def validate_and_prepare_feature(
|
|
347
|
+
adata: ad.AnnData,
|
|
348
|
+
feature: str,
|
|
349
|
+
context: Optional["ToolContext"] = None,
|
|
350
|
+
) -> tuple[np.ndarray, str, bool]:
|
|
351
|
+
"""Validate a single feature and prepare its data for visualization.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
adata: AnnData object
|
|
355
|
+
feature: Feature name to validate
|
|
356
|
+
context: Optional tool context for logging
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Tuple of (data array, display name, is_categorical)
|
|
360
|
+
"""
|
|
361
|
+
# Gene expression - use unified utility
|
|
362
|
+
if feature in adata.var_names:
|
|
363
|
+
data = get_gene_expression(adata, feature)
|
|
364
|
+
return data, feature, False
|
|
365
|
+
|
|
366
|
+
# Observation column
|
|
367
|
+
if feature in adata.obs.columns:
|
|
368
|
+
data = adata.obs[feature]
|
|
369
|
+
is_cat = pd.api.types.is_categorical_dtype(data) or data.dtype == object
|
|
370
|
+
return data.values, feature, is_cat
|
|
371
|
+
|
|
372
|
+
raise DataNotFoundError(f"Feature '{feature}' not found in data")
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
# =============================================================================
|
|
376
|
+
# Colormap Utilities
|
|
377
|
+
# =============================================================================
|
|
378
|
+
|
|
379
|
+
# Categorical colormaps by size threshold
|
|
380
|
+
_CATEGORICAL_CMAPS = {
|
|
381
|
+
10: "tab10", # Best for <= 10 categories
|
|
382
|
+
20: "tab20", # Best for 11-20 categories
|
|
383
|
+
40: "tab20b", # Extended palette for more categories
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def get_categorical_cmap(n_categories: int, user_cmap: Optional[str] = None) -> str:
|
|
388
|
+
"""Select the best categorical colormap based on number of categories.
|
|
389
|
+
|
|
390
|
+
This centralizes the categorical colormap selection logic that was
|
|
391
|
+
previously scattered across visualization modules.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
n_categories: Number of distinct categories to color
|
|
395
|
+
user_cmap: User-specified colormap (takes precedence if provided
|
|
396
|
+
and is a known categorical palette)
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Colormap name suitable for categorical data
|
|
400
|
+
|
|
401
|
+
Examples:
|
|
402
|
+
>>> get_categorical_cmap(5) # Returns "tab10"
|
|
403
|
+
>>> get_categorical_cmap(15) # Returns "tab20"
|
|
404
|
+
>>> get_categorical_cmap(8, user_cmap="Set2") # Returns "Set2"
|
|
405
|
+
"""
|
|
406
|
+
# Known categorical palettes that user might specify
|
|
407
|
+
categorical_palettes = {
|
|
408
|
+
"tab10", "tab20", "tab20b", "tab20c",
|
|
409
|
+
"Set1", "Set2", "Set3", "Paired", "Accent",
|
|
410
|
+
"Dark2", "Pastel1", "Pastel2",
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
# User preference takes precedence if it's a categorical palette
|
|
414
|
+
if user_cmap and user_cmap in categorical_palettes:
|
|
415
|
+
return user_cmap
|
|
416
|
+
|
|
417
|
+
# Auto-select based on category count
|
|
418
|
+
for threshold, cmap in sorted(_CATEGORICAL_CMAPS.items()):
|
|
419
|
+
if n_categories <= threshold:
|
|
420
|
+
return cmap
|
|
421
|
+
|
|
422
|
+
# Fallback for very large category counts
|
|
423
|
+
return "tab20"
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def get_category_colors(
|
|
427
|
+
n_categories: int,
|
|
428
|
+
cmap_name: Optional[str] = None,
|
|
429
|
+
) -> list:
|
|
430
|
+
"""Get a list of colors for categorical data.
|
|
431
|
+
|
|
432
|
+
This is the primary function for obtaining colors for categorical
|
|
433
|
+
visualizations. It handles colormap selection and color extraction.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
n_categories: Number of categories to color
|
|
437
|
+
cmap_name: Colormap name (auto-selected if None)
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
List of colors (can be used with matplotlib scatter, legend, etc.)
|
|
441
|
+
|
|
442
|
+
Examples:
|
|
443
|
+
>>> colors = get_category_colors(5) # 5 distinct colors
|
|
444
|
+
>>> colors = get_category_colors(15, "tab20") # 15 colors from tab20
|
|
445
|
+
"""
|
|
446
|
+
# Select appropriate colormap
|
|
447
|
+
if cmap_name is None:
|
|
448
|
+
cmap_name = get_categorical_cmap(n_categories)
|
|
449
|
+
|
|
450
|
+
# Seaborn palettes
|
|
451
|
+
if cmap_name in ["tab10", "tab20", "Set1", "Set2", "Set3", "Paired", "husl"]:
|
|
452
|
+
return sns.color_palette(cmap_name, n_colors=n_categories)
|
|
453
|
+
|
|
454
|
+
# Matplotlib colormaps
|
|
455
|
+
cmap = plt.get_cmap(cmap_name)
|
|
456
|
+
return [cmap(i / max(n_categories - 1, 1)) for i in range(n_categories)]
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def get_colormap(name: str, n_colors: Optional[int] = None):
|
|
460
|
+
"""Get a matplotlib colormap by name.
|
|
461
|
+
|
|
462
|
+
For categorical data, prefer using get_category_colors() instead.
|
|
463
|
+
This function is for backward compatibility and continuous colormaps.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
name: Colormap name (supports matplotlib and seaborn palettes)
|
|
467
|
+
n_colors: Number of discrete colors (for categorical data)
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
If n_colors is specified: List of colors (always indexable)
|
|
471
|
+
Otherwise: Colormap object (for continuous data)
|
|
472
|
+
"""
|
|
473
|
+
# For categorical with n_colors, delegate to specialized function
|
|
474
|
+
if n_colors:
|
|
475
|
+
return get_category_colors(n_colors, name)
|
|
476
|
+
|
|
477
|
+
# Check if it's a seaborn palette (return as palette for consistency)
|
|
478
|
+
if name in ["tab10", "tab20", "Set1", "Set2", "Set3", "Paired", "husl"]:
|
|
479
|
+
return sns.color_palette(name)
|
|
480
|
+
|
|
481
|
+
# For matplotlib colormaps, return the colormap object
|
|
482
|
+
return plt.get_cmap(name)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def get_diverging_colormap(center: float = 0.0) -> str:
|
|
486
|
+
"""Get an appropriate diverging colormap centered at a value."""
|
|
487
|
+
return "RdBu_r"
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
# =============================================================================
|
|
491
|
+
# Spatial Plot Utilities
|
|
492
|
+
# =============================================================================
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def plot_spatial_feature(
|
|
496
|
+
adata: ad.AnnData,
|
|
497
|
+
ax: plt.Axes,
|
|
498
|
+
feature: Optional[str] = None,
|
|
499
|
+
values: Optional[np.ndarray] = None,
|
|
500
|
+
params: Optional[VisualizationParameters] = None,
|
|
501
|
+
spatial_key: str = "spatial",
|
|
502
|
+
show_colorbar: bool = True,
|
|
503
|
+
title: Optional[str] = None,
|
|
504
|
+
) -> Optional[plt.cm.ScalarMappable]:
|
|
505
|
+
"""Plot a feature on spatial coordinates.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
adata: AnnData object with spatial coordinates
|
|
509
|
+
ax: Matplotlib axes to plot on
|
|
510
|
+
feature: Feature name (gene or obs column)
|
|
511
|
+
values: Pre-computed values to plot (overrides feature)
|
|
512
|
+
params: Visualization parameters
|
|
513
|
+
spatial_key: Key for spatial coordinates in obsm
|
|
514
|
+
show_colorbar: Whether to add a colorbar
|
|
515
|
+
title: Plot title
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
ScalarMappable for colorbar creation, or None for categorical data
|
|
519
|
+
"""
|
|
520
|
+
if params is None:
|
|
521
|
+
params = VisualizationParameters() # type: ignore[call-arg]
|
|
522
|
+
|
|
523
|
+
# Get spatial coordinates
|
|
524
|
+
coords = require_spatial_coords(adata, spatial_key=spatial_key)
|
|
525
|
+
|
|
526
|
+
# Get values to plot
|
|
527
|
+
if values is not None:
|
|
528
|
+
plot_values = values
|
|
529
|
+
is_categorical = pd.api.types.is_categorical_dtype(values)
|
|
530
|
+
elif feature is not None:
|
|
531
|
+
if feature in adata.var_names:
|
|
532
|
+
# Use unified utility for gene expression extraction
|
|
533
|
+
plot_values = get_gene_expression(adata, feature)
|
|
534
|
+
is_categorical = False
|
|
535
|
+
elif feature in adata.obs.columns:
|
|
536
|
+
plot_values = adata.obs[feature].values
|
|
537
|
+
is_categorical = pd.api.types.is_categorical_dtype(adata.obs[feature])
|
|
538
|
+
else:
|
|
539
|
+
raise DataNotFoundError(f"Feature '{feature}' not found")
|
|
540
|
+
else:
|
|
541
|
+
raise ParameterError("Either feature or values must be provided")
|
|
542
|
+
|
|
543
|
+
# Handle categorical data
|
|
544
|
+
if is_categorical:
|
|
545
|
+
categories = (
|
|
546
|
+
plot_values.categories
|
|
547
|
+
if hasattr(plot_values, "categories")
|
|
548
|
+
else np.unique(plot_values)
|
|
549
|
+
)
|
|
550
|
+
n_cats = len(categories)
|
|
551
|
+
colors = get_colormap(params.colormap, n_colors=n_cats)
|
|
552
|
+
cat_to_idx = {cat: i for i, cat in enumerate(categories)}
|
|
553
|
+
color_indices = [cat_to_idx[v] for v in plot_values]
|
|
554
|
+
|
|
555
|
+
scatter = ax.scatter(
|
|
556
|
+
coords[:, 0],
|
|
557
|
+
coords[:, 1],
|
|
558
|
+
c=[colors[i] for i in color_indices],
|
|
559
|
+
s=params.spot_size,
|
|
560
|
+
alpha=params.alpha,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
# Add legend for categorical
|
|
564
|
+
if params.show_legend:
|
|
565
|
+
handles = [
|
|
566
|
+
plt.Line2D(
|
|
567
|
+
[0],
|
|
568
|
+
[0],
|
|
569
|
+
marker="o",
|
|
570
|
+
color="w",
|
|
571
|
+
markerfacecolor=colors[i],
|
|
572
|
+
markersize=8,
|
|
573
|
+
)
|
|
574
|
+
for i in range(n_cats)
|
|
575
|
+
]
|
|
576
|
+
ax.legend(
|
|
577
|
+
handles,
|
|
578
|
+
categories,
|
|
579
|
+
loc="center left",
|
|
580
|
+
bbox_to_anchor=(1, 0.5),
|
|
581
|
+
fontsize=8,
|
|
582
|
+
)
|
|
583
|
+
mappable = None
|
|
584
|
+
else:
|
|
585
|
+
# Continuous data
|
|
586
|
+
cmap = get_colormap(params.colormap)
|
|
587
|
+
scatter = ax.scatter(
|
|
588
|
+
coords[:, 0],
|
|
589
|
+
coords[:, 1],
|
|
590
|
+
c=plot_values,
|
|
591
|
+
cmap=cmap,
|
|
592
|
+
s=params.spot_size,
|
|
593
|
+
alpha=params.alpha,
|
|
594
|
+
vmin=params.vmin,
|
|
595
|
+
vmax=params.vmax,
|
|
596
|
+
)
|
|
597
|
+
mappable = scatter
|
|
598
|
+
|
|
599
|
+
ax.set_aspect("equal")
|
|
600
|
+
ax.set_xlabel("")
|
|
601
|
+
ax.set_ylabel("")
|
|
602
|
+
|
|
603
|
+
if not params.show_axes:
|
|
604
|
+
ax.axis("off")
|
|
605
|
+
|
|
606
|
+
if title:
|
|
607
|
+
ax.set_title(title, fontsize=12)
|
|
608
|
+
|
|
609
|
+
return mappable
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
# =============================================================================
|
|
613
|
+
# Data Inference Utilities
|
|
614
|
+
# =============================================================================
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def get_categorical_columns(
|
|
618
|
+
adata: ad.AnnData,
|
|
619
|
+
limit: Optional[int] = None,
|
|
620
|
+
) -> list[str]:
|
|
621
|
+
"""Get categorical column names from adata.obs.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
adata: AnnData object
|
|
625
|
+
limit: Maximum number of columns to return (None for all)
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
List of categorical column names
|
|
629
|
+
"""
|
|
630
|
+
categorical_cols = [
|
|
631
|
+
col
|
|
632
|
+
for col in adata.obs.columns
|
|
633
|
+
if adata.obs[col].dtype.name in ["object", "category"]
|
|
634
|
+
]
|
|
635
|
+
if limit is not None:
|
|
636
|
+
return categorical_cols[:limit]
|
|
637
|
+
return categorical_cols
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def infer_basis(
|
|
641
|
+
adata: ad.AnnData,
|
|
642
|
+
preferred: Optional[str] = None,
|
|
643
|
+
priority: Optional[list[str]] = None,
|
|
644
|
+
) -> Optional[str]:
|
|
645
|
+
"""Infer the best embedding basis from available options.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
adata: AnnData object
|
|
649
|
+
preferred: User-specified preferred basis (returned if valid)
|
|
650
|
+
priority: Priority order for basis selection.
|
|
651
|
+
Default: ["spatial", "umap", "pca"]
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
Best available basis name (without X_ prefix), or None if none found
|
|
655
|
+
|
|
656
|
+
Examples:
|
|
657
|
+
>>> infer_basis(adata) # Auto-detect: spatial > umap > pca
|
|
658
|
+
'umap'
|
|
659
|
+
>>> infer_basis(adata, preferred='tsne') # Use if valid
|
|
660
|
+
'tsne'
|
|
661
|
+
>>> infer_basis(adata, priority=['umap', 'spatial']) # Custom order
|
|
662
|
+
'umap'
|
|
663
|
+
"""
|
|
664
|
+
if priority is None:
|
|
665
|
+
priority = ["spatial", "umap", "pca"]
|
|
666
|
+
|
|
667
|
+
# Check preferred basis first
|
|
668
|
+
if preferred:
|
|
669
|
+
key = preferred if preferred == "spatial" else f"X_{preferred}"
|
|
670
|
+
if key in adata.obsm:
|
|
671
|
+
return preferred
|
|
672
|
+
|
|
673
|
+
# Check priority list
|
|
674
|
+
for basis in priority:
|
|
675
|
+
key = basis if basis == "spatial" else f"X_{basis}"
|
|
676
|
+
if key in adata.obsm:
|
|
677
|
+
return basis
|
|
678
|
+
|
|
679
|
+
# Fallback: return first available X_* key
|
|
680
|
+
for key in adata.obsm.keys():
|
|
681
|
+
if key.startswith("X_"):
|
|
682
|
+
return key[2:] # Strip X_ prefix
|
|
683
|
+
|
|
684
|
+
return None
|