spatialcore 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. spatialcore/__init__.py +122 -0
  2. spatialcore/annotation/__init__.py +253 -0
  3. spatialcore/annotation/acquisition.py +529 -0
  4. spatialcore/annotation/annotate.py +603 -0
  5. spatialcore/annotation/cellxgene.py +365 -0
  6. spatialcore/annotation/confidence.py +802 -0
  7. spatialcore/annotation/discovery.py +529 -0
  8. spatialcore/annotation/expression.py +363 -0
  9. spatialcore/annotation/loading.py +529 -0
  10. spatialcore/annotation/markers.py +297 -0
  11. spatialcore/annotation/ontology.py +1282 -0
  12. spatialcore/annotation/patterns.py +247 -0
  13. spatialcore/annotation/pipeline.py +620 -0
  14. spatialcore/annotation/synapse.py +380 -0
  15. spatialcore/annotation/training.py +1457 -0
  16. spatialcore/annotation/validation.py +422 -0
  17. spatialcore/core/__init__.py +34 -0
  18. spatialcore/core/cache.py +118 -0
  19. spatialcore/core/logging.py +135 -0
  20. spatialcore/core/metadata.py +149 -0
  21. spatialcore/core/utils.py +768 -0
  22. spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
  23. spatialcore/data/markers/canonical_markers.json +83 -0
  24. spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
  25. spatialcore/plotting/__init__.py +109 -0
  26. spatialcore/plotting/benchmark.py +477 -0
  27. spatialcore/plotting/celltype.py +329 -0
  28. spatialcore/plotting/confidence.py +413 -0
  29. spatialcore/plotting/spatial.py +505 -0
  30. spatialcore/plotting/utils.py +411 -0
  31. spatialcore/plotting/validation.py +1342 -0
  32. spatialcore-0.1.9.dist-info/METADATA +213 -0
  33. spatialcore-0.1.9.dist-info/RECORD +36 -0
  34. spatialcore-0.1.9.dist-info/WHEEL +5 -0
  35. spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
  36. spatialcore-0.1.9.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1342 @@
1
+ """
2
+ Marker validation visualization.
3
+
4
+ This module provides functions for validating cell type annotations
5
+ using canonical marker genes and GMM-3 thresholding.
6
+
7
+ For marker validation, we use classify_by_threshold with n_components=3
8
+ (trimodal GMM) which handles spatial data's dropout/moderate/high
9
+ expression patterns better than bimodal GMM.
10
+ """
11
+
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+ import gc
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import matplotlib.pyplot as plt
19
+ from matplotlib.figure import Figure
20
+ import anndata as ad
21
+
22
+ from spatialcore.core.logging import get_logger
23
+ from spatialcore.plotting.utils import (
24
+ generate_celltype_palette,
25
+ setup_figure,
26
+ setup_multi_figure,
27
+ save_figure,
28
+ despine,
29
+ )
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ def plot_marker_heatmap(
35
+ adata: ad.AnnData,
36
+ label_column: str,
37
+ markers: Optional[Dict[str, List[str]]] = None,
38
+ cluster: bool = True,
39
+ layer: Optional[str] = None,
40
+ figsize: Optional[tuple] = None,
41
+ cmap: str = "RdBu_r",
42
+ center: float = 0,
43
+ title: Optional[str] = None,
44
+ save: Optional[Union[str, Path]] = None,
45
+ ) -> Figure:
46
+ """
47
+ Plot marker gene expression heatmap by cell type.
48
+
49
+ Parameters
50
+ ----------
51
+ adata : AnnData
52
+ Annotated data with cell type labels.
53
+ label_column : str
54
+ Column in adata.obs containing cell type labels.
55
+ markers : Dict[str, List[str]], optional
56
+ Marker genes per cell type. If None, uses canonical markers.
57
+ cluster : bool, default True
58
+ Hierarchically cluster cell types.
59
+ layer : str, optional
60
+ Layer to use. If None, uses adata.X.
61
+ figsize : tuple, optional
62
+ Figure size. Auto-calculated if None.
63
+ cmap : str, default "RdBu_r"
64
+ Colormap.
65
+ center : float, default 0
66
+ Value to center colormap on.
67
+ title : str, optional
68
+ Plot title.
69
+ save : str or Path, optional
70
+ Path to save figure.
71
+
72
+ Returns
73
+ -------
74
+ Figure
75
+ Matplotlib figure.
76
+
77
+ Examples
78
+ --------
79
+ >>> from spatialcore.plotting.validation import plot_marker_heatmap
80
+ >>> from spatialcore.annotation.markers import CANONICAL_MARKERS
81
+ >>> fig = plot_marker_heatmap(
82
+ ... adata,
83
+ ... label_column="cell_type",
84
+ ... markers=CANONICAL_MARKERS,
85
+ ... )
86
+ """
87
+ try:
88
+ import seaborn as sns
89
+ except ImportError:
90
+ raise ImportError("seaborn is required for heatmaps")
91
+
92
+ if label_column not in adata.obs.columns:
93
+ raise ValueError(f"Label column '{label_column}' not found.")
94
+
95
+ # Load canonical markers if not provided
96
+ if markers is None:
97
+ from spatialcore.annotation.markers import load_canonical_markers
98
+ markers = load_canonical_markers()
99
+
100
+ # Collect all marker genes that exist in the data
101
+ all_genes = []
102
+ cell_types_with_markers = []
103
+
104
+ for cell_type in adata.obs[label_column].unique():
105
+ ct_lower = str(cell_type).lower()
106
+ ct_markers = markers.get(ct_lower, [])
107
+ available = [g for g in ct_markers if g in adata.var_names]
108
+ if available:
109
+ all_genes.extend(available)
110
+ cell_types_with_markers.append(cell_type)
111
+
112
+ all_genes = list(dict.fromkeys(all_genes)) # Remove duplicates, preserve order
113
+
114
+ if not all_genes:
115
+ raise ValueError("No marker genes found in data.")
116
+
117
+ # Calculate mean expression per cell type
118
+ cell_types = adata.obs[label_column].unique()
119
+ mean_expr = pd.DataFrame(index=cell_types, columns=all_genes, dtype=float)
120
+
121
+ for ct in cell_types:
122
+ mask = adata.obs[label_column] == ct
123
+ subset = adata[mask, all_genes]
124
+ X = subset.layers[layer] if layer else subset.X
125
+ if hasattr(X, "toarray"):
126
+ X = X.toarray()
127
+ mean_expr.loc[ct] = np.mean(X, axis=0)
128
+
129
+ # Z-score normalize columns
130
+ mean_expr_z = (mean_expr - mean_expr.mean()) / mean_expr.std()
131
+ mean_expr_z = mean_expr_z.fillna(0)
132
+
133
+ # Figure size
134
+ if figsize is None:
135
+ n_types = len(cell_types)
136
+ n_genes = len(all_genes)
137
+ figsize = (max(10, n_genes * 0.3), max(6, n_types * 0.4))
138
+
139
+ # Cluster if requested
140
+ if cluster:
141
+ g = sns.clustermap(
142
+ mean_expr_z,
143
+ cmap=cmap,
144
+ center=center,
145
+ figsize=figsize,
146
+ row_cluster=True,
147
+ col_cluster=True,
148
+ )
149
+ if title:
150
+ g.fig.suptitle(title, y=1.02)
151
+ if save:
152
+ g.savefig(save)
153
+ return g.fig
154
+ else:
155
+ fig, ax = setup_figure(figsize=figsize)
156
+ sns.heatmap(
157
+ mean_expr_z,
158
+ cmap=cmap,
159
+ center=center,
160
+ ax=ax,
161
+ xticklabels=True,
162
+ yticklabels=True,
163
+ )
164
+ ax.set_xlabel("Marker Genes")
165
+ ax.set_ylabel("Cell Types")
166
+
167
+ if title is None:
168
+ title = "Marker Expression Heatmap"
169
+ ax.set_title(title)
170
+
171
+ plt.tight_layout()
172
+
173
+ if save:
174
+ save_figure(fig, save)
175
+
176
+ return fig
177
+
178
+
179
+ def plot_2d_validation(
180
+ adata: ad.AnnData,
181
+ label_column: str,
182
+ confidence_column: str,
183
+ markers: Optional[Dict[str, List[str]]] = None,
184
+ confidence_threshold: float = 0.8,
185
+ min_cells_per_type: int = 15,
186
+ n_components: int = 3,
187
+ ncols: int = 4,
188
+ figsize_per_panel: Tuple[float, float] = (3, 3),
189
+ save: Optional[Union[str, Path]] = None,
190
+ ) -> Tuple[Figure, pd.DataFrame]:
191
+ """
192
+ 2D marker validation plot per cell type.
193
+
194
+ For each cell type, plots confidence (x-axis) vs marker metagene score
195
+ (y-axis). Cells are colored green if above both thresholds, red otherwise.
196
+
197
+ Uses GMM-3 via classify_by_threshold() to find marker threshold.
198
+
199
+ Parameters
200
+ ----------
201
+ adata : AnnData
202
+ Annotated data with cell type labels and confidence.
203
+ label_column : str
204
+ Column in adata.obs containing cell type labels.
205
+ confidence_column : str
206
+ Column in adata.obs containing confidence values.
207
+ markers : Dict[str, List[str]], optional
208
+ Marker genes per cell type. If None, uses canonical markers.
209
+ confidence_threshold : float, default 0.8
210
+ Confidence threshold for validation.
211
+ min_cells_per_type : int, default 15
212
+ Minimum cells required to plot a cell type.
213
+ n_components : int, default 3
214
+ Number of GMM components (3 for trimodal spatial data).
215
+ ncols : int, default 4
216
+ Number of columns in subplot grid.
217
+ figsize_per_panel : Tuple[float, float], default (3, 3)
218
+ Size per panel.
219
+ save : str or Path, optional
220
+ Path to save figure.
221
+
222
+ Returns
223
+ -------
224
+ Tuple[Figure, pd.DataFrame]
225
+ Figure and validation summary DataFrame.
226
+
227
+ Examples
228
+ --------
229
+ >>> from spatialcore.plotting.validation import plot_2d_validation
230
+ >>> fig, summary = plot_2d_validation(
231
+ ... adata,
232
+ ... label_column="cell_type",
233
+ ... confidence_column="confidence",
234
+ ... confidence_threshold=0.7,
235
+ ... )
236
+ >>> print(summary)
237
+ """
238
+ from spatialcore.stats.classify import classify_by_threshold
239
+
240
+ if label_column not in adata.obs.columns:
241
+ raise ValueError(f"Label column '{label_column}' not found.")
242
+ if confidence_column not in adata.obs.columns:
243
+ raise ValueError(f"Confidence column '{confidence_column}' not found.")
244
+
245
+ # Load canonical markers if not provided
246
+ if markers is None:
247
+ from spatialcore.annotation.markers import load_canonical_markers
248
+ markers = load_canonical_markers()
249
+
250
+ # Find cell types with enough cells and available markers
251
+ cell_types = adata.obs[label_column].value_counts()
252
+ cell_types = cell_types[cell_types >= min_cells_per_type].index.tolist()
253
+
254
+ types_to_plot = []
255
+ for ct in cell_types:
256
+ ct_lower = str(ct).lower()
257
+ ct_markers = markers.get(ct_lower, [])
258
+ available = [g for g in ct_markers if g in adata.var_names]
259
+ if len(available) >= 2: # Need at least 2 markers
260
+ types_to_plot.append((ct, available))
261
+
262
+ if not types_to_plot:
263
+ raise ValueError("No cell types with sufficient markers found.")
264
+
265
+ # Pre-compute GMM results for all cell types to know which will succeed
266
+ successful_types = []
267
+ for cell_type, ct_markers in types_to_plot:
268
+ mask = adata.obs[label_column] == cell_type
269
+ subset = adata[mask].copy()
270
+ confidence = subset.obs[confidence_column].values
271
+
272
+ try:
273
+ subset_result = classify_by_threshold(
274
+ subset,
275
+ feature_columns=ct_markers,
276
+ threshold_method="gmm",
277
+ n_components=n_components,
278
+ metagene_method="shifted_geometric_mean",
279
+ plot=False,
280
+ copy=True,
281
+ )
282
+ metagene_scores = subset_result.obs["threshold_score"].values
283
+ marker_threshold = subset_result.uns.get(
284
+ "threshold_params", {}
285
+ ).get("threshold", np.median(metagene_scores))
286
+
287
+ successful_types.append({
288
+ "cell_type": cell_type,
289
+ "markers": ct_markers,
290
+ "subset": subset,
291
+ "confidence": confidence,
292
+ "metagene_scores": metagene_scores,
293
+ "marker_threshold": marker_threshold,
294
+ })
295
+ except Exception as e:
296
+ logger.warning(f"Skipping {cell_type}: GMM failed - {e}")
297
+ continue
298
+
299
+ if not successful_types:
300
+ raise ValueError("GMM failed for all cell types.")
301
+
302
+ # Create grid with only successful cell types
303
+ n_types = len(successful_types)
304
+ nrows = int(np.ceil(n_types / ncols))
305
+ figsize = (figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows)
306
+
307
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
308
+ axes = np.atleast_2d(axes).flatten()
309
+
310
+ summary_rows = []
311
+
312
+ for i, data in enumerate(successful_types):
313
+ ax = axes[i]
314
+ cell_type = data["cell_type"]
315
+ ct_markers = data["markers"]
316
+ subset = data["subset"]
317
+ confidence = data["confidence"]
318
+ metagene_scores = data["metagene_scores"]
319
+ marker_threshold = data["marker_threshold"]
320
+
321
+ # Classify cells into three groups per spec section 3.1:
322
+ # - Red: Low confidence (uncertain)
323
+ # - Green: High confidence only (validated)
324
+ # - Yellow: High conf + High marker (strongly validated)
325
+ high_conf = confidence >= confidence_threshold
326
+ high_marker = metagene_scores >= marker_threshold
327
+
328
+ low_conf = ~high_conf # Red
329
+ high_conf_low_marker = high_conf & ~high_marker # Green
330
+ high_conf_high_marker = high_conf & high_marker # Yellow
331
+
332
+ # Plot in order: red (back), green, yellow (front)
333
+ ax.scatter(
334
+ confidence[low_conf],
335
+ metagene_scores[low_conf],
336
+ c="red",
337
+ s=5,
338
+ alpha=0.5,
339
+ label="Low Conf",
340
+ )
341
+ ax.scatter(
342
+ confidence[high_conf_low_marker],
343
+ metagene_scores[high_conf_low_marker],
344
+ c="green",
345
+ s=5,
346
+ alpha=0.5,
347
+ label="High Conf",
348
+ )
349
+ ax.scatter(
350
+ confidence[high_conf_high_marker],
351
+ metagene_scores[high_conf_high_marker],
352
+ c="gold",
353
+ s=5,
354
+ alpha=0.5,
355
+ label="High Conf + Marker",
356
+ )
357
+
358
+ # Threshold lines
359
+ ax.axvline(confidence_threshold, color="gray", linestyle="--", alpha=0.5)
360
+ ax.axhline(marker_threshold, color="gray", linestyle="--", alpha=0.5)
361
+
362
+ ax.set_xlabel("Confidence")
363
+ ax.set_ylabel("Marker Score")
364
+ ax.set_xlim(0, 1) # Fixed x-axis for consistent comparison
365
+ ax.set_xticks([0, 0.5, 1.0])
366
+ ax.set_title(f"{cell_type}\n(n={len(subset)})", fontsize=9)
367
+
368
+ # Summary stats
369
+ n_low_conf = low_conf.sum()
370
+ n_high_conf = high_conf_low_marker.sum()
371
+ n_validated = high_conf_high_marker.sum()
372
+ pct_validated = 100 * n_validated / len(subset) if len(subset) > 0 else 0
373
+ pct_high_conf = 100 * (n_high_conf + n_validated) / len(subset) if len(subset) > 0 else 0
374
+
375
+ summary_rows.append({
376
+ "cell_type": cell_type,
377
+ "n_cells": len(subset),
378
+ "n_low_conf": n_low_conf,
379
+ "n_high_conf": n_high_conf,
380
+ "n_validated": n_validated,
381
+ "pct_validated": pct_validated,
382
+ "pct_high_conf": pct_high_conf,
383
+ "marker_threshold": marker_threshold,
384
+ "n_markers": len(ct_markers),
385
+ })
386
+
387
+ # Hide unused panels (when n_types doesn't fill full grid)
388
+ for i in range(len(successful_types), len(axes)):
389
+ axes[i].set_visible(False)
390
+
391
+ plt.tight_layout()
392
+
393
+ if save:
394
+ save_figure(fig, save)
395
+
396
+ summary_df = pd.DataFrame(summary_rows)
397
+ return fig, summary_df
398
+
399
+
400
+ def plot_marker_dotplot(
401
+ adata: ad.AnnData,
402
+ label_column: str,
403
+ markers: Optional[Dict[str, List[str]]] = None,
404
+ layer: Optional[str] = None,
405
+ figsize: Optional[tuple] = None,
406
+ cmap: str = "Reds",
407
+ title: Optional[str] = None,
408
+ save: Optional[Union[str, Path]] = None,
409
+ ) -> Figure:
410
+ """
411
+ Plot marker expression as dot plot.
412
+
413
+ Dot size represents fraction of cells expressing the marker,
414
+ color intensity represents mean expression.
415
+
416
+ Parameters
417
+ ----------
418
+ adata : AnnData
419
+ Annotated data with cell type labels.
420
+ label_column : str
421
+ Column in adata.obs containing cell type labels.
422
+ markers : Dict[str, List[str]], optional
423
+ Marker genes per cell type.
424
+ layer : str, optional
425
+ Layer to use.
426
+ figsize : tuple, optional
427
+ Figure size.
428
+ cmap : str, default "Reds"
429
+ Colormap.
430
+ title : str, optional
431
+ Plot title.
432
+ save : str or Path, optional
433
+ Path to save figure.
434
+
435
+ Returns
436
+ -------
437
+ Figure
438
+ Matplotlib figure.
439
+ """
440
+ if label_column not in adata.obs.columns:
441
+ raise ValueError(f"Label column '{label_column}' not found.")
442
+
443
+ # Load canonical markers if not provided
444
+ if markers is None:
445
+ from spatialcore.annotation.markers import load_canonical_markers
446
+ markers = load_canonical_markers()
447
+
448
+ # Collect all marker genes
449
+ all_genes = []
450
+ for ct in adata.obs[label_column].unique():
451
+ ct_lower = str(ct).lower()
452
+ ct_markers = markers.get(ct_lower, [])
453
+ available = [g for g in ct_markers if g in adata.var_names]
454
+ all_genes.extend(available)
455
+
456
+ all_genes = list(dict.fromkeys(all_genes))
457
+
458
+ if not all_genes:
459
+ raise ValueError("No marker genes found in data.")
460
+
461
+ cell_types = sorted(adata.obs[label_column].unique())
462
+
463
+ # Calculate expression fraction and mean
464
+ n_types = len(cell_types)
465
+ n_genes = len(all_genes)
466
+
467
+ frac_expr = np.zeros((n_types, n_genes))
468
+ mean_expr = np.zeros((n_types, n_genes))
469
+
470
+ for i, ct in enumerate(cell_types):
471
+ mask = adata.obs[label_column] == ct
472
+ subset = adata[mask, all_genes]
473
+ X = subset.layers[layer] if layer else subset.X
474
+ if hasattr(X, "toarray"):
475
+ X = X.toarray()
476
+
477
+ frac_expr[i] = (X > 0).mean(axis=0)
478
+ mean_expr[i] = X.mean(axis=0)
479
+
480
+ # Normalize mean expression per gene
481
+ mean_expr_norm = (mean_expr - mean_expr.min(axis=0)) / (
482
+ mean_expr.max(axis=0) - mean_expr.min(axis=0) + 1e-10
483
+ )
484
+
485
+ if figsize is None:
486
+ figsize = (max(10, n_genes * 0.4), max(6, n_types * 0.4))
487
+
488
+ fig, ax = setup_figure(figsize=figsize)
489
+
490
+ # Create dot plot
491
+ for i, ct in enumerate(cell_types):
492
+ for j, gene in enumerate(all_genes):
493
+ size = frac_expr[i, j] * 300 # Scale for visibility
494
+ color = mean_expr_norm[i, j]
495
+ ax.scatter(j, i, s=size, c=[color], cmap=cmap, vmin=0, vmax=1)
496
+
497
+ ax.set_xticks(range(n_genes))
498
+ ax.set_xticklabels(all_genes, rotation=90)
499
+ ax.set_yticks(range(n_types))
500
+ ax.set_yticklabels(cell_types)
501
+
502
+ ax.set_xlabel("Marker Genes")
503
+ ax.set_ylabel("Cell Types")
504
+
505
+ if title is None:
506
+ title = "Marker Expression Dotplot"
507
+ ax.set_title(title)
508
+
509
+ # Add legend for size
510
+ size_legend = [0.25, 0.5, 0.75, 1.0]
511
+ legend_x = n_genes + 0.5
512
+ for k, frac in enumerate(size_legend):
513
+ ax.scatter(legend_x, k, s=frac * 300, c="gray")
514
+ ax.text(legend_x + 0.3, k, f"{int(frac*100)}%", va="center")
515
+ ax.text(legend_x, len(size_legend) + 0.5, "% Expressing", fontsize=9)
516
+
517
+ ax.set_xlim(-0.5, n_genes + 2)
518
+
519
+ plt.tight_layout()
520
+
521
+ if save:
522
+ save_figure(fig, save)
523
+
524
+ return fig
525
+
526
+
527
+ def plot_celltype_confidence(
528
+ adata: ad.AnnData,
529
+ label_column: str,
530
+ confidence_column: str,
531
+ spatial_key: str = "spatial",
532
+ threshold: float = 0.8,
533
+ max_cell_types: int = 20,
534
+ figsize: Tuple[float, float] = (14, 6),
535
+ save: Optional[Union[str, Path]] = None,
536
+ ) -> Figure:
537
+ """
538
+ Cell type confidence visualization with spatial and jitter plots.
539
+
540
+ Creates a two-panel figure:
541
+ - Left panel: Spatial scatter plot colored by cell type
542
+ - Right panel: Jitter plot (x=cell type, y=confidence) with threshold line
543
+
544
+ Parameters
545
+ ----------
546
+ adata : AnnData
547
+ Annotated data with cell type labels and confidence scores.
548
+ label_column : str
549
+ Column in adata.obs containing cell type labels.
550
+ confidence_column : str
551
+ Column in adata.obs containing confidence values.
552
+ spatial_key : str, default "spatial"
553
+ Key in adata.obsm for spatial coordinates.
554
+ threshold : float, default 0.8
555
+ Confidence threshold line to display.
556
+ max_cell_types : int, default 20
557
+ Maximum number of cell types to show in jitter plot.
558
+ figsize : Tuple[float, float], default (14, 6)
559
+ Figure size (width, height).
560
+ save : str or Path, optional
561
+ Path to save figure.
562
+
563
+ Returns
564
+ -------
565
+ Figure
566
+ Matplotlib figure.
567
+
568
+ Examples
569
+ --------
570
+ >>> from spatialcore.plotting.validation import plot_celltype_confidence
571
+ >>> fig = plot_celltype_confidence(
572
+ ... adata,
573
+ ... label_column="celltypist",
574
+ ... confidence_column="celltypist_confidence_transformed",
575
+ ... threshold=0.8,
576
+ ... )
577
+ """
578
+ if label_column not in adata.obs.columns:
579
+ raise ValueError(f"Label column '{label_column}' not found.")
580
+ if confidence_column not in adata.obs.columns:
581
+ raise ValueError(f"Confidence column '{confidence_column}' not found.")
582
+
583
+ fig, (ax_spatial, ax_jitter) = plt.subplots(1, 2, figsize=figsize)
584
+
585
+ # Get cell types and palette
586
+ cell_types = adata.obs[label_column].unique()
587
+ colors = generate_celltype_palette(cell_types)
588
+
589
+ # Left: Spatial plot colored by confidence z-score
590
+ if spatial_key in adata.obsm:
591
+ coords = adata.obsm[spatial_key]
592
+ conf_values = adata.obs[confidence_column].values
593
+
594
+ # Create scatter colored by confidence
595
+ scatter = ax_spatial.scatter(
596
+ coords[:, 0],
597
+ coords[:, 1],
598
+ s=1,
599
+ alpha=0.7,
600
+ c=conf_values,
601
+ cmap="RdYlGn", # Red (low) -> Yellow -> Green (high)
602
+ vmin=np.nanpercentile(conf_values, 5),
603
+ vmax=np.nanpercentile(conf_values, 95),
604
+ rasterized=True,
605
+ )
606
+ cbar = plt.colorbar(scatter, ax=ax_spatial, shrink=0.7, pad=0.02)
607
+ cbar.set_label("Confidence (z-score)", fontsize=9)
608
+ ax_spatial.set_title("Spatial Confidence")
609
+ ax_spatial.set_xlabel("X")
610
+ ax_spatial.set_ylabel("Y")
611
+ ax_spatial.axis("equal")
612
+ despine(ax_spatial)
613
+ else:
614
+ ax_spatial.text(
615
+ 0.5, 0.5,
616
+ f"No spatial coordinates found\n(key: '{spatial_key}')",
617
+ ha="center", va="center", transform=ax_spatial.transAxes
618
+ )
619
+ ax_spatial.set_title("Spatial Confidence (N/A)")
620
+
621
+ # Right: Jitter plot - sort by median confidence (flipped: cell types on Y-axis)
622
+ ct_median = adata.obs.groupby(label_column, observed=True)[confidence_column].median()
623
+ ct_order = ct_median.sort_values(ascending=True).index.tolist() # Ascending so highest at top
624
+
625
+ # Limit number of cell types shown
626
+ if len(ct_order) > max_cell_types:
627
+ # Keep top by median confidence (will be at top of plot)
628
+ ct_order = ct_median.sort_values(ascending=False).index.tolist()[:max_cell_types]
629
+ ct_order = ct_order[::-1] # Reverse so highest at top
630
+ logger.info(f"Showing top {max_cell_types} cell types by median confidence")
631
+
632
+ for i, ct in enumerate(ct_order):
633
+ mask = adata.obs[label_column] == ct
634
+ conf = adata.obs.loc[mask, confidence_column].values
635
+ # Add jitter on Y-axis (cell type position)
636
+ y = np.random.normal(i, 0.15, len(conf))
637
+ color = colors.get(str(ct), "gray")
638
+ ax_jitter.scatter(conf, y, s=3, alpha=0.3, c=[color], rasterized=True)
639
+
640
+ # Threshold line (vertical now)
641
+ ax_jitter.axvline(
642
+ threshold, color="red", linestyle="--", lw=2, label=f"Threshold={threshold}"
643
+ )
644
+
645
+ ax_jitter.set_yticks(range(len(ct_order)))
646
+ ax_jitter.set_yticklabels(
647
+ [str(ct)[:25] for ct in ct_order], # Truncate long names
648
+ fontsize=8
649
+ )
650
+
651
+ # Set x-axis to fixed 0-1 range with 0.5 increments for consistent comparison
652
+ ax_jitter.set_xlim(0, 1)
653
+ ax_jitter.set_xticks([0, 0.5, 1.0])
654
+
655
+ ax_jitter.set_xlabel("Confidence (z-score)")
656
+ ax_jitter.set_ylabel("")
657
+ ax_jitter.set_title("Confidence by Cell Type")
658
+ ax_jitter.legend(loc="lower right")
659
+ despine(ax_jitter)
660
+
661
+ plt.tight_layout()
662
+
663
+ if save:
664
+ save_figure(fig, save)
665
+
666
+ return fig
667
+
668
+
669
+ def plot_deg_heatmap(
670
+ adata: ad.AnnData,
671
+ label_column: str,
672
+ n_genes: int = 5,
673
+ method: str = "wilcoxon",
674
+ layer: Optional[str] = None,
675
+ figsize: Optional[Tuple[float, float]] = None,
676
+ cmap: str = "viridis",
677
+ save: Optional[Union[str, Path]] = None,
678
+ title: Optional[str] = None,
679
+ ) -> Figure:
680
+ """
681
+ DEG heatmap showing top marker genes per cell type.
682
+
683
+ Runs differential expression analysis and displays a heatmap with genes
684
+ on rows, cell types on columns, row color annotation bar, and cell type legend.
685
+ Uses exact plotting code from celltypist_demonstration_plots.py.
686
+
687
+ Parameters
688
+ ----------
689
+ adata : AnnData
690
+ Annotated data with cell type labels.
691
+ label_column : str
692
+ Column in adata.obs containing cell type labels.
693
+ n_genes : int, default 5
694
+ Number of top DEGs per cell type to include.
695
+ method : str, default "wilcoxon"
696
+ DEG method ("wilcoxon", "t-test", "t-test_overestim_var", "logreg").
697
+ layer : str, optional
698
+ Expression layer to use. If None, uses adata.X.
699
+ figsize : Tuple[float, float], optional
700
+ Figure size. Auto-calculated if None.
701
+ cmap : str, default "viridis"
702
+ Colormap for heatmap.
703
+ save : str or Path, optional
704
+ Path to save figure.
705
+ title : str, optional
706
+ Plot title. Defaults to "Marker Genes ({label_column})".
707
+
708
+ Returns
709
+ -------
710
+ Figure
711
+ Matplotlib figure.
712
+ """
713
+ import scanpy as sc
714
+ import seaborn as sns
715
+
716
+ if label_column not in adata.obs.columns:
717
+ raise ValueError(f"Label column '{label_column}' not found.")
718
+
719
+ # Build filter mask on original data (no copies)
720
+ mask_assigned = adata.obs[label_column] != "Unassigned"
721
+
722
+ # Get valid cell types from view (no copy)
723
+ ct_counts = adata.obs.loc[mask_assigned, label_column].value_counts()
724
+ valid_cts = ct_counts[ct_counts >= 10].index.tolist()
725
+ if len(valid_cts) < 2:
726
+ raise ValueError("Need at least 2 cell types with >= 10 cells each for DEG analysis.")
727
+
728
+ # Combined mask - applied once
729
+ mask = mask_assigned & adata.obs[label_column].isin(valid_cts)
730
+
731
+ # Manual AnnData construction to avoid anndata 0.12.x memory bug in adata[mask].copy()
732
+ indices = np.where(mask)[0]
733
+ X_sub = adata.X[indices]
734
+ if layer and layer in adata.layers:
735
+ X_sub = adata.layers[layer][indices]
736
+ obs_sub = adata.obs.iloc[indices].copy()
737
+ obs_sub[label_column] = obs_sub[label_column].astype(str).astype("category")
738
+ adata_deg = ad.AnnData(X=X_sub, obs=obs_sub, var=adata.var.copy())
739
+
740
+ logger.info(f"Running rank_genes_groups with method={method} on {len(valid_cts)} cell types...")
741
+ sc.tl.rank_genes_groups(adata_deg, label_column, method=method, n_genes=50)
742
+
743
+ # Get marker genes per cell type
744
+ results = adata_deg.uns["rank_genes_groups"]
745
+ cell_types = sorted(results["names"].dtype.names)
746
+
747
+ # Collect genes grouped by cell type (preserve order)
748
+ all_genes = []
749
+ gene_to_celltype = {}
750
+ seen = set()
751
+ for ct in cell_types:
752
+ genes = results["names"][ct][:n_genes]
753
+ for gene in genes:
754
+ if gene not in seen and gene in adata_deg.var_names:
755
+ all_genes.append(gene)
756
+ gene_to_celltype[gene] = ct
757
+ seen.add(gene)
758
+
759
+ if len(all_genes) == 0:
760
+ raise ValueError("No valid marker genes found")
761
+
762
+ # Calculate mean expression per cell type
763
+ expr_matrix = np.zeros((len(all_genes), len(cell_types)))
764
+ for j, ct in enumerate(cell_types):
765
+ ct_mask = adata_deg.obs[label_column] == ct
766
+ if ct_mask.sum() == 0:
767
+ continue
768
+ adata_ct = adata_deg[ct_mask]
769
+ for i, gene in enumerate(all_genes):
770
+ gene_idx = adata_deg.var_names.get_loc(gene)
771
+ if hasattr(adata_ct.X, "toarray"):
772
+ expr = adata_ct.X[:, gene_idx].toarray().flatten()
773
+ else:
774
+ expr = np.asarray(adata_ct.X[:, gene_idx]).flatten()
775
+ expr_matrix[i, j] = np.mean(expr)
776
+
777
+ # Release the copy - no longer needed after expression matrix is built
778
+ del adata_deg
779
+ gc.collect()
780
+
781
+ # Z-score normalize across cell types (rows)
782
+ expr_scaled = np.zeros_like(expr_matrix)
783
+ for i in range(expr_matrix.shape[0]):
784
+ row = expr_matrix[i, :]
785
+ if row.std() > 0:
786
+ expr_scaled[i, :] = (row - row.mean()) / row.std()
787
+ else:
788
+ expr_scaled[i, :] = 0
789
+
790
+ # Create color palette for cell types
791
+ n_cts = len(cell_types)
792
+ palette = sns.color_palette("tab20", n_cts)
793
+ ct_to_color = {ct: palette[i] for i, ct in enumerate(cell_types)}
794
+
795
+ # Create row colors array
796
+ row_colors = [ct_to_color[gene_to_celltype[gene]] for gene in all_genes]
797
+
798
+ # Create figure with gridspec for custom layout
799
+ fig_height = max(10, len(all_genes) * 0.11)
800
+ fig = plt.figure(figsize=(12, fig_height))
801
+
802
+ # GridSpec with annotation bar flush against heatmap
803
+ gs = fig.add_gridspec(
804
+ 1, 5,
805
+ width_ratios=[0.06, 0.012, 1, 0.02, 0.015],
806
+ wspace=0.0,
807
+ left=0.01,
808
+ right=0.78,
809
+ top=0.95,
810
+ bottom=0.12,
811
+ )
812
+
813
+ # Gene labels axis (far left)
814
+ ax_labels = fig.add_subplot(gs[0, 0])
815
+ ax_labels.set_ylim(0, len(all_genes))
816
+ ax_labels.set_xlim(0, 1)
817
+ ax_labels.invert_yaxis()
818
+ for i, gene in enumerate(all_genes):
819
+ ax_labels.text(
820
+ 0.98, i + 0.5, gene,
821
+ ha="right", va="center",
822
+ fontsize=5,
823
+ color="black",
824
+ )
825
+ ax_labels.axis("off")
826
+
827
+ # Row colors axis (annotation bar)
828
+ ax_rowcolors = fig.add_subplot(gs[0, 1])
829
+ for i, color in enumerate(row_colors):
830
+ ax_rowcolors.add_patch(plt.Rectangle(
831
+ (0, i), 1, 1,
832
+ facecolor=color,
833
+ edgecolor="none",
834
+ ))
835
+ ax_rowcolors.set_xlim(0, 1)
836
+ ax_rowcolors.set_ylim(0, len(all_genes))
837
+ ax_rowcolors.invert_yaxis()
838
+ ax_rowcolors.axis("off")
839
+
840
+ # Main heatmap axis
841
+ ax_heatmap = fig.add_subplot(gs[0, 2])
842
+ im = ax_heatmap.imshow(
843
+ expr_scaled,
844
+ aspect="auto",
845
+ cmap=cmap,
846
+ vmin=-2.5,
847
+ vmax=2.5,
848
+ )
849
+
850
+ ax_heatmap.set_yticks([])
851
+ ax_heatmap.set_xticks(range(len(cell_types)))
852
+ ax_heatmap.set_xticklabels(cell_types, rotation=45, ha="right", fontsize=8)
853
+
854
+ # Gap between heatmap and colorbar
855
+ ax_gap = fig.add_subplot(gs[0, 3])
856
+ ax_gap.axis("off")
857
+
858
+ # Colorbar axis
859
+ ax_cbar = fig.add_subplot(gs[0, 4])
860
+ cbar = plt.colorbar(im, cax=ax_cbar)
861
+ cbar.set_label("Scaled expression", fontsize=9)
862
+ cbar.ax.tick_params(labelsize=7)
863
+
864
+ # Cell type legend
865
+ legend_handles = [
866
+ plt.Rectangle((0, 0), 1, 1, facecolor=ct_to_color[ct], label=ct)
867
+ for ct in cell_types
868
+ ]
869
+ fig.legend(
870
+ handles=legend_handles,
871
+ loc="upper left",
872
+ bbox_to_anchor=(0.80, 0.95),
873
+ fontsize=6,
874
+ title="Cell type",
875
+ title_fontsize=7,
876
+ frameon=True,
877
+ fancybox=True,
878
+ )
879
+
880
+ # Title
881
+ if title is None:
882
+ title = f"Marker Genes ({label_column})"
883
+ fig.suptitle(title, fontsize=14, fontweight="bold")
884
+
885
+ if save:
886
+ save_figure(fig, save)
887
+
888
+ return fig
889
+
890
+
891
+ def plot_ontology_mapping(
892
+ adata: ad.AnnData,
893
+ source_label_column: str,
894
+ ontology_name_column: str,
895
+ ontology_id_column: str,
896
+ mapping_table: Optional[pd.DataFrame] = None,
897
+ title: Optional[str] = None,
898
+ figsize: Tuple[float, float] = (14, 8),
899
+ save: Optional[Union[str, Path]] = None,
900
+ ) -> Figure:
901
+ """
902
+ Plot ontology mapping table showing original labels mapped to Cell Ontology.
903
+
904
+ Creates a table visualization with:
905
+ - Original CellTypist labels
906
+ - Mapped ontology names and CL IDs
907
+ - Match tier (tier0=pattern, tier1=exact, etc.)
908
+ - Matching score (actual score from matching, not hardcoded)
909
+ - Cell counts
910
+
911
+ Parameters
912
+ ----------
913
+ adata : AnnData
914
+ Annotated data with ontology mapping columns.
915
+ source_label_column : str
916
+ Column with original cell type labels (e.g., "celltypist").
917
+ ontology_name_column : str
918
+ Column with mapped ontology names.
919
+ ontology_id_column : str
920
+ Column with Cell Ontology IDs (CL:XXXXXXX).
921
+ mapping_table : pd.DataFrame, optional
922
+ Pre-computed mapping table from OntologyMappingResult.table.
923
+ If provided, uses this directly instead of building from adata.
924
+ Should have columns: input_label, ontology_name, ontology_id,
925
+ match_tier, score, n_cells.
926
+ title : str, optional
927
+ Plot title. Auto-generated if None.
928
+ figsize : tuple
929
+ Figure size.
930
+ save : Path, optional
931
+ Path to save figure.
932
+
933
+ Returns
934
+ -------
935
+ Figure
936
+ Matplotlib figure with ontology mapping table.
937
+
938
+ Examples
939
+ --------
940
+ >>> # Using adata columns (scores read from {id_column}_score)
941
+ >>> fig = plot_ontology_mapping(
942
+ ... adata,
943
+ ... source_label_column="celltypist",
944
+ ... ontology_name_column="celltypist_ontology_name",
945
+ ... ontology_id_column="celltypist_ontology_id",
946
+ ... )
947
+
948
+ >>> # Using pre-computed mapping table
949
+ >>> from spatialcore.annotation import add_ontology_ids
950
+ >>> adata, mappings, result = add_ontology_ids(adata, "celltypist", save_mapping="./")
951
+ >>> fig = plot_ontology_mapping(
952
+ ... adata,
953
+ ... source_label_column="celltypist",
954
+ ... ontology_name_column="celltypist_ontology_name",
955
+ ... ontology_id_column="celltypist_ontology_id",
956
+ ... mapping_table=result.table,
957
+ ... )
958
+ """
959
+ # If mapping_table is provided, use it directly
960
+ if mapping_table is not None:
961
+ summary = mapping_table.copy()
962
+ # Rename columns to display names
963
+ col_map = {
964
+ "input_label": "CellTypist Label",
965
+ "ontology_name": "Ontology Name",
966
+ "ontology_id": "CL ID",
967
+ "match_tier": "Match Tier",
968
+ "score": "Score",
969
+ "n_cells": "Cells",
970
+ }
971
+ summary = summary.rename(columns=col_map)
972
+
973
+ # Format score column
974
+ summary["Score"] = summary["Score"].apply(
975
+ lambda x: f"{x:.2f}" if pd.notna(x) and x > 0 else "-"
976
+ )
977
+ else:
978
+ # Build from adata columns
979
+ for col in [source_label_column, ontology_name_column, ontology_id_column]:
980
+ if col not in adata.obs.columns:
981
+ raise ValueError(f"Column '{col}' not found in adata.obs")
982
+
983
+ # Check for tier and score columns (derived from ontology_id_column)
984
+ tier_column = ontology_id_column.replace("_id", "_tier")
985
+ score_column = ontology_id_column.replace("_id", "_score")
986
+ has_tier_column = tier_column in adata.obs.columns
987
+ has_score_column = score_column in adata.obs.columns
988
+
989
+ # Build mapping summary - convert to string to avoid categorical issues
990
+ cols_to_use = [source_label_column, ontology_name_column, ontology_id_column]
991
+ if has_tier_column:
992
+ cols_to_use.append(tier_column)
993
+ if has_score_column:
994
+ cols_to_use.append(score_column)
995
+
996
+ df = adata.obs[cols_to_use].copy()
997
+ df[source_label_column] = df[source_label_column].astype(str)
998
+ df[ontology_name_column] = df[ontology_name_column].astype(str)
999
+ df[ontology_id_column] = df[ontology_id_column].astype(str)
1000
+ if has_tier_column:
1001
+ df[tier_column] = df[tier_column].astype(str)
1002
+ df["count"] = 1
1003
+
1004
+ # Replace 'nan' strings with empty string
1005
+ df = df.replace("nan", "")
1006
+
1007
+ # Aggregate by source label
1008
+ agg_dict = {
1009
+ ontology_name_column: "first",
1010
+ ontology_id_column: "first",
1011
+ "count": "sum",
1012
+ }
1013
+ if has_tier_column:
1014
+ agg_dict[tier_column] = "first"
1015
+ if has_score_column:
1016
+ agg_dict[score_column] = "first"
1017
+
1018
+ summary = df.groupby(source_label_column).agg(agg_dict).reset_index()
1019
+
1020
+ # Rename columns based on what we have
1021
+ if has_tier_column and has_score_column:
1022
+ summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Match Tier", "Score"]
1023
+ elif has_tier_column:
1024
+ summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Match Tier"]
1025
+ elif has_score_column:
1026
+ summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Score"]
1027
+ else:
1028
+ summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells"]
1029
+
1030
+ # Add tier column if missing
1031
+ if "Match Tier" not in summary.columns:
1032
+ def get_tier(row):
1033
+ cl_id = str(row["CL ID"]).strip()
1034
+ if not cl_id or cl_id == "" or cl_id == "-" or cl_id == "nan" or cl_id == "unknown" or cl_id == "skipped":
1035
+ return "unmapped"
1036
+ return "tier0_pattern" # Simplified - no tier info available
1037
+
1038
+ summary["Match Tier"] = summary.apply(get_tier, axis=1)
1039
+
1040
+ # Format score column - use actual scores if available
1041
+ if "Score" in summary.columns:
1042
+ summary["Score"] = summary["Score"].apply(
1043
+ lambda x: f"{float(x):.2f}" if pd.notna(x) and str(x) not in ["", "nan", "None"] and float(x) > 0 else "-"
1044
+ )
1045
+ else:
1046
+ # Fallback: estimate score based on tier (less accurate)
1047
+ tier_scores = {
1048
+ "tier0_pattern": "0.95",
1049
+ "tier1_exact": "1.00",
1050
+ "tier2_token": "0.75",
1051
+ "tier3_overlap": "0.60",
1052
+ "unmapped": "-",
1053
+ "skipped": "-",
1054
+ }
1055
+ summary["Score"] = summary["Match Tier"].map(lambda t: tier_scores.get(t, "-"))
1056
+
1057
+ # Fill empty values with placeholder
1058
+ summary["CL ID"] = summary["CL ID"].replace("", "-")
1059
+ summary.loc[summary["Ontology Name"] == "", "Ontology Name"] = summary.loc[summary["Ontology Name"] == "", "CellTypist Label"]
1060
+
1061
+ # Sort by cell count descending
1062
+ summary = summary.sort_values("Cells", ascending=False).reset_index(drop=True)
1063
+
1064
+ # Calculate stats
1065
+ n_labels = len(summary)
1066
+ n_mapped = (~summary["Match Tier"].isin(["unmapped", "skipped"])).sum()
1067
+ total_cells = summary["Cells"].sum()
1068
+ mapped_cells = summary[~summary["Match Tier"].isin(["unmapped", "skipped"])]["Cells"].sum()
1069
+
1070
+ # Tier colors
1071
+ tier_colors = {
1072
+ "tier0_pattern": "#d4edda", # Green
1073
+ "tier1_exact": "#cce5ff", # Blue
1074
+ "tier2_token": "#fff3cd", # Orange/Yellow
1075
+ "tier3_overlap": "#f8d7da", # Red
1076
+ "unmapped": "#e9ecef", # Gray
1077
+ "skipped": "#e9ecef", # Gray (for Unassigned, Unknown, etc.)
1078
+ }
1079
+
1080
+ # Calculate figure height based on number of rows
1081
+ n_rows = len(summary)
1082
+ fig_height = max(6, 2 + n_rows * 0.4) # Dynamic height
1083
+ figsize = (figsize[0], fig_height)
1084
+
1085
+ # Create figure
1086
+ fig, ax = plt.subplots(figsize=figsize)
1087
+ ax.axis("off")
1088
+
1089
+ # Title and stats at top
1090
+ if title is None:
1091
+ title = "CellTypist to Cell Ontology Mapping"
1092
+
1093
+ fig.text(
1094
+ 0.5, 0.97,
1095
+ f"Labels: {n_mapped}/{n_labels} mapped ({100*n_mapped/n_labels:.1f}%) | "
1096
+ f"Cells: {mapped_cells:,}/{total_cells:,} mapped ({100*mapped_cells/total_cells:.1f}%)",
1097
+ ha="center", fontsize=10, color="green"
1098
+ )
1099
+ fig.text(0.5, 0.93, title, ha="center", fontsize=14, fontweight="bold")
1100
+
1101
+ # Reorder columns for display
1102
+ display_cols = ["CellTypist Label", "Ontology Name", "CL ID", "Match Tier", "Score", "Cells"]
1103
+ summary = summary[display_cols]
1104
+
1105
+ # Create table
1106
+ table = ax.table(
1107
+ cellText=summary.values,
1108
+ colLabels=summary.columns,
1109
+ cellLoc="left",
1110
+ loc="upper center",
1111
+ colColours=["#2c3e50"] * len(display_cols),
1112
+ )
1113
+
1114
+ # Style table
1115
+ table.auto_set_font_size(False)
1116
+ table.set_fontsize(8)
1117
+
1118
+ # Scale table to fit - adjust row height based on number of rows
1119
+ row_height = min(0.08, 0.8 / (n_rows + 1))
1120
+ table.scale(1.0, 1.5)
1121
+
1122
+ # Set column widths
1123
+ col_widths = [0.22, 0.22, 0.12, 0.14, 0.08, 0.10]
1124
+ for j, width in enumerate(col_widths):
1125
+ for i in range(n_rows + 1):
1126
+ table[(i, j)].set_width(width)
1127
+
1128
+ # Color header text white
1129
+ for j in range(len(display_cols)):
1130
+ table[(0, j)].get_text().set_color("white")
1131
+ table[(0, j)].get_text().set_fontweight("bold")
1132
+
1133
+ # Color rows by tier
1134
+ for i in range(len(summary)):
1135
+ tier = summary.iloc[i]["Match Tier"]
1136
+ color = tier_colors.get(tier, "#ffffff")
1137
+ for j in range(len(display_cols)):
1138
+ table[(i + 1, j)].set_facecolor(color)
1139
+
1140
+ # Add legend at bottom
1141
+ legend_text = (
1142
+ "Tier Colors: Green = Pattern Match (tier0) | Blue = Exact Match (tier1) | "
1143
+ "Orange = Token Match (tier2) | Red = Overlap (tier3) | Gray = Unmapped"
1144
+ )
1145
+ fig.text(0.5, 0.02, legend_text, ha="center", fontsize=8, style="italic")
1146
+
1147
+ if save:
1148
+ save_figure(fig, save)
1149
+
1150
+ return fig
1151
+
1152
+
1153
+ def generate_annotation_plots(
1154
+ adata: ad.AnnData,
1155
+ label_column: str = "cell_type",
1156
+ confidence_column: str = "cell_type_confidence",
1157
+ output_dir: Optional[Union[str, Path]] = None,
1158
+ prefix: str = "celltyping",
1159
+ confidence_threshold: float = 0.8,
1160
+ markers: Optional[Dict[str, List[str]]] = None,
1161
+ n_deg_genes: int = 10,
1162
+ spatial_key: str = "spatial",
1163
+ source_label_column: Optional[str] = None,
1164
+ ontology_name_column: Optional[str] = None,
1165
+ ontology_id_column: Optional[str] = None,
1166
+ ) -> Dict:
1167
+ """
1168
+ Generate all cell typing validation plots.
1169
+
1170
+ Produces four standard validation outputs per the spec:
1171
+ 1. Ontology mapping table - original labels to Cell Ontology
1172
+ 2. 2D marker validation (GMM-3) - faceted by cell type
1173
+ 3. Cell type confidence - spatial + jitter plot
1174
+ 4. DEG heatmap - top genes per cell type
1175
+
1176
+ This function should be called after annotation to validate results.
1177
+ Per spec: "Validation: 2D Multivariate QC" (Step H in workflow).
1178
+
1179
+ Parameters
1180
+ ----------
1181
+ adata : AnnData
1182
+ Annotated data with cell type labels and confidence scores.
1183
+ label_column : str, default "cell_type"
1184
+ Column containing cell type labels (CellxGene standard).
1185
+ confidence_column : str, default "cell_type_confidence"
1186
+ Column containing confidence values (z-score transformed, CellxGene standard).
1187
+ output_dir : str or Path, optional
1188
+ Directory to save plots. If None, plots are returned but not saved.
1189
+ prefix : str, default "celltyping"
1190
+ Filename prefix for saved plots.
1191
+ confidence_threshold : float, default 0.8
1192
+ Threshold for confidence validation.
1193
+ markers : Dict[str, List[str]], optional
1194
+ Custom marker genes per cell type. If None, uses canonical markers
1195
+ from C:/SpatialCore/Data/markers/canonical_markers.json.
1196
+ n_deg_genes : int, default 10
1197
+ Number of top DEGs per cell type for heatmap.
1198
+ source_label_column : str, optional
1199
+ Original CellTypist label column (for ontology mapping table).
1200
+ If None, tries to infer from label_column.
1201
+ ontology_name_column : str, optional
1202
+ Ontology name column. If None, tries "{source}_ontology_name".
1203
+ ontology_id_column : str, optional
1204
+ Ontology ID column. If None, tries "{source}_ontology_id".
1205
+ spatial_key : str, default "spatial"
1206
+ Key for spatial coordinates in adata.obsm.
1207
+
1208
+ Returns
1209
+ -------
1210
+ Dict
1211
+ Dictionary with keys:
1212
+ - "figures": dict of matplotlib Figure objects
1213
+ - "summary": validation summary DataFrame (from 2D validation)
1214
+ - "paths": dict of saved file paths (if output_dir provided)
1215
+
1216
+ Examples
1217
+ --------
1218
+ >>> from spatialcore.plotting import generate_annotation_plots
1219
+ >>> results = generate_annotation_plots(
1220
+ ... adata,
1221
+ ... output_dir="./qc_plots",
1222
+ ... prefix="lung_cancer",
1223
+ ... )
1224
+ >>> print(results["summary"])
1225
+ """
1226
+ output_dir = Path(output_dir) if output_dir else None
1227
+ if output_dir:
1228
+ output_dir.mkdir(parents=True, exist_ok=True)
1229
+
1230
+ results = {"figures": {}, "summary": None, "paths": {}}
1231
+
1232
+ # Infer ontology column names if not provided
1233
+ # Try to find source label column (original predictions before ontology mapping)
1234
+ if source_label_column is None:
1235
+ # Use the label_column as source (cell_type is the predicted type)
1236
+ source_label_column = label_column
1237
+
1238
+ if ontology_name_column is None:
1239
+ # Try CellxGene standard: cell_type_ontology_label
1240
+ candidate = "cell_type_ontology_label"
1241
+ if candidate in adata.obs.columns:
1242
+ ontology_name_column = candidate
1243
+ else:
1244
+ # Try legacy pattern: {base}_ontology_name
1245
+ base = label_column.replace("_ontology_name", "")
1246
+ candidate = f"{base}_ontology_name"
1247
+ if candidate in adata.obs.columns:
1248
+ ontology_name_column = candidate
1249
+
1250
+ if ontology_id_column is None:
1251
+ # Try CellxGene standard: cell_type_ontology_term_id
1252
+ candidate = "cell_type_ontology_term_id"
1253
+ if candidate in adata.obs.columns:
1254
+ ontology_id_column = candidate
1255
+ else:
1256
+ # Try legacy pattern: {base}_ontology_id
1257
+ base = label_column.replace("_ontology_name", "")
1258
+ candidate = f"{base}_ontology_id"
1259
+ if candidate in adata.obs.columns:
1260
+ ontology_id_column = candidate
1261
+
1262
+ # 0. Ontology Mapping Table
1263
+ logger.info("Generating ontology mapping table...")
1264
+ if source_label_column and ontology_name_column and ontology_id_column:
1265
+ try:
1266
+ path_ontology = output_dir / f"{prefix}_ontology_mapping.png" if output_dir else None
1267
+ fig_ontology = plot_ontology_mapping(
1268
+ adata,
1269
+ source_label_column=source_label_column,
1270
+ ontology_name_column=ontology_name_column,
1271
+ ontology_id_column=ontology_id_column,
1272
+ save=path_ontology,
1273
+ )
1274
+ results["figures"]["ontology_mapping"] = fig_ontology
1275
+ results["paths"]["ontology_mapping"] = path_ontology
1276
+ logger.info(" Ontology mapping table generated")
1277
+ except Exception as e:
1278
+ logger.warning(f"Ontology mapping table failed: {e}")
1279
+ else:
1280
+ logger.info(f" Skipping ontology table - columns not found")
1281
+ logger.info(f" source_label_column: {source_label_column}")
1282
+ logger.info(f" ontology_name_column: {ontology_name_column}")
1283
+ logger.info(f" ontology_id_column: {ontology_id_column}")
1284
+
1285
+ # 1. 2D Marker Validation (GMM-3)
1286
+ logger.info("Generating 2D marker validation plot...")
1287
+ try:
1288
+ path_2d = output_dir / f"{prefix}_2d_validation.png" if output_dir else None
1289
+ fig_2d, summary = plot_2d_validation(
1290
+ adata,
1291
+ label_column=label_column,
1292
+ confidence_column=confidence_column,
1293
+ markers=markers,
1294
+ confidence_threshold=confidence_threshold,
1295
+ n_components=3, # GMM-3 per spec
1296
+ save=path_2d,
1297
+ )
1298
+ results["figures"]["2d_validation"] = fig_2d
1299
+ results["summary"] = summary
1300
+ results["paths"]["2d_validation"] = path_2d
1301
+ logger.info(f" 2D validation: {len(summary)} cell types analyzed")
1302
+ except Exception as e:
1303
+ logger.warning(f"2D validation plot failed: {e}")
1304
+
1305
+ # 2. Cell Type Confidence
1306
+ logger.info("Generating confidence plot...")
1307
+ try:
1308
+ path_conf = output_dir / f"{prefix}_confidence.png" if output_dir else None
1309
+ fig_conf = plot_celltype_confidence(
1310
+ adata,
1311
+ label_column=label_column,
1312
+ confidence_column=confidence_column,
1313
+ spatial_key=spatial_key,
1314
+ threshold=confidence_threshold,
1315
+ save=path_conf,
1316
+ )
1317
+ results["figures"]["confidence"] = fig_conf
1318
+ results["paths"]["confidence"] = path_conf
1319
+ logger.info(" Confidence plot generated")
1320
+ except Exception as e:
1321
+ logger.warning(f"Confidence plot failed: {e}")
1322
+
1323
+ # 3. DEG Heatmap
1324
+ logger.info("Generating DEG heatmap...")
1325
+ try:
1326
+ path_deg = output_dir / f"{prefix}_deg_heatmap.png" if output_dir else None
1327
+ fig_deg = plot_deg_heatmap(
1328
+ adata,
1329
+ label_column=label_column,
1330
+ n_genes=n_deg_genes,
1331
+ save=path_deg,
1332
+ )
1333
+ results["figures"]["deg_heatmap"] = fig_deg
1334
+ results["paths"]["deg_heatmap"] = path_deg
1335
+ logger.info(" DEG heatmap generated")
1336
+ except Exception as e:
1337
+ logger.warning(f"DEG heatmap failed: {e}")
1338
+
1339
+ if output_dir:
1340
+ logger.info(f"Annotation plots saved to: {output_dir}")
1341
+
1342
+ return results