spatialcore 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcore/__init__.py +122 -0
- spatialcore/annotation/__init__.py +253 -0
- spatialcore/annotation/acquisition.py +529 -0
- spatialcore/annotation/annotate.py +603 -0
- spatialcore/annotation/cellxgene.py +365 -0
- spatialcore/annotation/confidence.py +802 -0
- spatialcore/annotation/discovery.py +529 -0
- spatialcore/annotation/expression.py +363 -0
- spatialcore/annotation/loading.py +529 -0
- spatialcore/annotation/markers.py +297 -0
- spatialcore/annotation/ontology.py +1282 -0
- spatialcore/annotation/patterns.py +247 -0
- spatialcore/annotation/pipeline.py +620 -0
- spatialcore/annotation/synapse.py +380 -0
- spatialcore/annotation/training.py +1457 -0
- spatialcore/annotation/validation.py +422 -0
- spatialcore/core/__init__.py +34 -0
- spatialcore/core/cache.py +118 -0
- spatialcore/core/logging.py +135 -0
- spatialcore/core/metadata.py +149 -0
- spatialcore/core/utils.py +768 -0
- spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
- spatialcore/data/markers/canonical_markers.json +83 -0
- spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
- spatialcore/plotting/__init__.py +109 -0
- spatialcore/plotting/benchmark.py +477 -0
- spatialcore/plotting/celltype.py +329 -0
- spatialcore/plotting/confidence.py +413 -0
- spatialcore/plotting/spatial.py +505 -0
- spatialcore/plotting/utils.py +411 -0
- spatialcore/plotting/validation.py +1342 -0
- spatialcore-0.1.9.dist-info/METADATA +213 -0
- spatialcore-0.1.9.dist-info/RECORD +36 -0
- spatialcore-0.1.9.dist-info/WHEEL +5 -0
- spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
- spatialcore-0.1.9.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plotting utilities for SpatialCore.
|
|
3
|
+
|
|
4
|
+
This module provides visualization functions for:
|
|
5
|
+
- Cell type distributions and UMAP plots
|
|
6
|
+
- Confidence score visualization
|
|
7
|
+
- Spatial maps of cell types and gene expression
|
|
8
|
+
- Marker validation heatmaps and dot plots
|
|
9
|
+
- Benchmark comparisons and confusion matrices
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from spatialcore.plotting.utils import (
|
|
13
|
+
# Color palettes
|
|
14
|
+
DEFAULT_PALETTE,
|
|
15
|
+
COLORBLIND_PALETTE,
|
|
16
|
+
generate_celltype_palette,
|
|
17
|
+
load_celltype_palette,
|
|
18
|
+
save_celltype_palette,
|
|
19
|
+
# Figure setup
|
|
20
|
+
setup_figure,
|
|
21
|
+
setup_multi_figure,
|
|
22
|
+
save_figure,
|
|
23
|
+
close_figure,
|
|
24
|
+
# Axis formatting
|
|
25
|
+
format_axis_labels,
|
|
26
|
+
despine,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from spatialcore.plotting.celltype import (
|
|
30
|
+
plot_celltype_distribution,
|
|
31
|
+
plot_celltype_pie,
|
|
32
|
+
plot_celltype_umap,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
from spatialcore.plotting.confidence import (
|
|
36
|
+
plot_confidence_histogram,
|
|
37
|
+
plot_confidence_by_celltype,
|
|
38
|
+
plot_confidence_violin,
|
|
39
|
+
plot_model_contribution,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from spatialcore.plotting.spatial import (
|
|
43
|
+
plot_spatial_celltype,
|
|
44
|
+
plot_spatial_confidence,
|
|
45
|
+
plot_spatial_gene,
|
|
46
|
+
plot_spatial_multi_gene,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
from spatialcore.plotting.validation import (
|
|
50
|
+
plot_marker_heatmap,
|
|
51
|
+
plot_2d_validation,
|
|
52
|
+
plot_marker_dotplot,
|
|
53
|
+
plot_celltype_confidence,
|
|
54
|
+
plot_deg_heatmap,
|
|
55
|
+
plot_ontology_mapping,
|
|
56
|
+
generate_annotation_plots,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
from spatialcore.plotting.benchmark import (
|
|
60
|
+
plot_method_comparison,
|
|
61
|
+
plot_confusion_matrix,
|
|
62
|
+
plot_classification_report,
|
|
63
|
+
plot_agreement_heatmap,
|
|
64
|
+
plot_silhouette_by_type,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
__all__ = [
|
|
68
|
+
# Utils - Palettes
|
|
69
|
+
"DEFAULT_PALETTE",
|
|
70
|
+
"COLORBLIND_PALETTE",
|
|
71
|
+
"generate_celltype_palette",
|
|
72
|
+
"load_celltype_palette",
|
|
73
|
+
"save_celltype_palette",
|
|
74
|
+
# Utils - Figures
|
|
75
|
+
"setup_figure",
|
|
76
|
+
"setup_multi_figure",
|
|
77
|
+
"save_figure",
|
|
78
|
+
"close_figure",
|
|
79
|
+
"format_axis_labels",
|
|
80
|
+
"despine",
|
|
81
|
+
# Cell type
|
|
82
|
+
"plot_celltype_distribution",
|
|
83
|
+
"plot_celltype_pie",
|
|
84
|
+
"plot_celltype_umap",
|
|
85
|
+
# Confidence
|
|
86
|
+
"plot_confidence_histogram",
|
|
87
|
+
"plot_confidence_by_celltype",
|
|
88
|
+
"plot_confidence_violin",
|
|
89
|
+
"plot_model_contribution",
|
|
90
|
+
# Spatial
|
|
91
|
+
"plot_spatial_celltype",
|
|
92
|
+
"plot_spatial_confidence",
|
|
93
|
+
"plot_spatial_gene",
|
|
94
|
+
"plot_spatial_multi_gene",
|
|
95
|
+
# Validation
|
|
96
|
+
"plot_marker_heatmap",
|
|
97
|
+
"plot_2d_validation",
|
|
98
|
+
"plot_marker_dotplot",
|
|
99
|
+
"plot_celltype_confidence",
|
|
100
|
+
"plot_deg_heatmap",
|
|
101
|
+
"plot_ontology_mapping",
|
|
102
|
+
"generate_annotation_plots",
|
|
103
|
+
# Benchmark
|
|
104
|
+
"plot_method_comparison",
|
|
105
|
+
"plot_confusion_matrix",
|
|
106
|
+
"plot_classification_report",
|
|
107
|
+
"plot_agreement_heatmap",
|
|
108
|
+
"plot_silhouette_by_type",
|
|
109
|
+
]
|
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Benchmark visualization for annotation method comparison.
|
|
3
|
+
|
|
4
|
+
This module provides functions for comparing annotation methods
|
|
5
|
+
and visualizing classification performance.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from matplotlib.figure import Figure
|
|
15
|
+
import anndata as ad
|
|
16
|
+
|
|
17
|
+
from spatialcore.core.logging import get_logger
|
|
18
|
+
from spatialcore.plotting.utils import (
|
|
19
|
+
generate_celltype_palette,
|
|
20
|
+
setup_figure,
|
|
21
|
+
save_figure,
|
|
22
|
+
despine,
|
|
23
|
+
format_axis_labels,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def plot_method_comparison(
|
|
30
|
+
df: pd.DataFrame,
|
|
31
|
+
metrics: List[str] = None,
|
|
32
|
+
method_column: str = "method",
|
|
33
|
+
figsize: tuple = (10, 6),
|
|
34
|
+
title: Optional[str] = None,
|
|
35
|
+
save: Optional[Union[str, Path]] = None,
|
|
36
|
+
) -> Figure:
|
|
37
|
+
"""
|
|
38
|
+
Plot comparison of annotation methods across metrics.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
df : pd.DataFrame
|
|
43
|
+
DataFrame with methods as rows and metrics as columns.
|
|
44
|
+
Should have a column identifying the method.
|
|
45
|
+
metrics : List[str], optional
|
|
46
|
+
Metrics to compare. Default: all numeric columns.
|
|
47
|
+
method_column : str, default "method"
|
|
48
|
+
Column containing method names.
|
|
49
|
+
figsize : tuple, default (10, 6)
|
|
50
|
+
Figure size.
|
|
51
|
+
title : str, optional
|
|
52
|
+
Plot title.
|
|
53
|
+
save : str or Path, optional
|
|
54
|
+
Path to save figure.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
Figure
|
|
59
|
+
Matplotlib figure.
|
|
60
|
+
|
|
61
|
+
Examples
|
|
62
|
+
--------
|
|
63
|
+
>>> from spatialcore.plotting.benchmark import plot_method_comparison
|
|
64
|
+
>>> df = pd.DataFrame({
|
|
65
|
+
... "method": ["CellTypist", "HieraType", "Manual"],
|
|
66
|
+
... "Accuracy": [0.85, 0.88, 0.92],
|
|
67
|
+
... "Silhouette": [0.45, 0.52, 0.48],
|
|
68
|
+
... })
|
|
69
|
+
>>> fig = plot_method_comparison(df, metrics=["Accuracy", "Silhouette"])
|
|
70
|
+
"""
|
|
71
|
+
if method_column not in df.columns:
|
|
72
|
+
raise ValueError(f"Method column '{method_column}' not found.")
|
|
73
|
+
|
|
74
|
+
if metrics is None:
|
|
75
|
+
metrics = [c for c in df.columns if c != method_column and np.issubdtype(df[c].dtype, np.number)]
|
|
76
|
+
|
|
77
|
+
if not metrics:
|
|
78
|
+
raise ValueError("No numeric metrics found.")
|
|
79
|
+
|
|
80
|
+
methods = df[method_column].tolist()
|
|
81
|
+
n_methods = len(methods)
|
|
82
|
+
n_metrics = len(metrics)
|
|
83
|
+
|
|
84
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
85
|
+
|
|
86
|
+
x = np.arange(n_metrics)
|
|
87
|
+
width = 0.8 / n_methods
|
|
88
|
+
|
|
89
|
+
colors = generate_celltype_palette(methods)
|
|
90
|
+
|
|
91
|
+
for i, method in enumerate(methods):
|
|
92
|
+
values = df[df[method_column] == method][metrics].values.flatten()
|
|
93
|
+
offset = (i - n_methods / 2 + 0.5) * width
|
|
94
|
+
bars = ax.bar(
|
|
95
|
+
x + offset,
|
|
96
|
+
values,
|
|
97
|
+
width,
|
|
98
|
+
label=method,
|
|
99
|
+
color=colors.get(method, "#888888"),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Add value labels on bars
|
|
103
|
+
for bar, val in zip(bars, values):
|
|
104
|
+
ax.text(
|
|
105
|
+
bar.get_x() + bar.get_width() / 2,
|
|
106
|
+
bar.get_height() + 0.01,
|
|
107
|
+
f"{val:.2f}",
|
|
108
|
+
ha="center",
|
|
109
|
+
va="bottom",
|
|
110
|
+
fontsize=8,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
ax.set_xticks(x)
|
|
114
|
+
ax.set_xticklabels(metrics)
|
|
115
|
+
ax.set_ylabel("Score")
|
|
116
|
+
ax.legend()
|
|
117
|
+
|
|
118
|
+
despine(ax)
|
|
119
|
+
|
|
120
|
+
if title is None:
|
|
121
|
+
title = "Method Comparison"
|
|
122
|
+
ax.set_title(title)
|
|
123
|
+
|
|
124
|
+
plt.tight_layout()
|
|
125
|
+
|
|
126
|
+
if save:
|
|
127
|
+
save_figure(fig, save)
|
|
128
|
+
|
|
129
|
+
return fig
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def plot_confusion_matrix(
|
|
133
|
+
true_labels: np.ndarray,
|
|
134
|
+
pred_labels: np.ndarray,
|
|
135
|
+
labels: Optional[List[str]] = None,
|
|
136
|
+
normalize: bool = True,
|
|
137
|
+
cmap: str = "Blues",
|
|
138
|
+
figsize: Optional[tuple] = None,
|
|
139
|
+
title: Optional[str] = None,
|
|
140
|
+
save: Optional[Union[str, Path]] = None,
|
|
141
|
+
) -> Figure:
|
|
142
|
+
"""
|
|
143
|
+
Plot confusion matrix.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
true_labels : np.ndarray
|
|
148
|
+
True class labels.
|
|
149
|
+
pred_labels : np.ndarray
|
|
150
|
+
Predicted class labels.
|
|
151
|
+
labels : List[str], optional
|
|
152
|
+
Class labels. If None, inferred from data.
|
|
153
|
+
normalize : bool, default True
|
|
154
|
+
Normalize by true class (row sums to 1).
|
|
155
|
+
cmap : str, default "Blues"
|
|
156
|
+
Colormap.
|
|
157
|
+
figsize : tuple, optional
|
|
158
|
+
Figure size.
|
|
159
|
+
title : str, optional
|
|
160
|
+
Plot title.
|
|
161
|
+
save : str or Path, optional
|
|
162
|
+
Path to save figure.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
Figure
|
|
167
|
+
Matplotlib figure.
|
|
168
|
+
|
|
169
|
+
Examples
|
|
170
|
+
--------
|
|
171
|
+
>>> from spatialcore.plotting.benchmark import plot_confusion_matrix
|
|
172
|
+
>>> true = adata.obs["true_label"].values
|
|
173
|
+
>>> pred = adata.obs["predicted_label"].values
|
|
174
|
+
>>> fig = plot_confusion_matrix(true, pred, normalize=True)
|
|
175
|
+
"""
|
|
176
|
+
from sklearn.metrics import confusion_matrix
|
|
177
|
+
|
|
178
|
+
if labels is None:
|
|
179
|
+
labels = sorted(set(true_labels) | set(pred_labels))
|
|
180
|
+
|
|
181
|
+
cm = confusion_matrix(true_labels, pred_labels, labels=labels)
|
|
182
|
+
|
|
183
|
+
if normalize:
|
|
184
|
+
cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
|
|
185
|
+
cm = np.nan_to_num(cm)
|
|
186
|
+
|
|
187
|
+
n_labels = len(labels)
|
|
188
|
+
|
|
189
|
+
if figsize is None:
|
|
190
|
+
figsize = (max(8, n_labels * 0.5), max(6, n_labels * 0.5))
|
|
191
|
+
|
|
192
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
193
|
+
|
|
194
|
+
im = ax.imshow(cm, cmap=cmap, aspect="auto")
|
|
195
|
+
|
|
196
|
+
# Add colorbar
|
|
197
|
+
cbar = plt.colorbar(im, ax=ax)
|
|
198
|
+
cbar.set_label("Fraction" if normalize else "Count")
|
|
199
|
+
|
|
200
|
+
# Add text annotations
|
|
201
|
+
thresh = cm.max() / 2
|
|
202
|
+
for i in range(n_labels):
|
|
203
|
+
for j in range(n_labels):
|
|
204
|
+
val = cm[i, j]
|
|
205
|
+
if normalize:
|
|
206
|
+
text = f"{val:.2f}"
|
|
207
|
+
else:
|
|
208
|
+
text = f"{int(val)}"
|
|
209
|
+
ax.text(
|
|
210
|
+
j,
|
|
211
|
+
i,
|
|
212
|
+
text,
|
|
213
|
+
ha="center",
|
|
214
|
+
va="center",
|
|
215
|
+
color="white" if val > thresh else "black",
|
|
216
|
+
fontsize=8,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
ax.set_xticks(range(n_labels))
|
|
220
|
+
ax.set_yticks(range(n_labels))
|
|
221
|
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
|
222
|
+
ax.set_yticklabels(labels)
|
|
223
|
+
|
|
224
|
+
ax.set_xlabel("Predicted")
|
|
225
|
+
ax.set_ylabel("True")
|
|
226
|
+
|
|
227
|
+
if title is None:
|
|
228
|
+
title = "Confusion Matrix"
|
|
229
|
+
ax.set_title(title)
|
|
230
|
+
|
|
231
|
+
plt.tight_layout()
|
|
232
|
+
|
|
233
|
+
if save:
|
|
234
|
+
save_figure(fig, save)
|
|
235
|
+
|
|
236
|
+
return fig
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def plot_classification_report(
|
|
240
|
+
true_labels: np.ndarray,
|
|
241
|
+
pred_labels: np.ndarray,
|
|
242
|
+
labels: Optional[List[str]] = None,
|
|
243
|
+
figsize: Optional[tuple] = None,
|
|
244
|
+
title: Optional[str] = None,
|
|
245
|
+
save: Optional[Union[str, Path]] = None,
|
|
246
|
+
) -> Figure:
|
|
247
|
+
"""
|
|
248
|
+
Plot classification metrics (precision, recall, F1) per class.
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
true_labels : np.ndarray
|
|
253
|
+
True class labels.
|
|
254
|
+
pred_labels : np.ndarray
|
|
255
|
+
Predicted class labels.
|
|
256
|
+
labels : List[str], optional
|
|
257
|
+
Class labels. If None, inferred from data.
|
|
258
|
+
figsize : tuple, optional
|
|
259
|
+
Figure size.
|
|
260
|
+
title : str, optional
|
|
261
|
+
Plot title.
|
|
262
|
+
save : str or Path, optional
|
|
263
|
+
Path to save figure.
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
Figure
|
|
268
|
+
Matplotlib figure.
|
|
269
|
+
"""
|
|
270
|
+
from sklearn.metrics import precision_recall_fscore_support
|
|
271
|
+
|
|
272
|
+
if labels is None:
|
|
273
|
+
labels = sorted(set(true_labels) | set(pred_labels))
|
|
274
|
+
|
|
275
|
+
precision, recall, f1, support = precision_recall_fscore_support(
|
|
276
|
+
true_labels, pred_labels, labels=labels, zero_division=0
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
n_labels = len(labels)
|
|
280
|
+
|
|
281
|
+
if figsize is None:
|
|
282
|
+
figsize = (max(10, n_labels * 0.5), 6)
|
|
283
|
+
|
|
284
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
285
|
+
|
|
286
|
+
x = np.arange(n_labels)
|
|
287
|
+
width = 0.25
|
|
288
|
+
|
|
289
|
+
ax.bar(x - width, precision, width, label="Precision", color="#3784FE")
|
|
290
|
+
ax.bar(x, recall, width, label="Recall", color="#33FF33")
|
|
291
|
+
ax.bar(x + width, f1, width, label="F1", color="#FF6B6B")
|
|
292
|
+
|
|
293
|
+
ax.set_xticks(x)
|
|
294
|
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
|
295
|
+
ax.set_ylabel("Score")
|
|
296
|
+
ax.set_ylim(0, 1.1)
|
|
297
|
+
ax.legend()
|
|
298
|
+
|
|
299
|
+
despine(ax)
|
|
300
|
+
|
|
301
|
+
if title is None:
|
|
302
|
+
title = "Classification Metrics by Class"
|
|
303
|
+
ax.set_title(title)
|
|
304
|
+
|
|
305
|
+
plt.tight_layout()
|
|
306
|
+
|
|
307
|
+
if save:
|
|
308
|
+
save_figure(fig, save)
|
|
309
|
+
|
|
310
|
+
return fig
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def plot_agreement_heatmap(
|
|
314
|
+
adata: ad.AnnData,
|
|
315
|
+
columns: List[str],
|
|
316
|
+
figsize: Optional[tuple] = None,
|
|
317
|
+
cmap: str = "Greens",
|
|
318
|
+
title: Optional[str] = None,
|
|
319
|
+
save: Optional[Union[str, Path]] = None,
|
|
320
|
+
) -> Figure:
|
|
321
|
+
"""
|
|
322
|
+
Plot agreement matrix between annotation methods.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
adata : AnnData
|
|
327
|
+
Annotated data with multiple annotation columns.
|
|
328
|
+
columns : List[str]
|
|
329
|
+
Columns in adata.obs to compare.
|
|
330
|
+
figsize : tuple, optional
|
|
331
|
+
Figure size.
|
|
332
|
+
cmap : str, default "Greens"
|
|
333
|
+
Colormap.
|
|
334
|
+
title : str, optional
|
|
335
|
+
Plot title.
|
|
336
|
+
save : str or Path, optional
|
|
337
|
+
Path to save figure.
|
|
338
|
+
|
|
339
|
+
Returns
|
|
340
|
+
-------
|
|
341
|
+
Figure
|
|
342
|
+
Matplotlib figure.
|
|
343
|
+
"""
|
|
344
|
+
for col in columns:
|
|
345
|
+
if col not in adata.obs.columns:
|
|
346
|
+
raise ValueError(f"Column '{col}' not found.")
|
|
347
|
+
|
|
348
|
+
n_methods = len(columns)
|
|
349
|
+
|
|
350
|
+
# Calculate agreement matrix
|
|
351
|
+
agreement = np.zeros((n_methods, n_methods))
|
|
352
|
+
for i, col1 in enumerate(columns):
|
|
353
|
+
for j, col2 in enumerate(columns):
|
|
354
|
+
agreement[i, j] = (adata.obs[col1] == adata.obs[col2]).mean()
|
|
355
|
+
|
|
356
|
+
if figsize is None:
|
|
357
|
+
figsize = (max(6, n_methods * 1.2), max(5, n_methods))
|
|
358
|
+
|
|
359
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
360
|
+
|
|
361
|
+
im = ax.imshow(agreement, cmap=cmap, vmin=0, vmax=1)
|
|
362
|
+
plt.colorbar(im, ax=ax, label="Agreement")
|
|
363
|
+
|
|
364
|
+
# Add text
|
|
365
|
+
for i in range(n_methods):
|
|
366
|
+
for j in range(n_methods):
|
|
367
|
+
ax.text(
|
|
368
|
+
j,
|
|
369
|
+
i,
|
|
370
|
+
f"{agreement[i, j]:.2f}",
|
|
371
|
+
ha="center",
|
|
372
|
+
va="center",
|
|
373
|
+
color="white" if agreement[i, j] > 0.5 else "black",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
ax.set_xticks(range(n_methods))
|
|
377
|
+
ax.set_yticks(range(n_methods))
|
|
378
|
+
ax.set_xticklabels(columns, rotation=45, ha="right")
|
|
379
|
+
ax.set_yticklabels(columns)
|
|
380
|
+
|
|
381
|
+
if title is None:
|
|
382
|
+
title = "Method Agreement"
|
|
383
|
+
ax.set_title(title)
|
|
384
|
+
|
|
385
|
+
plt.tight_layout()
|
|
386
|
+
|
|
387
|
+
if save:
|
|
388
|
+
save_figure(fig, save)
|
|
389
|
+
|
|
390
|
+
return fig
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def plot_silhouette_by_type(
|
|
394
|
+
adata: ad.AnnData,
|
|
395
|
+
label_column: str,
|
|
396
|
+
embedding_key: str = "X_pca",
|
|
397
|
+
sample_size: int = 5000,
|
|
398
|
+
random_state: int = 42,
|
|
399
|
+
figsize: tuple = (10, 6),
|
|
400
|
+
title: Optional[str] = None,
|
|
401
|
+
save: Optional[Union[str, Path]] = None,
|
|
402
|
+
) -> Figure:
|
|
403
|
+
"""
|
|
404
|
+
Plot silhouette scores by cell type.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
adata : AnnData
|
|
409
|
+
Annotated data.
|
|
410
|
+
label_column : str
|
|
411
|
+
Column in adata.obs containing cell type labels.
|
|
412
|
+
embedding_key : str, default "X_pca"
|
|
413
|
+
Key in adata.obsm for embedding.
|
|
414
|
+
sample_size : int, default 5000
|
|
415
|
+
Number of cells to sample (for speed).
|
|
416
|
+
random_state : int, default 42
|
|
417
|
+
Random seed.
|
|
418
|
+
figsize : tuple, default (10, 6)
|
|
419
|
+
Figure size.
|
|
420
|
+
title : str, optional
|
|
421
|
+
Plot title.
|
|
422
|
+
save : str or Path, optional
|
|
423
|
+
Path to save figure.
|
|
424
|
+
|
|
425
|
+
Returns
|
|
426
|
+
-------
|
|
427
|
+
Figure
|
|
428
|
+
Matplotlib figure.
|
|
429
|
+
"""
|
|
430
|
+
from sklearn.metrics import silhouette_samples
|
|
431
|
+
|
|
432
|
+
if label_column not in adata.obs.columns:
|
|
433
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
434
|
+
if embedding_key not in adata.obsm:
|
|
435
|
+
raise ValueError(f"Embedding '{embedding_key}' not found.")
|
|
436
|
+
|
|
437
|
+
# Sample if too large
|
|
438
|
+
if adata.n_obs > sample_size:
|
|
439
|
+
np.random.seed(random_state)
|
|
440
|
+
idx = np.random.choice(adata.n_obs, sample_size, replace=False)
|
|
441
|
+
X = adata.obsm[embedding_key][idx]
|
|
442
|
+
labels = adata.obs[label_column].values[idx]
|
|
443
|
+
else:
|
|
444
|
+
X = adata.obsm[embedding_key]
|
|
445
|
+
labels = adata.obs[label_column].values
|
|
446
|
+
|
|
447
|
+
# Calculate silhouette scores
|
|
448
|
+
sil_scores = silhouette_samples(X, labels)
|
|
449
|
+
|
|
450
|
+
# Get mean per type
|
|
451
|
+
df = pd.DataFrame({"label": labels, "silhouette": sil_scores})
|
|
452
|
+
type_scores = df.groupby("label")["silhouette"].mean().sort_values()
|
|
453
|
+
|
|
454
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
455
|
+
|
|
456
|
+
colors = generate_celltype_palette(type_scores.index.tolist())
|
|
457
|
+
bar_colors = [colors.get(ct, "#888888") for ct in type_scores.index]
|
|
458
|
+
|
|
459
|
+
y_pos = np.arange(len(type_scores))
|
|
460
|
+
ax.barh(y_pos, type_scores.values, color=bar_colors)
|
|
461
|
+
ax.set_yticks(y_pos)
|
|
462
|
+
ax.set_yticklabels(type_scores.index)
|
|
463
|
+
ax.axvline(0, color="gray", linestyle="--")
|
|
464
|
+
|
|
465
|
+
format_axis_labels(ax, xlabel="Silhouette Score")
|
|
466
|
+
despine(ax)
|
|
467
|
+
|
|
468
|
+
if title is None:
|
|
469
|
+
title = f"Silhouette Scores by Cell Type\n(mean={sil_scores.mean():.3f})"
|
|
470
|
+
ax.set_title(title)
|
|
471
|
+
|
|
472
|
+
plt.tight_layout()
|
|
473
|
+
|
|
474
|
+
if save:
|
|
475
|
+
save_figure(fig, save)
|
|
476
|
+
|
|
477
|
+
return fig
|