spatialcore 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. spatialcore/__init__.py +122 -0
  2. spatialcore/annotation/__init__.py +253 -0
  3. spatialcore/annotation/acquisition.py +529 -0
  4. spatialcore/annotation/annotate.py +603 -0
  5. spatialcore/annotation/cellxgene.py +365 -0
  6. spatialcore/annotation/confidence.py +802 -0
  7. spatialcore/annotation/discovery.py +529 -0
  8. spatialcore/annotation/expression.py +363 -0
  9. spatialcore/annotation/loading.py +529 -0
  10. spatialcore/annotation/markers.py +297 -0
  11. spatialcore/annotation/ontology.py +1282 -0
  12. spatialcore/annotation/patterns.py +247 -0
  13. spatialcore/annotation/pipeline.py +620 -0
  14. spatialcore/annotation/synapse.py +380 -0
  15. spatialcore/annotation/training.py +1457 -0
  16. spatialcore/annotation/validation.py +422 -0
  17. spatialcore/core/__init__.py +34 -0
  18. spatialcore/core/cache.py +118 -0
  19. spatialcore/core/logging.py +135 -0
  20. spatialcore/core/metadata.py +149 -0
  21. spatialcore/core/utils.py +768 -0
  22. spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
  23. spatialcore/data/markers/canonical_markers.json +83 -0
  24. spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
  25. spatialcore/plotting/__init__.py +109 -0
  26. spatialcore/plotting/benchmark.py +477 -0
  27. spatialcore/plotting/celltype.py +329 -0
  28. spatialcore/plotting/confidence.py +413 -0
  29. spatialcore/plotting/spatial.py +505 -0
  30. spatialcore/plotting/utils.py +411 -0
  31. spatialcore/plotting/validation.py +1342 -0
  32. spatialcore-0.1.9.dist-info/METADATA +213 -0
  33. spatialcore-0.1.9.dist-info/RECORD +36 -0
  34. spatialcore-0.1.9.dist-info/WHEEL +5 -0
  35. spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
  36. spatialcore-0.1.9.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ """
2
+ Plotting utilities for SpatialCore.
3
+
4
+ This module provides visualization functions for:
5
+ - Cell type distributions and UMAP plots
6
+ - Confidence score visualization
7
+ - Spatial maps of cell types and gene expression
8
+ - Marker validation heatmaps and dot plots
9
+ - Benchmark comparisons and confusion matrices
10
+ """
11
+
12
+ from spatialcore.plotting.utils import (
13
+ # Color palettes
14
+ DEFAULT_PALETTE,
15
+ COLORBLIND_PALETTE,
16
+ generate_celltype_palette,
17
+ load_celltype_palette,
18
+ save_celltype_palette,
19
+ # Figure setup
20
+ setup_figure,
21
+ setup_multi_figure,
22
+ save_figure,
23
+ close_figure,
24
+ # Axis formatting
25
+ format_axis_labels,
26
+ despine,
27
+ )
28
+
29
+ from spatialcore.plotting.celltype import (
30
+ plot_celltype_distribution,
31
+ plot_celltype_pie,
32
+ plot_celltype_umap,
33
+ )
34
+
35
+ from spatialcore.plotting.confidence import (
36
+ plot_confidence_histogram,
37
+ plot_confidence_by_celltype,
38
+ plot_confidence_violin,
39
+ plot_model_contribution,
40
+ )
41
+
42
+ from spatialcore.plotting.spatial import (
43
+ plot_spatial_celltype,
44
+ plot_spatial_confidence,
45
+ plot_spatial_gene,
46
+ plot_spatial_multi_gene,
47
+ )
48
+
49
+ from spatialcore.plotting.validation import (
50
+ plot_marker_heatmap,
51
+ plot_2d_validation,
52
+ plot_marker_dotplot,
53
+ plot_celltype_confidence,
54
+ plot_deg_heatmap,
55
+ plot_ontology_mapping,
56
+ generate_annotation_plots,
57
+ )
58
+
59
+ from spatialcore.plotting.benchmark import (
60
+ plot_method_comparison,
61
+ plot_confusion_matrix,
62
+ plot_classification_report,
63
+ plot_agreement_heatmap,
64
+ plot_silhouette_by_type,
65
+ )
66
+
67
+ __all__ = [
68
+ # Utils - Palettes
69
+ "DEFAULT_PALETTE",
70
+ "COLORBLIND_PALETTE",
71
+ "generate_celltype_palette",
72
+ "load_celltype_palette",
73
+ "save_celltype_palette",
74
+ # Utils - Figures
75
+ "setup_figure",
76
+ "setup_multi_figure",
77
+ "save_figure",
78
+ "close_figure",
79
+ "format_axis_labels",
80
+ "despine",
81
+ # Cell type
82
+ "plot_celltype_distribution",
83
+ "plot_celltype_pie",
84
+ "plot_celltype_umap",
85
+ # Confidence
86
+ "plot_confidence_histogram",
87
+ "plot_confidence_by_celltype",
88
+ "plot_confidence_violin",
89
+ "plot_model_contribution",
90
+ # Spatial
91
+ "plot_spatial_celltype",
92
+ "plot_spatial_confidence",
93
+ "plot_spatial_gene",
94
+ "plot_spatial_multi_gene",
95
+ # Validation
96
+ "plot_marker_heatmap",
97
+ "plot_2d_validation",
98
+ "plot_marker_dotplot",
99
+ "plot_celltype_confidence",
100
+ "plot_deg_heatmap",
101
+ "plot_ontology_mapping",
102
+ "generate_annotation_plots",
103
+ # Benchmark
104
+ "plot_method_comparison",
105
+ "plot_confusion_matrix",
106
+ "plot_classification_report",
107
+ "plot_agreement_heatmap",
108
+ "plot_silhouette_by_type",
109
+ ]
@@ -0,0 +1,477 @@
1
+ """
2
+ Benchmark visualization for annotation method comparison.
3
+
4
+ This module provides functions for comparing annotation methods
5
+ and visualizing classification performance.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Union
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.figure import Figure
15
+ import anndata as ad
16
+
17
+ from spatialcore.core.logging import get_logger
18
+ from spatialcore.plotting.utils import (
19
+ generate_celltype_palette,
20
+ setup_figure,
21
+ save_figure,
22
+ despine,
23
+ format_axis_labels,
24
+ )
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def plot_method_comparison(
30
+ df: pd.DataFrame,
31
+ metrics: List[str] = None,
32
+ method_column: str = "method",
33
+ figsize: tuple = (10, 6),
34
+ title: Optional[str] = None,
35
+ save: Optional[Union[str, Path]] = None,
36
+ ) -> Figure:
37
+ """
38
+ Plot comparison of annotation methods across metrics.
39
+
40
+ Parameters
41
+ ----------
42
+ df : pd.DataFrame
43
+ DataFrame with methods as rows and metrics as columns.
44
+ Should have a column identifying the method.
45
+ metrics : List[str], optional
46
+ Metrics to compare. Default: all numeric columns.
47
+ method_column : str, default "method"
48
+ Column containing method names.
49
+ figsize : tuple, default (10, 6)
50
+ Figure size.
51
+ title : str, optional
52
+ Plot title.
53
+ save : str or Path, optional
54
+ Path to save figure.
55
+
56
+ Returns
57
+ -------
58
+ Figure
59
+ Matplotlib figure.
60
+
61
+ Examples
62
+ --------
63
+ >>> from spatialcore.plotting.benchmark import plot_method_comparison
64
+ >>> df = pd.DataFrame({
65
+ ... "method": ["CellTypist", "HieraType", "Manual"],
66
+ ... "Accuracy": [0.85, 0.88, 0.92],
67
+ ... "Silhouette": [0.45, 0.52, 0.48],
68
+ ... })
69
+ >>> fig = plot_method_comparison(df, metrics=["Accuracy", "Silhouette"])
70
+ """
71
+ if method_column not in df.columns:
72
+ raise ValueError(f"Method column '{method_column}' not found.")
73
+
74
+ if metrics is None:
75
+ metrics = [c for c in df.columns if c != method_column and np.issubdtype(df[c].dtype, np.number)]
76
+
77
+ if not metrics:
78
+ raise ValueError("No numeric metrics found.")
79
+
80
+ methods = df[method_column].tolist()
81
+ n_methods = len(methods)
82
+ n_metrics = len(metrics)
83
+
84
+ fig, ax = setup_figure(figsize=figsize)
85
+
86
+ x = np.arange(n_metrics)
87
+ width = 0.8 / n_methods
88
+
89
+ colors = generate_celltype_palette(methods)
90
+
91
+ for i, method in enumerate(methods):
92
+ values = df[df[method_column] == method][metrics].values.flatten()
93
+ offset = (i - n_methods / 2 + 0.5) * width
94
+ bars = ax.bar(
95
+ x + offset,
96
+ values,
97
+ width,
98
+ label=method,
99
+ color=colors.get(method, "#888888"),
100
+ )
101
+
102
+ # Add value labels on bars
103
+ for bar, val in zip(bars, values):
104
+ ax.text(
105
+ bar.get_x() + bar.get_width() / 2,
106
+ bar.get_height() + 0.01,
107
+ f"{val:.2f}",
108
+ ha="center",
109
+ va="bottom",
110
+ fontsize=8,
111
+ )
112
+
113
+ ax.set_xticks(x)
114
+ ax.set_xticklabels(metrics)
115
+ ax.set_ylabel("Score")
116
+ ax.legend()
117
+
118
+ despine(ax)
119
+
120
+ if title is None:
121
+ title = "Method Comparison"
122
+ ax.set_title(title)
123
+
124
+ plt.tight_layout()
125
+
126
+ if save:
127
+ save_figure(fig, save)
128
+
129
+ return fig
130
+
131
+
132
+ def plot_confusion_matrix(
133
+ true_labels: np.ndarray,
134
+ pred_labels: np.ndarray,
135
+ labels: Optional[List[str]] = None,
136
+ normalize: bool = True,
137
+ cmap: str = "Blues",
138
+ figsize: Optional[tuple] = None,
139
+ title: Optional[str] = None,
140
+ save: Optional[Union[str, Path]] = None,
141
+ ) -> Figure:
142
+ """
143
+ Plot confusion matrix.
144
+
145
+ Parameters
146
+ ----------
147
+ true_labels : np.ndarray
148
+ True class labels.
149
+ pred_labels : np.ndarray
150
+ Predicted class labels.
151
+ labels : List[str], optional
152
+ Class labels. If None, inferred from data.
153
+ normalize : bool, default True
154
+ Normalize by true class (row sums to 1).
155
+ cmap : str, default "Blues"
156
+ Colormap.
157
+ figsize : tuple, optional
158
+ Figure size.
159
+ title : str, optional
160
+ Plot title.
161
+ save : str or Path, optional
162
+ Path to save figure.
163
+
164
+ Returns
165
+ -------
166
+ Figure
167
+ Matplotlib figure.
168
+
169
+ Examples
170
+ --------
171
+ >>> from spatialcore.plotting.benchmark import plot_confusion_matrix
172
+ >>> true = adata.obs["true_label"].values
173
+ >>> pred = adata.obs["predicted_label"].values
174
+ >>> fig = plot_confusion_matrix(true, pred, normalize=True)
175
+ """
176
+ from sklearn.metrics import confusion_matrix
177
+
178
+ if labels is None:
179
+ labels = sorted(set(true_labels) | set(pred_labels))
180
+
181
+ cm = confusion_matrix(true_labels, pred_labels, labels=labels)
182
+
183
+ if normalize:
184
+ cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
185
+ cm = np.nan_to_num(cm)
186
+
187
+ n_labels = len(labels)
188
+
189
+ if figsize is None:
190
+ figsize = (max(8, n_labels * 0.5), max(6, n_labels * 0.5))
191
+
192
+ fig, ax = setup_figure(figsize=figsize)
193
+
194
+ im = ax.imshow(cm, cmap=cmap, aspect="auto")
195
+
196
+ # Add colorbar
197
+ cbar = plt.colorbar(im, ax=ax)
198
+ cbar.set_label("Fraction" if normalize else "Count")
199
+
200
+ # Add text annotations
201
+ thresh = cm.max() / 2
202
+ for i in range(n_labels):
203
+ for j in range(n_labels):
204
+ val = cm[i, j]
205
+ if normalize:
206
+ text = f"{val:.2f}"
207
+ else:
208
+ text = f"{int(val)}"
209
+ ax.text(
210
+ j,
211
+ i,
212
+ text,
213
+ ha="center",
214
+ va="center",
215
+ color="white" if val > thresh else "black",
216
+ fontsize=8,
217
+ )
218
+
219
+ ax.set_xticks(range(n_labels))
220
+ ax.set_yticks(range(n_labels))
221
+ ax.set_xticklabels(labels, rotation=45, ha="right")
222
+ ax.set_yticklabels(labels)
223
+
224
+ ax.set_xlabel("Predicted")
225
+ ax.set_ylabel("True")
226
+
227
+ if title is None:
228
+ title = "Confusion Matrix"
229
+ ax.set_title(title)
230
+
231
+ plt.tight_layout()
232
+
233
+ if save:
234
+ save_figure(fig, save)
235
+
236
+ return fig
237
+
238
+
239
+ def plot_classification_report(
240
+ true_labels: np.ndarray,
241
+ pred_labels: np.ndarray,
242
+ labels: Optional[List[str]] = None,
243
+ figsize: Optional[tuple] = None,
244
+ title: Optional[str] = None,
245
+ save: Optional[Union[str, Path]] = None,
246
+ ) -> Figure:
247
+ """
248
+ Plot classification metrics (precision, recall, F1) per class.
249
+
250
+ Parameters
251
+ ----------
252
+ true_labels : np.ndarray
253
+ True class labels.
254
+ pred_labels : np.ndarray
255
+ Predicted class labels.
256
+ labels : List[str], optional
257
+ Class labels. If None, inferred from data.
258
+ figsize : tuple, optional
259
+ Figure size.
260
+ title : str, optional
261
+ Plot title.
262
+ save : str or Path, optional
263
+ Path to save figure.
264
+
265
+ Returns
266
+ -------
267
+ Figure
268
+ Matplotlib figure.
269
+ """
270
+ from sklearn.metrics import precision_recall_fscore_support
271
+
272
+ if labels is None:
273
+ labels = sorted(set(true_labels) | set(pred_labels))
274
+
275
+ precision, recall, f1, support = precision_recall_fscore_support(
276
+ true_labels, pred_labels, labels=labels, zero_division=0
277
+ )
278
+
279
+ n_labels = len(labels)
280
+
281
+ if figsize is None:
282
+ figsize = (max(10, n_labels * 0.5), 6)
283
+
284
+ fig, ax = setup_figure(figsize=figsize)
285
+
286
+ x = np.arange(n_labels)
287
+ width = 0.25
288
+
289
+ ax.bar(x - width, precision, width, label="Precision", color="#3784FE")
290
+ ax.bar(x, recall, width, label="Recall", color="#33FF33")
291
+ ax.bar(x + width, f1, width, label="F1", color="#FF6B6B")
292
+
293
+ ax.set_xticks(x)
294
+ ax.set_xticklabels(labels, rotation=45, ha="right")
295
+ ax.set_ylabel("Score")
296
+ ax.set_ylim(0, 1.1)
297
+ ax.legend()
298
+
299
+ despine(ax)
300
+
301
+ if title is None:
302
+ title = "Classification Metrics by Class"
303
+ ax.set_title(title)
304
+
305
+ plt.tight_layout()
306
+
307
+ if save:
308
+ save_figure(fig, save)
309
+
310
+ return fig
311
+
312
+
313
+ def plot_agreement_heatmap(
314
+ adata: ad.AnnData,
315
+ columns: List[str],
316
+ figsize: Optional[tuple] = None,
317
+ cmap: str = "Greens",
318
+ title: Optional[str] = None,
319
+ save: Optional[Union[str, Path]] = None,
320
+ ) -> Figure:
321
+ """
322
+ Plot agreement matrix between annotation methods.
323
+
324
+ Parameters
325
+ ----------
326
+ adata : AnnData
327
+ Annotated data with multiple annotation columns.
328
+ columns : List[str]
329
+ Columns in adata.obs to compare.
330
+ figsize : tuple, optional
331
+ Figure size.
332
+ cmap : str, default "Greens"
333
+ Colormap.
334
+ title : str, optional
335
+ Plot title.
336
+ save : str or Path, optional
337
+ Path to save figure.
338
+
339
+ Returns
340
+ -------
341
+ Figure
342
+ Matplotlib figure.
343
+ """
344
+ for col in columns:
345
+ if col not in adata.obs.columns:
346
+ raise ValueError(f"Column '{col}' not found.")
347
+
348
+ n_methods = len(columns)
349
+
350
+ # Calculate agreement matrix
351
+ agreement = np.zeros((n_methods, n_methods))
352
+ for i, col1 in enumerate(columns):
353
+ for j, col2 in enumerate(columns):
354
+ agreement[i, j] = (adata.obs[col1] == adata.obs[col2]).mean()
355
+
356
+ if figsize is None:
357
+ figsize = (max(6, n_methods * 1.2), max(5, n_methods))
358
+
359
+ fig, ax = setup_figure(figsize=figsize)
360
+
361
+ im = ax.imshow(agreement, cmap=cmap, vmin=0, vmax=1)
362
+ plt.colorbar(im, ax=ax, label="Agreement")
363
+
364
+ # Add text
365
+ for i in range(n_methods):
366
+ for j in range(n_methods):
367
+ ax.text(
368
+ j,
369
+ i,
370
+ f"{agreement[i, j]:.2f}",
371
+ ha="center",
372
+ va="center",
373
+ color="white" if agreement[i, j] > 0.5 else "black",
374
+ )
375
+
376
+ ax.set_xticks(range(n_methods))
377
+ ax.set_yticks(range(n_methods))
378
+ ax.set_xticklabels(columns, rotation=45, ha="right")
379
+ ax.set_yticklabels(columns)
380
+
381
+ if title is None:
382
+ title = "Method Agreement"
383
+ ax.set_title(title)
384
+
385
+ plt.tight_layout()
386
+
387
+ if save:
388
+ save_figure(fig, save)
389
+
390
+ return fig
391
+
392
+
393
+ def plot_silhouette_by_type(
394
+ adata: ad.AnnData,
395
+ label_column: str,
396
+ embedding_key: str = "X_pca",
397
+ sample_size: int = 5000,
398
+ random_state: int = 42,
399
+ figsize: tuple = (10, 6),
400
+ title: Optional[str] = None,
401
+ save: Optional[Union[str, Path]] = None,
402
+ ) -> Figure:
403
+ """
404
+ Plot silhouette scores by cell type.
405
+
406
+ Parameters
407
+ ----------
408
+ adata : AnnData
409
+ Annotated data.
410
+ label_column : str
411
+ Column in adata.obs containing cell type labels.
412
+ embedding_key : str, default "X_pca"
413
+ Key in adata.obsm for embedding.
414
+ sample_size : int, default 5000
415
+ Number of cells to sample (for speed).
416
+ random_state : int, default 42
417
+ Random seed.
418
+ figsize : tuple, default (10, 6)
419
+ Figure size.
420
+ title : str, optional
421
+ Plot title.
422
+ save : str or Path, optional
423
+ Path to save figure.
424
+
425
+ Returns
426
+ -------
427
+ Figure
428
+ Matplotlib figure.
429
+ """
430
+ from sklearn.metrics import silhouette_samples
431
+
432
+ if label_column not in adata.obs.columns:
433
+ raise ValueError(f"Label column '{label_column}' not found.")
434
+ if embedding_key not in adata.obsm:
435
+ raise ValueError(f"Embedding '{embedding_key}' not found.")
436
+
437
+ # Sample if too large
438
+ if adata.n_obs > sample_size:
439
+ np.random.seed(random_state)
440
+ idx = np.random.choice(adata.n_obs, sample_size, replace=False)
441
+ X = adata.obsm[embedding_key][idx]
442
+ labels = adata.obs[label_column].values[idx]
443
+ else:
444
+ X = adata.obsm[embedding_key]
445
+ labels = adata.obs[label_column].values
446
+
447
+ # Calculate silhouette scores
448
+ sil_scores = silhouette_samples(X, labels)
449
+
450
+ # Get mean per type
451
+ df = pd.DataFrame({"label": labels, "silhouette": sil_scores})
452
+ type_scores = df.groupby("label")["silhouette"].mean().sort_values()
453
+
454
+ fig, ax = setup_figure(figsize=figsize)
455
+
456
+ colors = generate_celltype_palette(type_scores.index.tolist())
457
+ bar_colors = [colors.get(ct, "#888888") for ct in type_scores.index]
458
+
459
+ y_pos = np.arange(len(type_scores))
460
+ ax.barh(y_pos, type_scores.values, color=bar_colors)
461
+ ax.set_yticks(y_pos)
462
+ ax.set_yticklabels(type_scores.index)
463
+ ax.axvline(0, color="gray", linestyle="--")
464
+
465
+ format_axis_labels(ax, xlabel="Silhouette Score")
466
+ despine(ax)
467
+
468
+ if title is None:
469
+ title = f"Silhouette Scores by Cell Type\n(mean={sil_scores.mean():.3f})"
470
+ ax.set_title(title)
471
+
472
+ plt.tight_layout()
473
+
474
+ if save:
475
+ save_figure(fig, save)
476
+
477
+ return fig