spatialcore 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcore/__init__.py +122 -0
- spatialcore/annotation/__init__.py +253 -0
- spatialcore/annotation/acquisition.py +529 -0
- spatialcore/annotation/annotate.py +603 -0
- spatialcore/annotation/cellxgene.py +365 -0
- spatialcore/annotation/confidence.py +802 -0
- spatialcore/annotation/discovery.py +529 -0
- spatialcore/annotation/expression.py +363 -0
- spatialcore/annotation/loading.py +529 -0
- spatialcore/annotation/markers.py +297 -0
- spatialcore/annotation/ontology.py +1282 -0
- spatialcore/annotation/patterns.py +247 -0
- spatialcore/annotation/pipeline.py +620 -0
- spatialcore/annotation/synapse.py +380 -0
- spatialcore/annotation/training.py +1457 -0
- spatialcore/annotation/validation.py +422 -0
- spatialcore/core/__init__.py +34 -0
- spatialcore/core/cache.py +118 -0
- spatialcore/core/logging.py +135 -0
- spatialcore/core/metadata.py +149 -0
- spatialcore/core/utils.py +768 -0
- spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
- spatialcore/data/markers/canonical_markers.json +83 -0
- spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
- spatialcore/plotting/__init__.py +109 -0
- spatialcore/plotting/benchmark.py +477 -0
- spatialcore/plotting/celltype.py +329 -0
- spatialcore/plotting/confidence.py +413 -0
- spatialcore/plotting/spatial.py +505 -0
- spatialcore/plotting/utils.py +411 -0
- spatialcore/plotting/validation.py +1342 -0
- spatialcore-0.1.9.dist-info/METADATA +213 -0
- spatialcore-0.1.9.dist-info/RECORD +36 -0
- spatialcore-0.1.9.dist-info/WHEEL +5 -0
- spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
- spatialcore-0.1.9.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Marker validation visualization.
|
|
3
|
+
|
|
4
|
+
This module provides functions for validating cell type annotations
|
|
5
|
+
using canonical marker genes and GMM-3 thresholding.
|
|
6
|
+
|
|
7
|
+
For marker validation, we use classify_by_threshold with n_components=3
|
|
8
|
+
(trimodal GMM) which handles spatial data's dropout/moderate/high
|
|
9
|
+
expression patterns better than bimodal GMM.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
14
|
+
import gc
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import matplotlib.pyplot as plt
|
|
19
|
+
from matplotlib.figure import Figure
|
|
20
|
+
import anndata as ad
|
|
21
|
+
|
|
22
|
+
from spatialcore.core.logging import get_logger
|
|
23
|
+
from spatialcore.plotting.utils import (
|
|
24
|
+
generate_celltype_palette,
|
|
25
|
+
setup_figure,
|
|
26
|
+
setup_multi_figure,
|
|
27
|
+
save_figure,
|
|
28
|
+
despine,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def plot_marker_heatmap(
|
|
35
|
+
adata: ad.AnnData,
|
|
36
|
+
label_column: str,
|
|
37
|
+
markers: Optional[Dict[str, List[str]]] = None,
|
|
38
|
+
cluster: bool = True,
|
|
39
|
+
layer: Optional[str] = None,
|
|
40
|
+
figsize: Optional[tuple] = None,
|
|
41
|
+
cmap: str = "RdBu_r",
|
|
42
|
+
center: float = 0,
|
|
43
|
+
title: Optional[str] = None,
|
|
44
|
+
save: Optional[Union[str, Path]] = None,
|
|
45
|
+
) -> Figure:
|
|
46
|
+
"""
|
|
47
|
+
Plot marker gene expression heatmap by cell type.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
adata : AnnData
|
|
52
|
+
Annotated data with cell type labels.
|
|
53
|
+
label_column : str
|
|
54
|
+
Column in adata.obs containing cell type labels.
|
|
55
|
+
markers : Dict[str, List[str]], optional
|
|
56
|
+
Marker genes per cell type. If None, uses canonical markers.
|
|
57
|
+
cluster : bool, default True
|
|
58
|
+
Hierarchically cluster cell types.
|
|
59
|
+
layer : str, optional
|
|
60
|
+
Layer to use. If None, uses adata.X.
|
|
61
|
+
figsize : tuple, optional
|
|
62
|
+
Figure size. Auto-calculated if None.
|
|
63
|
+
cmap : str, default "RdBu_r"
|
|
64
|
+
Colormap.
|
|
65
|
+
center : float, default 0
|
|
66
|
+
Value to center colormap on.
|
|
67
|
+
title : str, optional
|
|
68
|
+
Plot title.
|
|
69
|
+
save : str or Path, optional
|
|
70
|
+
Path to save figure.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
Figure
|
|
75
|
+
Matplotlib figure.
|
|
76
|
+
|
|
77
|
+
Examples
|
|
78
|
+
--------
|
|
79
|
+
>>> from spatialcore.plotting.validation import plot_marker_heatmap
|
|
80
|
+
>>> from spatialcore.annotation.markers import CANONICAL_MARKERS
|
|
81
|
+
>>> fig = plot_marker_heatmap(
|
|
82
|
+
... adata,
|
|
83
|
+
... label_column="cell_type",
|
|
84
|
+
... markers=CANONICAL_MARKERS,
|
|
85
|
+
... )
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
import seaborn as sns
|
|
89
|
+
except ImportError:
|
|
90
|
+
raise ImportError("seaborn is required for heatmaps")
|
|
91
|
+
|
|
92
|
+
if label_column not in adata.obs.columns:
|
|
93
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
94
|
+
|
|
95
|
+
# Load canonical markers if not provided
|
|
96
|
+
if markers is None:
|
|
97
|
+
from spatialcore.annotation.markers import load_canonical_markers
|
|
98
|
+
markers = load_canonical_markers()
|
|
99
|
+
|
|
100
|
+
# Collect all marker genes that exist in the data
|
|
101
|
+
all_genes = []
|
|
102
|
+
cell_types_with_markers = []
|
|
103
|
+
|
|
104
|
+
for cell_type in adata.obs[label_column].unique():
|
|
105
|
+
ct_lower = str(cell_type).lower()
|
|
106
|
+
ct_markers = markers.get(ct_lower, [])
|
|
107
|
+
available = [g for g in ct_markers if g in adata.var_names]
|
|
108
|
+
if available:
|
|
109
|
+
all_genes.extend(available)
|
|
110
|
+
cell_types_with_markers.append(cell_type)
|
|
111
|
+
|
|
112
|
+
all_genes = list(dict.fromkeys(all_genes)) # Remove duplicates, preserve order
|
|
113
|
+
|
|
114
|
+
if not all_genes:
|
|
115
|
+
raise ValueError("No marker genes found in data.")
|
|
116
|
+
|
|
117
|
+
# Calculate mean expression per cell type
|
|
118
|
+
cell_types = adata.obs[label_column].unique()
|
|
119
|
+
mean_expr = pd.DataFrame(index=cell_types, columns=all_genes, dtype=float)
|
|
120
|
+
|
|
121
|
+
for ct in cell_types:
|
|
122
|
+
mask = adata.obs[label_column] == ct
|
|
123
|
+
subset = adata[mask, all_genes]
|
|
124
|
+
X = subset.layers[layer] if layer else subset.X
|
|
125
|
+
if hasattr(X, "toarray"):
|
|
126
|
+
X = X.toarray()
|
|
127
|
+
mean_expr.loc[ct] = np.mean(X, axis=0)
|
|
128
|
+
|
|
129
|
+
# Z-score normalize columns
|
|
130
|
+
mean_expr_z = (mean_expr - mean_expr.mean()) / mean_expr.std()
|
|
131
|
+
mean_expr_z = mean_expr_z.fillna(0)
|
|
132
|
+
|
|
133
|
+
# Figure size
|
|
134
|
+
if figsize is None:
|
|
135
|
+
n_types = len(cell_types)
|
|
136
|
+
n_genes = len(all_genes)
|
|
137
|
+
figsize = (max(10, n_genes * 0.3), max(6, n_types * 0.4))
|
|
138
|
+
|
|
139
|
+
# Cluster if requested
|
|
140
|
+
if cluster:
|
|
141
|
+
g = sns.clustermap(
|
|
142
|
+
mean_expr_z,
|
|
143
|
+
cmap=cmap,
|
|
144
|
+
center=center,
|
|
145
|
+
figsize=figsize,
|
|
146
|
+
row_cluster=True,
|
|
147
|
+
col_cluster=True,
|
|
148
|
+
)
|
|
149
|
+
if title:
|
|
150
|
+
g.fig.suptitle(title, y=1.02)
|
|
151
|
+
if save:
|
|
152
|
+
g.savefig(save)
|
|
153
|
+
return g.fig
|
|
154
|
+
else:
|
|
155
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
156
|
+
sns.heatmap(
|
|
157
|
+
mean_expr_z,
|
|
158
|
+
cmap=cmap,
|
|
159
|
+
center=center,
|
|
160
|
+
ax=ax,
|
|
161
|
+
xticklabels=True,
|
|
162
|
+
yticklabels=True,
|
|
163
|
+
)
|
|
164
|
+
ax.set_xlabel("Marker Genes")
|
|
165
|
+
ax.set_ylabel("Cell Types")
|
|
166
|
+
|
|
167
|
+
if title is None:
|
|
168
|
+
title = "Marker Expression Heatmap"
|
|
169
|
+
ax.set_title(title)
|
|
170
|
+
|
|
171
|
+
plt.tight_layout()
|
|
172
|
+
|
|
173
|
+
if save:
|
|
174
|
+
save_figure(fig, save)
|
|
175
|
+
|
|
176
|
+
return fig
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def plot_2d_validation(
|
|
180
|
+
adata: ad.AnnData,
|
|
181
|
+
label_column: str,
|
|
182
|
+
confidence_column: str,
|
|
183
|
+
markers: Optional[Dict[str, List[str]]] = None,
|
|
184
|
+
confidence_threshold: float = 0.8,
|
|
185
|
+
min_cells_per_type: int = 15,
|
|
186
|
+
n_components: int = 3,
|
|
187
|
+
ncols: int = 4,
|
|
188
|
+
figsize_per_panel: Tuple[float, float] = (3, 3),
|
|
189
|
+
save: Optional[Union[str, Path]] = None,
|
|
190
|
+
) -> Tuple[Figure, pd.DataFrame]:
|
|
191
|
+
"""
|
|
192
|
+
2D marker validation plot per cell type.
|
|
193
|
+
|
|
194
|
+
For each cell type, plots confidence (x-axis) vs marker metagene score
|
|
195
|
+
(y-axis). Cells are colored green if above both thresholds, red otherwise.
|
|
196
|
+
|
|
197
|
+
Uses GMM-3 via classify_by_threshold() to find marker threshold.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
adata : AnnData
|
|
202
|
+
Annotated data with cell type labels and confidence.
|
|
203
|
+
label_column : str
|
|
204
|
+
Column in adata.obs containing cell type labels.
|
|
205
|
+
confidence_column : str
|
|
206
|
+
Column in adata.obs containing confidence values.
|
|
207
|
+
markers : Dict[str, List[str]], optional
|
|
208
|
+
Marker genes per cell type. If None, uses canonical markers.
|
|
209
|
+
confidence_threshold : float, default 0.8
|
|
210
|
+
Confidence threshold for validation.
|
|
211
|
+
min_cells_per_type : int, default 15
|
|
212
|
+
Minimum cells required to plot a cell type.
|
|
213
|
+
n_components : int, default 3
|
|
214
|
+
Number of GMM components (3 for trimodal spatial data).
|
|
215
|
+
ncols : int, default 4
|
|
216
|
+
Number of columns in subplot grid.
|
|
217
|
+
figsize_per_panel : Tuple[float, float], default (3, 3)
|
|
218
|
+
Size per panel.
|
|
219
|
+
save : str or Path, optional
|
|
220
|
+
Path to save figure.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
Tuple[Figure, pd.DataFrame]
|
|
225
|
+
Figure and validation summary DataFrame.
|
|
226
|
+
|
|
227
|
+
Examples
|
|
228
|
+
--------
|
|
229
|
+
>>> from spatialcore.plotting.validation import plot_2d_validation
|
|
230
|
+
>>> fig, summary = plot_2d_validation(
|
|
231
|
+
... adata,
|
|
232
|
+
... label_column="cell_type",
|
|
233
|
+
... confidence_column="confidence",
|
|
234
|
+
... confidence_threshold=0.7,
|
|
235
|
+
... )
|
|
236
|
+
>>> print(summary)
|
|
237
|
+
"""
|
|
238
|
+
from spatialcore.stats.classify import classify_by_threshold
|
|
239
|
+
|
|
240
|
+
if label_column not in adata.obs.columns:
|
|
241
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
242
|
+
if confidence_column not in adata.obs.columns:
|
|
243
|
+
raise ValueError(f"Confidence column '{confidence_column}' not found.")
|
|
244
|
+
|
|
245
|
+
# Load canonical markers if not provided
|
|
246
|
+
if markers is None:
|
|
247
|
+
from spatialcore.annotation.markers import load_canonical_markers
|
|
248
|
+
markers = load_canonical_markers()
|
|
249
|
+
|
|
250
|
+
# Find cell types with enough cells and available markers
|
|
251
|
+
cell_types = adata.obs[label_column].value_counts()
|
|
252
|
+
cell_types = cell_types[cell_types >= min_cells_per_type].index.tolist()
|
|
253
|
+
|
|
254
|
+
types_to_plot = []
|
|
255
|
+
for ct in cell_types:
|
|
256
|
+
ct_lower = str(ct).lower()
|
|
257
|
+
ct_markers = markers.get(ct_lower, [])
|
|
258
|
+
available = [g for g in ct_markers if g in adata.var_names]
|
|
259
|
+
if len(available) >= 2: # Need at least 2 markers
|
|
260
|
+
types_to_plot.append((ct, available))
|
|
261
|
+
|
|
262
|
+
if not types_to_plot:
|
|
263
|
+
raise ValueError("No cell types with sufficient markers found.")
|
|
264
|
+
|
|
265
|
+
# Pre-compute GMM results for all cell types to know which will succeed
|
|
266
|
+
successful_types = []
|
|
267
|
+
for cell_type, ct_markers in types_to_plot:
|
|
268
|
+
mask = adata.obs[label_column] == cell_type
|
|
269
|
+
subset = adata[mask].copy()
|
|
270
|
+
confidence = subset.obs[confidence_column].values
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
subset_result = classify_by_threshold(
|
|
274
|
+
subset,
|
|
275
|
+
feature_columns=ct_markers,
|
|
276
|
+
threshold_method="gmm",
|
|
277
|
+
n_components=n_components,
|
|
278
|
+
metagene_method="shifted_geometric_mean",
|
|
279
|
+
plot=False,
|
|
280
|
+
copy=True,
|
|
281
|
+
)
|
|
282
|
+
metagene_scores = subset_result.obs["threshold_score"].values
|
|
283
|
+
marker_threshold = subset_result.uns.get(
|
|
284
|
+
"threshold_params", {}
|
|
285
|
+
).get("threshold", np.median(metagene_scores))
|
|
286
|
+
|
|
287
|
+
successful_types.append({
|
|
288
|
+
"cell_type": cell_type,
|
|
289
|
+
"markers": ct_markers,
|
|
290
|
+
"subset": subset,
|
|
291
|
+
"confidence": confidence,
|
|
292
|
+
"metagene_scores": metagene_scores,
|
|
293
|
+
"marker_threshold": marker_threshold,
|
|
294
|
+
})
|
|
295
|
+
except Exception as e:
|
|
296
|
+
logger.warning(f"Skipping {cell_type}: GMM failed - {e}")
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
if not successful_types:
|
|
300
|
+
raise ValueError("GMM failed for all cell types.")
|
|
301
|
+
|
|
302
|
+
# Create grid with only successful cell types
|
|
303
|
+
n_types = len(successful_types)
|
|
304
|
+
nrows = int(np.ceil(n_types / ncols))
|
|
305
|
+
figsize = (figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows)
|
|
306
|
+
|
|
307
|
+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
|
|
308
|
+
axes = np.atleast_2d(axes).flatten()
|
|
309
|
+
|
|
310
|
+
summary_rows = []
|
|
311
|
+
|
|
312
|
+
for i, data in enumerate(successful_types):
|
|
313
|
+
ax = axes[i]
|
|
314
|
+
cell_type = data["cell_type"]
|
|
315
|
+
ct_markers = data["markers"]
|
|
316
|
+
subset = data["subset"]
|
|
317
|
+
confidence = data["confidence"]
|
|
318
|
+
metagene_scores = data["metagene_scores"]
|
|
319
|
+
marker_threshold = data["marker_threshold"]
|
|
320
|
+
|
|
321
|
+
# Classify cells into three groups per spec section 3.1:
|
|
322
|
+
# - Red: Low confidence (uncertain)
|
|
323
|
+
# - Green: High confidence only (validated)
|
|
324
|
+
# - Yellow: High conf + High marker (strongly validated)
|
|
325
|
+
high_conf = confidence >= confidence_threshold
|
|
326
|
+
high_marker = metagene_scores >= marker_threshold
|
|
327
|
+
|
|
328
|
+
low_conf = ~high_conf # Red
|
|
329
|
+
high_conf_low_marker = high_conf & ~high_marker # Green
|
|
330
|
+
high_conf_high_marker = high_conf & high_marker # Yellow
|
|
331
|
+
|
|
332
|
+
# Plot in order: red (back), green, yellow (front)
|
|
333
|
+
ax.scatter(
|
|
334
|
+
confidence[low_conf],
|
|
335
|
+
metagene_scores[low_conf],
|
|
336
|
+
c="red",
|
|
337
|
+
s=5,
|
|
338
|
+
alpha=0.5,
|
|
339
|
+
label="Low Conf",
|
|
340
|
+
)
|
|
341
|
+
ax.scatter(
|
|
342
|
+
confidence[high_conf_low_marker],
|
|
343
|
+
metagene_scores[high_conf_low_marker],
|
|
344
|
+
c="green",
|
|
345
|
+
s=5,
|
|
346
|
+
alpha=0.5,
|
|
347
|
+
label="High Conf",
|
|
348
|
+
)
|
|
349
|
+
ax.scatter(
|
|
350
|
+
confidence[high_conf_high_marker],
|
|
351
|
+
metagene_scores[high_conf_high_marker],
|
|
352
|
+
c="gold",
|
|
353
|
+
s=5,
|
|
354
|
+
alpha=0.5,
|
|
355
|
+
label="High Conf + Marker",
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Threshold lines
|
|
359
|
+
ax.axvline(confidence_threshold, color="gray", linestyle="--", alpha=0.5)
|
|
360
|
+
ax.axhline(marker_threshold, color="gray", linestyle="--", alpha=0.5)
|
|
361
|
+
|
|
362
|
+
ax.set_xlabel("Confidence")
|
|
363
|
+
ax.set_ylabel("Marker Score")
|
|
364
|
+
ax.set_xlim(0, 1) # Fixed x-axis for consistent comparison
|
|
365
|
+
ax.set_xticks([0, 0.5, 1.0])
|
|
366
|
+
ax.set_title(f"{cell_type}\n(n={len(subset)})", fontsize=9)
|
|
367
|
+
|
|
368
|
+
# Summary stats
|
|
369
|
+
n_low_conf = low_conf.sum()
|
|
370
|
+
n_high_conf = high_conf_low_marker.sum()
|
|
371
|
+
n_validated = high_conf_high_marker.sum()
|
|
372
|
+
pct_validated = 100 * n_validated / len(subset) if len(subset) > 0 else 0
|
|
373
|
+
pct_high_conf = 100 * (n_high_conf + n_validated) / len(subset) if len(subset) > 0 else 0
|
|
374
|
+
|
|
375
|
+
summary_rows.append({
|
|
376
|
+
"cell_type": cell_type,
|
|
377
|
+
"n_cells": len(subset),
|
|
378
|
+
"n_low_conf": n_low_conf,
|
|
379
|
+
"n_high_conf": n_high_conf,
|
|
380
|
+
"n_validated": n_validated,
|
|
381
|
+
"pct_validated": pct_validated,
|
|
382
|
+
"pct_high_conf": pct_high_conf,
|
|
383
|
+
"marker_threshold": marker_threshold,
|
|
384
|
+
"n_markers": len(ct_markers),
|
|
385
|
+
})
|
|
386
|
+
|
|
387
|
+
# Hide unused panels (when n_types doesn't fill full grid)
|
|
388
|
+
for i in range(len(successful_types), len(axes)):
|
|
389
|
+
axes[i].set_visible(False)
|
|
390
|
+
|
|
391
|
+
plt.tight_layout()
|
|
392
|
+
|
|
393
|
+
if save:
|
|
394
|
+
save_figure(fig, save)
|
|
395
|
+
|
|
396
|
+
summary_df = pd.DataFrame(summary_rows)
|
|
397
|
+
return fig, summary_df
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def plot_marker_dotplot(
|
|
401
|
+
adata: ad.AnnData,
|
|
402
|
+
label_column: str,
|
|
403
|
+
markers: Optional[Dict[str, List[str]]] = None,
|
|
404
|
+
layer: Optional[str] = None,
|
|
405
|
+
figsize: Optional[tuple] = None,
|
|
406
|
+
cmap: str = "Reds",
|
|
407
|
+
title: Optional[str] = None,
|
|
408
|
+
save: Optional[Union[str, Path]] = None,
|
|
409
|
+
) -> Figure:
|
|
410
|
+
"""
|
|
411
|
+
Plot marker expression as dot plot.
|
|
412
|
+
|
|
413
|
+
Dot size represents fraction of cells expressing the marker,
|
|
414
|
+
color intensity represents mean expression.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
adata : AnnData
|
|
419
|
+
Annotated data with cell type labels.
|
|
420
|
+
label_column : str
|
|
421
|
+
Column in adata.obs containing cell type labels.
|
|
422
|
+
markers : Dict[str, List[str]], optional
|
|
423
|
+
Marker genes per cell type.
|
|
424
|
+
layer : str, optional
|
|
425
|
+
Layer to use.
|
|
426
|
+
figsize : tuple, optional
|
|
427
|
+
Figure size.
|
|
428
|
+
cmap : str, default "Reds"
|
|
429
|
+
Colormap.
|
|
430
|
+
title : str, optional
|
|
431
|
+
Plot title.
|
|
432
|
+
save : str or Path, optional
|
|
433
|
+
Path to save figure.
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
Figure
|
|
438
|
+
Matplotlib figure.
|
|
439
|
+
"""
|
|
440
|
+
if label_column not in adata.obs.columns:
|
|
441
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
442
|
+
|
|
443
|
+
# Load canonical markers if not provided
|
|
444
|
+
if markers is None:
|
|
445
|
+
from spatialcore.annotation.markers import load_canonical_markers
|
|
446
|
+
markers = load_canonical_markers()
|
|
447
|
+
|
|
448
|
+
# Collect all marker genes
|
|
449
|
+
all_genes = []
|
|
450
|
+
for ct in adata.obs[label_column].unique():
|
|
451
|
+
ct_lower = str(ct).lower()
|
|
452
|
+
ct_markers = markers.get(ct_lower, [])
|
|
453
|
+
available = [g for g in ct_markers if g in adata.var_names]
|
|
454
|
+
all_genes.extend(available)
|
|
455
|
+
|
|
456
|
+
all_genes = list(dict.fromkeys(all_genes))
|
|
457
|
+
|
|
458
|
+
if not all_genes:
|
|
459
|
+
raise ValueError("No marker genes found in data.")
|
|
460
|
+
|
|
461
|
+
cell_types = sorted(adata.obs[label_column].unique())
|
|
462
|
+
|
|
463
|
+
# Calculate expression fraction and mean
|
|
464
|
+
n_types = len(cell_types)
|
|
465
|
+
n_genes = len(all_genes)
|
|
466
|
+
|
|
467
|
+
frac_expr = np.zeros((n_types, n_genes))
|
|
468
|
+
mean_expr = np.zeros((n_types, n_genes))
|
|
469
|
+
|
|
470
|
+
for i, ct in enumerate(cell_types):
|
|
471
|
+
mask = adata.obs[label_column] == ct
|
|
472
|
+
subset = adata[mask, all_genes]
|
|
473
|
+
X = subset.layers[layer] if layer else subset.X
|
|
474
|
+
if hasattr(X, "toarray"):
|
|
475
|
+
X = X.toarray()
|
|
476
|
+
|
|
477
|
+
frac_expr[i] = (X > 0).mean(axis=0)
|
|
478
|
+
mean_expr[i] = X.mean(axis=0)
|
|
479
|
+
|
|
480
|
+
# Normalize mean expression per gene
|
|
481
|
+
mean_expr_norm = (mean_expr - mean_expr.min(axis=0)) / (
|
|
482
|
+
mean_expr.max(axis=0) - mean_expr.min(axis=0) + 1e-10
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if figsize is None:
|
|
486
|
+
figsize = (max(10, n_genes * 0.4), max(6, n_types * 0.4))
|
|
487
|
+
|
|
488
|
+
fig, ax = setup_figure(figsize=figsize)
|
|
489
|
+
|
|
490
|
+
# Create dot plot
|
|
491
|
+
for i, ct in enumerate(cell_types):
|
|
492
|
+
for j, gene in enumerate(all_genes):
|
|
493
|
+
size = frac_expr[i, j] * 300 # Scale for visibility
|
|
494
|
+
color = mean_expr_norm[i, j]
|
|
495
|
+
ax.scatter(j, i, s=size, c=[color], cmap=cmap, vmin=0, vmax=1)
|
|
496
|
+
|
|
497
|
+
ax.set_xticks(range(n_genes))
|
|
498
|
+
ax.set_xticklabels(all_genes, rotation=90)
|
|
499
|
+
ax.set_yticks(range(n_types))
|
|
500
|
+
ax.set_yticklabels(cell_types)
|
|
501
|
+
|
|
502
|
+
ax.set_xlabel("Marker Genes")
|
|
503
|
+
ax.set_ylabel("Cell Types")
|
|
504
|
+
|
|
505
|
+
if title is None:
|
|
506
|
+
title = "Marker Expression Dotplot"
|
|
507
|
+
ax.set_title(title)
|
|
508
|
+
|
|
509
|
+
# Add legend for size
|
|
510
|
+
size_legend = [0.25, 0.5, 0.75, 1.0]
|
|
511
|
+
legend_x = n_genes + 0.5
|
|
512
|
+
for k, frac in enumerate(size_legend):
|
|
513
|
+
ax.scatter(legend_x, k, s=frac * 300, c="gray")
|
|
514
|
+
ax.text(legend_x + 0.3, k, f"{int(frac*100)}%", va="center")
|
|
515
|
+
ax.text(legend_x, len(size_legend) + 0.5, "% Expressing", fontsize=9)
|
|
516
|
+
|
|
517
|
+
ax.set_xlim(-0.5, n_genes + 2)
|
|
518
|
+
|
|
519
|
+
plt.tight_layout()
|
|
520
|
+
|
|
521
|
+
if save:
|
|
522
|
+
save_figure(fig, save)
|
|
523
|
+
|
|
524
|
+
return fig
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def plot_celltype_confidence(
|
|
528
|
+
adata: ad.AnnData,
|
|
529
|
+
label_column: str,
|
|
530
|
+
confidence_column: str,
|
|
531
|
+
spatial_key: str = "spatial",
|
|
532
|
+
threshold: float = 0.8,
|
|
533
|
+
max_cell_types: int = 20,
|
|
534
|
+
figsize: Tuple[float, float] = (14, 6),
|
|
535
|
+
save: Optional[Union[str, Path]] = None,
|
|
536
|
+
) -> Figure:
|
|
537
|
+
"""
|
|
538
|
+
Cell type confidence visualization with spatial and jitter plots.
|
|
539
|
+
|
|
540
|
+
Creates a two-panel figure:
|
|
541
|
+
- Left panel: Spatial scatter plot colored by cell type
|
|
542
|
+
- Right panel: Jitter plot (x=cell type, y=confidence) with threshold line
|
|
543
|
+
|
|
544
|
+
Parameters
|
|
545
|
+
----------
|
|
546
|
+
adata : AnnData
|
|
547
|
+
Annotated data with cell type labels and confidence scores.
|
|
548
|
+
label_column : str
|
|
549
|
+
Column in adata.obs containing cell type labels.
|
|
550
|
+
confidence_column : str
|
|
551
|
+
Column in adata.obs containing confidence values.
|
|
552
|
+
spatial_key : str, default "spatial"
|
|
553
|
+
Key in adata.obsm for spatial coordinates.
|
|
554
|
+
threshold : float, default 0.8
|
|
555
|
+
Confidence threshold line to display.
|
|
556
|
+
max_cell_types : int, default 20
|
|
557
|
+
Maximum number of cell types to show in jitter plot.
|
|
558
|
+
figsize : Tuple[float, float], default (14, 6)
|
|
559
|
+
Figure size (width, height).
|
|
560
|
+
save : str or Path, optional
|
|
561
|
+
Path to save figure.
|
|
562
|
+
|
|
563
|
+
Returns
|
|
564
|
+
-------
|
|
565
|
+
Figure
|
|
566
|
+
Matplotlib figure.
|
|
567
|
+
|
|
568
|
+
Examples
|
|
569
|
+
--------
|
|
570
|
+
>>> from spatialcore.plotting.validation import plot_celltype_confidence
|
|
571
|
+
>>> fig = plot_celltype_confidence(
|
|
572
|
+
... adata,
|
|
573
|
+
... label_column="celltypist",
|
|
574
|
+
... confidence_column="celltypist_confidence_transformed",
|
|
575
|
+
... threshold=0.8,
|
|
576
|
+
... )
|
|
577
|
+
"""
|
|
578
|
+
if label_column not in adata.obs.columns:
|
|
579
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
580
|
+
if confidence_column not in adata.obs.columns:
|
|
581
|
+
raise ValueError(f"Confidence column '{confidence_column}' not found.")
|
|
582
|
+
|
|
583
|
+
fig, (ax_spatial, ax_jitter) = plt.subplots(1, 2, figsize=figsize)
|
|
584
|
+
|
|
585
|
+
# Get cell types and palette
|
|
586
|
+
cell_types = adata.obs[label_column].unique()
|
|
587
|
+
colors = generate_celltype_palette(cell_types)
|
|
588
|
+
|
|
589
|
+
# Left: Spatial plot colored by confidence z-score
|
|
590
|
+
if spatial_key in adata.obsm:
|
|
591
|
+
coords = adata.obsm[spatial_key]
|
|
592
|
+
conf_values = adata.obs[confidence_column].values
|
|
593
|
+
|
|
594
|
+
# Create scatter colored by confidence
|
|
595
|
+
scatter = ax_spatial.scatter(
|
|
596
|
+
coords[:, 0],
|
|
597
|
+
coords[:, 1],
|
|
598
|
+
s=1,
|
|
599
|
+
alpha=0.7,
|
|
600
|
+
c=conf_values,
|
|
601
|
+
cmap="RdYlGn", # Red (low) -> Yellow -> Green (high)
|
|
602
|
+
vmin=np.nanpercentile(conf_values, 5),
|
|
603
|
+
vmax=np.nanpercentile(conf_values, 95),
|
|
604
|
+
rasterized=True,
|
|
605
|
+
)
|
|
606
|
+
cbar = plt.colorbar(scatter, ax=ax_spatial, shrink=0.7, pad=0.02)
|
|
607
|
+
cbar.set_label("Confidence (z-score)", fontsize=9)
|
|
608
|
+
ax_spatial.set_title("Spatial Confidence")
|
|
609
|
+
ax_spatial.set_xlabel("X")
|
|
610
|
+
ax_spatial.set_ylabel("Y")
|
|
611
|
+
ax_spatial.axis("equal")
|
|
612
|
+
despine(ax_spatial)
|
|
613
|
+
else:
|
|
614
|
+
ax_spatial.text(
|
|
615
|
+
0.5, 0.5,
|
|
616
|
+
f"No spatial coordinates found\n(key: '{spatial_key}')",
|
|
617
|
+
ha="center", va="center", transform=ax_spatial.transAxes
|
|
618
|
+
)
|
|
619
|
+
ax_spatial.set_title("Spatial Confidence (N/A)")
|
|
620
|
+
|
|
621
|
+
# Right: Jitter plot - sort by median confidence (flipped: cell types on Y-axis)
|
|
622
|
+
ct_median = adata.obs.groupby(label_column, observed=True)[confidence_column].median()
|
|
623
|
+
ct_order = ct_median.sort_values(ascending=True).index.tolist() # Ascending so highest at top
|
|
624
|
+
|
|
625
|
+
# Limit number of cell types shown
|
|
626
|
+
if len(ct_order) > max_cell_types:
|
|
627
|
+
# Keep top by median confidence (will be at top of plot)
|
|
628
|
+
ct_order = ct_median.sort_values(ascending=False).index.tolist()[:max_cell_types]
|
|
629
|
+
ct_order = ct_order[::-1] # Reverse so highest at top
|
|
630
|
+
logger.info(f"Showing top {max_cell_types} cell types by median confidence")
|
|
631
|
+
|
|
632
|
+
for i, ct in enumerate(ct_order):
|
|
633
|
+
mask = adata.obs[label_column] == ct
|
|
634
|
+
conf = adata.obs.loc[mask, confidence_column].values
|
|
635
|
+
# Add jitter on Y-axis (cell type position)
|
|
636
|
+
y = np.random.normal(i, 0.15, len(conf))
|
|
637
|
+
color = colors.get(str(ct), "gray")
|
|
638
|
+
ax_jitter.scatter(conf, y, s=3, alpha=0.3, c=[color], rasterized=True)
|
|
639
|
+
|
|
640
|
+
# Threshold line (vertical now)
|
|
641
|
+
ax_jitter.axvline(
|
|
642
|
+
threshold, color="red", linestyle="--", lw=2, label=f"Threshold={threshold}"
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
ax_jitter.set_yticks(range(len(ct_order)))
|
|
646
|
+
ax_jitter.set_yticklabels(
|
|
647
|
+
[str(ct)[:25] for ct in ct_order], # Truncate long names
|
|
648
|
+
fontsize=8
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Set x-axis to fixed 0-1 range with 0.5 increments for consistent comparison
|
|
652
|
+
ax_jitter.set_xlim(0, 1)
|
|
653
|
+
ax_jitter.set_xticks([0, 0.5, 1.0])
|
|
654
|
+
|
|
655
|
+
ax_jitter.set_xlabel("Confidence (z-score)")
|
|
656
|
+
ax_jitter.set_ylabel("")
|
|
657
|
+
ax_jitter.set_title("Confidence by Cell Type")
|
|
658
|
+
ax_jitter.legend(loc="lower right")
|
|
659
|
+
despine(ax_jitter)
|
|
660
|
+
|
|
661
|
+
plt.tight_layout()
|
|
662
|
+
|
|
663
|
+
if save:
|
|
664
|
+
save_figure(fig, save)
|
|
665
|
+
|
|
666
|
+
return fig
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def plot_deg_heatmap(
|
|
670
|
+
adata: ad.AnnData,
|
|
671
|
+
label_column: str,
|
|
672
|
+
n_genes: int = 5,
|
|
673
|
+
method: str = "wilcoxon",
|
|
674
|
+
layer: Optional[str] = None,
|
|
675
|
+
figsize: Optional[Tuple[float, float]] = None,
|
|
676
|
+
cmap: str = "viridis",
|
|
677
|
+
save: Optional[Union[str, Path]] = None,
|
|
678
|
+
title: Optional[str] = None,
|
|
679
|
+
) -> Figure:
|
|
680
|
+
"""
|
|
681
|
+
DEG heatmap showing top marker genes per cell type.
|
|
682
|
+
|
|
683
|
+
Runs differential expression analysis and displays a heatmap with genes
|
|
684
|
+
on rows, cell types on columns, row color annotation bar, and cell type legend.
|
|
685
|
+
Uses exact plotting code from celltypist_demonstration_plots.py.
|
|
686
|
+
|
|
687
|
+
Parameters
|
|
688
|
+
----------
|
|
689
|
+
adata : AnnData
|
|
690
|
+
Annotated data with cell type labels.
|
|
691
|
+
label_column : str
|
|
692
|
+
Column in adata.obs containing cell type labels.
|
|
693
|
+
n_genes : int, default 5
|
|
694
|
+
Number of top DEGs per cell type to include.
|
|
695
|
+
method : str, default "wilcoxon"
|
|
696
|
+
DEG method ("wilcoxon", "t-test", "t-test_overestim_var", "logreg").
|
|
697
|
+
layer : str, optional
|
|
698
|
+
Expression layer to use. If None, uses adata.X.
|
|
699
|
+
figsize : Tuple[float, float], optional
|
|
700
|
+
Figure size. Auto-calculated if None.
|
|
701
|
+
cmap : str, default "viridis"
|
|
702
|
+
Colormap for heatmap.
|
|
703
|
+
save : str or Path, optional
|
|
704
|
+
Path to save figure.
|
|
705
|
+
title : str, optional
|
|
706
|
+
Plot title. Defaults to "Marker Genes ({label_column})".
|
|
707
|
+
|
|
708
|
+
Returns
|
|
709
|
+
-------
|
|
710
|
+
Figure
|
|
711
|
+
Matplotlib figure.
|
|
712
|
+
"""
|
|
713
|
+
import scanpy as sc
|
|
714
|
+
import seaborn as sns
|
|
715
|
+
|
|
716
|
+
if label_column not in adata.obs.columns:
|
|
717
|
+
raise ValueError(f"Label column '{label_column}' not found.")
|
|
718
|
+
|
|
719
|
+
# Build filter mask on original data (no copies)
|
|
720
|
+
mask_assigned = adata.obs[label_column] != "Unassigned"
|
|
721
|
+
|
|
722
|
+
# Get valid cell types from view (no copy)
|
|
723
|
+
ct_counts = adata.obs.loc[mask_assigned, label_column].value_counts()
|
|
724
|
+
valid_cts = ct_counts[ct_counts >= 10].index.tolist()
|
|
725
|
+
if len(valid_cts) < 2:
|
|
726
|
+
raise ValueError("Need at least 2 cell types with >= 10 cells each for DEG analysis.")
|
|
727
|
+
|
|
728
|
+
# Combined mask - applied once
|
|
729
|
+
mask = mask_assigned & adata.obs[label_column].isin(valid_cts)
|
|
730
|
+
|
|
731
|
+
# Manual AnnData construction to avoid anndata 0.12.x memory bug in adata[mask].copy()
|
|
732
|
+
indices = np.where(mask)[0]
|
|
733
|
+
X_sub = adata.X[indices]
|
|
734
|
+
if layer and layer in adata.layers:
|
|
735
|
+
X_sub = adata.layers[layer][indices]
|
|
736
|
+
obs_sub = adata.obs.iloc[indices].copy()
|
|
737
|
+
obs_sub[label_column] = obs_sub[label_column].astype(str).astype("category")
|
|
738
|
+
adata_deg = ad.AnnData(X=X_sub, obs=obs_sub, var=adata.var.copy())
|
|
739
|
+
|
|
740
|
+
logger.info(f"Running rank_genes_groups with method={method} on {len(valid_cts)} cell types...")
|
|
741
|
+
sc.tl.rank_genes_groups(adata_deg, label_column, method=method, n_genes=50)
|
|
742
|
+
|
|
743
|
+
# Get marker genes per cell type
|
|
744
|
+
results = adata_deg.uns["rank_genes_groups"]
|
|
745
|
+
cell_types = sorted(results["names"].dtype.names)
|
|
746
|
+
|
|
747
|
+
# Collect genes grouped by cell type (preserve order)
|
|
748
|
+
all_genes = []
|
|
749
|
+
gene_to_celltype = {}
|
|
750
|
+
seen = set()
|
|
751
|
+
for ct in cell_types:
|
|
752
|
+
genes = results["names"][ct][:n_genes]
|
|
753
|
+
for gene in genes:
|
|
754
|
+
if gene not in seen and gene in adata_deg.var_names:
|
|
755
|
+
all_genes.append(gene)
|
|
756
|
+
gene_to_celltype[gene] = ct
|
|
757
|
+
seen.add(gene)
|
|
758
|
+
|
|
759
|
+
if len(all_genes) == 0:
|
|
760
|
+
raise ValueError("No valid marker genes found")
|
|
761
|
+
|
|
762
|
+
# Calculate mean expression per cell type
|
|
763
|
+
expr_matrix = np.zeros((len(all_genes), len(cell_types)))
|
|
764
|
+
for j, ct in enumerate(cell_types):
|
|
765
|
+
ct_mask = adata_deg.obs[label_column] == ct
|
|
766
|
+
if ct_mask.sum() == 0:
|
|
767
|
+
continue
|
|
768
|
+
adata_ct = adata_deg[ct_mask]
|
|
769
|
+
for i, gene in enumerate(all_genes):
|
|
770
|
+
gene_idx = adata_deg.var_names.get_loc(gene)
|
|
771
|
+
if hasattr(adata_ct.X, "toarray"):
|
|
772
|
+
expr = adata_ct.X[:, gene_idx].toarray().flatten()
|
|
773
|
+
else:
|
|
774
|
+
expr = np.asarray(adata_ct.X[:, gene_idx]).flatten()
|
|
775
|
+
expr_matrix[i, j] = np.mean(expr)
|
|
776
|
+
|
|
777
|
+
# Release the copy - no longer needed after expression matrix is built
|
|
778
|
+
del adata_deg
|
|
779
|
+
gc.collect()
|
|
780
|
+
|
|
781
|
+
# Z-score normalize across cell types (rows)
|
|
782
|
+
expr_scaled = np.zeros_like(expr_matrix)
|
|
783
|
+
for i in range(expr_matrix.shape[0]):
|
|
784
|
+
row = expr_matrix[i, :]
|
|
785
|
+
if row.std() > 0:
|
|
786
|
+
expr_scaled[i, :] = (row - row.mean()) / row.std()
|
|
787
|
+
else:
|
|
788
|
+
expr_scaled[i, :] = 0
|
|
789
|
+
|
|
790
|
+
# Create color palette for cell types
|
|
791
|
+
n_cts = len(cell_types)
|
|
792
|
+
palette = sns.color_palette("tab20", n_cts)
|
|
793
|
+
ct_to_color = {ct: palette[i] for i, ct in enumerate(cell_types)}
|
|
794
|
+
|
|
795
|
+
# Create row colors array
|
|
796
|
+
row_colors = [ct_to_color[gene_to_celltype[gene]] for gene in all_genes]
|
|
797
|
+
|
|
798
|
+
# Create figure with gridspec for custom layout
|
|
799
|
+
fig_height = max(10, len(all_genes) * 0.11)
|
|
800
|
+
fig = plt.figure(figsize=(12, fig_height))
|
|
801
|
+
|
|
802
|
+
# GridSpec with annotation bar flush against heatmap
|
|
803
|
+
gs = fig.add_gridspec(
|
|
804
|
+
1, 5,
|
|
805
|
+
width_ratios=[0.06, 0.012, 1, 0.02, 0.015],
|
|
806
|
+
wspace=0.0,
|
|
807
|
+
left=0.01,
|
|
808
|
+
right=0.78,
|
|
809
|
+
top=0.95,
|
|
810
|
+
bottom=0.12,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Gene labels axis (far left)
|
|
814
|
+
ax_labels = fig.add_subplot(gs[0, 0])
|
|
815
|
+
ax_labels.set_ylim(0, len(all_genes))
|
|
816
|
+
ax_labels.set_xlim(0, 1)
|
|
817
|
+
ax_labels.invert_yaxis()
|
|
818
|
+
for i, gene in enumerate(all_genes):
|
|
819
|
+
ax_labels.text(
|
|
820
|
+
0.98, i + 0.5, gene,
|
|
821
|
+
ha="right", va="center",
|
|
822
|
+
fontsize=5,
|
|
823
|
+
color="black",
|
|
824
|
+
)
|
|
825
|
+
ax_labels.axis("off")
|
|
826
|
+
|
|
827
|
+
# Row colors axis (annotation bar)
|
|
828
|
+
ax_rowcolors = fig.add_subplot(gs[0, 1])
|
|
829
|
+
for i, color in enumerate(row_colors):
|
|
830
|
+
ax_rowcolors.add_patch(plt.Rectangle(
|
|
831
|
+
(0, i), 1, 1,
|
|
832
|
+
facecolor=color,
|
|
833
|
+
edgecolor="none",
|
|
834
|
+
))
|
|
835
|
+
ax_rowcolors.set_xlim(0, 1)
|
|
836
|
+
ax_rowcolors.set_ylim(0, len(all_genes))
|
|
837
|
+
ax_rowcolors.invert_yaxis()
|
|
838
|
+
ax_rowcolors.axis("off")
|
|
839
|
+
|
|
840
|
+
# Main heatmap axis
|
|
841
|
+
ax_heatmap = fig.add_subplot(gs[0, 2])
|
|
842
|
+
im = ax_heatmap.imshow(
|
|
843
|
+
expr_scaled,
|
|
844
|
+
aspect="auto",
|
|
845
|
+
cmap=cmap,
|
|
846
|
+
vmin=-2.5,
|
|
847
|
+
vmax=2.5,
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
ax_heatmap.set_yticks([])
|
|
851
|
+
ax_heatmap.set_xticks(range(len(cell_types)))
|
|
852
|
+
ax_heatmap.set_xticklabels(cell_types, rotation=45, ha="right", fontsize=8)
|
|
853
|
+
|
|
854
|
+
# Gap between heatmap and colorbar
|
|
855
|
+
ax_gap = fig.add_subplot(gs[0, 3])
|
|
856
|
+
ax_gap.axis("off")
|
|
857
|
+
|
|
858
|
+
# Colorbar axis
|
|
859
|
+
ax_cbar = fig.add_subplot(gs[0, 4])
|
|
860
|
+
cbar = plt.colorbar(im, cax=ax_cbar)
|
|
861
|
+
cbar.set_label("Scaled expression", fontsize=9)
|
|
862
|
+
cbar.ax.tick_params(labelsize=7)
|
|
863
|
+
|
|
864
|
+
# Cell type legend
|
|
865
|
+
legend_handles = [
|
|
866
|
+
plt.Rectangle((0, 0), 1, 1, facecolor=ct_to_color[ct], label=ct)
|
|
867
|
+
for ct in cell_types
|
|
868
|
+
]
|
|
869
|
+
fig.legend(
|
|
870
|
+
handles=legend_handles,
|
|
871
|
+
loc="upper left",
|
|
872
|
+
bbox_to_anchor=(0.80, 0.95),
|
|
873
|
+
fontsize=6,
|
|
874
|
+
title="Cell type",
|
|
875
|
+
title_fontsize=7,
|
|
876
|
+
frameon=True,
|
|
877
|
+
fancybox=True,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
# Title
|
|
881
|
+
if title is None:
|
|
882
|
+
title = f"Marker Genes ({label_column})"
|
|
883
|
+
fig.suptitle(title, fontsize=14, fontweight="bold")
|
|
884
|
+
|
|
885
|
+
if save:
|
|
886
|
+
save_figure(fig, save)
|
|
887
|
+
|
|
888
|
+
return fig
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def plot_ontology_mapping(
|
|
892
|
+
adata: ad.AnnData,
|
|
893
|
+
source_label_column: str,
|
|
894
|
+
ontology_name_column: str,
|
|
895
|
+
ontology_id_column: str,
|
|
896
|
+
mapping_table: Optional[pd.DataFrame] = None,
|
|
897
|
+
title: Optional[str] = None,
|
|
898
|
+
figsize: Tuple[float, float] = (14, 8),
|
|
899
|
+
save: Optional[Union[str, Path]] = None,
|
|
900
|
+
) -> Figure:
|
|
901
|
+
"""
|
|
902
|
+
Plot ontology mapping table showing original labels mapped to Cell Ontology.
|
|
903
|
+
|
|
904
|
+
Creates a table visualization with:
|
|
905
|
+
- Original CellTypist labels
|
|
906
|
+
- Mapped ontology names and CL IDs
|
|
907
|
+
- Match tier (tier0=pattern, tier1=exact, etc.)
|
|
908
|
+
- Matching score (actual score from matching, not hardcoded)
|
|
909
|
+
- Cell counts
|
|
910
|
+
|
|
911
|
+
Parameters
|
|
912
|
+
----------
|
|
913
|
+
adata : AnnData
|
|
914
|
+
Annotated data with ontology mapping columns.
|
|
915
|
+
source_label_column : str
|
|
916
|
+
Column with original cell type labels (e.g., "celltypist").
|
|
917
|
+
ontology_name_column : str
|
|
918
|
+
Column with mapped ontology names.
|
|
919
|
+
ontology_id_column : str
|
|
920
|
+
Column with Cell Ontology IDs (CL:XXXXXXX).
|
|
921
|
+
mapping_table : pd.DataFrame, optional
|
|
922
|
+
Pre-computed mapping table from OntologyMappingResult.table.
|
|
923
|
+
If provided, uses this directly instead of building from adata.
|
|
924
|
+
Should have columns: input_label, ontology_name, ontology_id,
|
|
925
|
+
match_tier, score, n_cells.
|
|
926
|
+
title : str, optional
|
|
927
|
+
Plot title. Auto-generated if None.
|
|
928
|
+
figsize : tuple
|
|
929
|
+
Figure size.
|
|
930
|
+
save : Path, optional
|
|
931
|
+
Path to save figure.
|
|
932
|
+
|
|
933
|
+
Returns
|
|
934
|
+
-------
|
|
935
|
+
Figure
|
|
936
|
+
Matplotlib figure with ontology mapping table.
|
|
937
|
+
|
|
938
|
+
Examples
|
|
939
|
+
--------
|
|
940
|
+
>>> # Using adata columns (scores read from {id_column}_score)
|
|
941
|
+
>>> fig = plot_ontology_mapping(
|
|
942
|
+
... adata,
|
|
943
|
+
... source_label_column="celltypist",
|
|
944
|
+
... ontology_name_column="celltypist_ontology_name",
|
|
945
|
+
... ontology_id_column="celltypist_ontology_id",
|
|
946
|
+
... )
|
|
947
|
+
|
|
948
|
+
>>> # Using pre-computed mapping table
|
|
949
|
+
>>> from spatialcore.annotation import add_ontology_ids
|
|
950
|
+
>>> adata, mappings, result = add_ontology_ids(adata, "celltypist", save_mapping="./")
|
|
951
|
+
>>> fig = plot_ontology_mapping(
|
|
952
|
+
... adata,
|
|
953
|
+
... source_label_column="celltypist",
|
|
954
|
+
... ontology_name_column="celltypist_ontology_name",
|
|
955
|
+
... ontology_id_column="celltypist_ontology_id",
|
|
956
|
+
... mapping_table=result.table,
|
|
957
|
+
... )
|
|
958
|
+
"""
|
|
959
|
+
# If mapping_table is provided, use it directly
|
|
960
|
+
if mapping_table is not None:
|
|
961
|
+
summary = mapping_table.copy()
|
|
962
|
+
# Rename columns to display names
|
|
963
|
+
col_map = {
|
|
964
|
+
"input_label": "CellTypist Label",
|
|
965
|
+
"ontology_name": "Ontology Name",
|
|
966
|
+
"ontology_id": "CL ID",
|
|
967
|
+
"match_tier": "Match Tier",
|
|
968
|
+
"score": "Score",
|
|
969
|
+
"n_cells": "Cells",
|
|
970
|
+
}
|
|
971
|
+
summary = summary.rename(columns=col_map)
|
|
972
|
+
|
|
973
|
+
# Format score column
|
|
974
|
+
summary["Score"] = summary["Score"].apply(
|
|
975
|
+
lambda x: f"{x:.2f}" if pd.notna(x) and x > 0 else "-"
|
|
976
|
+
)
|
|
977
|
+
else:
|
|
978
|
+
# Build from adata columns
|
|
979
|
+
for col in [source_label_column, ontology_name_column, ontology_id_column]:
|
|
980
|
+
if col not in adata.obs.columns:
|
|
981
|
+
raise ValueError(f"Column '{col}' not found in adata.obs")
|
|
982
|
+
|
|
983
|
+
# Check for tier and score columns (derived from ontology_id_column)
|
|
984
|
+
tier_column = ontology_id_column.replace("_id", "_tier")
|
|
985
|
+
score_column = ontology_id_column.replace("_id", "_score")
|
|
986
|
+
has_tier_column = tier_column in adata.obs.columns
|
|
987
|
+
has_score_column = score_column in adata.obs.columns
|
|
988
|
+
|
|
989
|
+
# Build mapping summary - convert to string to avoid categorical issues
|
|
990
|
+
cols_to_use = [source_label_column, ontology_name_column, ontology_id_column]
|
|
991
|
+
if has_tier_column:
|
|
992
|
+
cols_to_use.append(tier_column)
|
|
993
|
+
if has_score_column:
|
|
994
|
+
cols_to_use.append(score_column)
|
|
995
|
+
|
|
996
|
+
df = adata.obs[cols_to_use].copy()
|
|
997
|
+
df[source_label_column] = df[source_label_column].astype(str)
|
|
998
|
+
df[ontology_name_column] = df[ontology_name_column].astype(str)
|
|
999
|
+
df[ontology_id_column] = df[ontology_id_column].astype(str)
|
|
1000
|
+
if has_tier_column:
|
|
1001
|
+
df[tier_column] = df[tier_column].astype(str)
|
|
1002
|
+
df["count"] = 1
|
|
1003
|
+
|
|
1004
|
+
# Replace 'nan' strings with empty string
|
|
1005
|
+
df = df.replace("nan", "")
|
|
1006
|
+
|
|
1007
|
+
# Aggregate by source label
|
|
1008
|
+
agg_dict = {
|
|
1009
|
+
ontology_name_column: "first",
|
|
1010
|
+
ontology_id_column: "first",
|
|
1011
|
+
"count": "sum",
|
|
1012
|
+
}
|
|
1013
|
+
if has_tier_column:
|
|
1014
|
+
agg_dict[tier_column] = "first"
|
|
1015
|
+
if has_score_column:
|
|
1016
|
+
agg_dict[score_column] = "first"
|
|
1017
|
+
|
|
1018
|
+
summary = df.groupby(source_label_column).agg(agg_dict).reset_index()
|
|
1019
|
+
|
|
1020
|
+
# Rename columns based on what we have
|
|
1021
|
+
if has_tier_column and has_score_column:
|
|
1022
|
+
summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Match Tier", "Score"]
|
|
1023
|
+
elif has_tier_column:
|
|
1024
|
+
summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Match Tier"]
|
|
1025
|
+
elif has_score_column:
|
|
1026
|
+
summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells", "Score"]
|
|
1027
|
+
else:
|
|
1028
|
+
summary.columns = ["CellTypist Label", "Ontology Name", "CL ID", "Cells"]
|
|
1029
|
+
|
|
1030
|
+
# Add tier column if missing
|
|
1031
|
+
if "Match Tier" not in summary.columns:
|
|
1032
|
+
def get_tier(row):
|
|
1033
|
+
cl_id = str(row["CL ID"]).strip()
|
|
1034
|
+
if not cl_id or cl_id == "" or cl_id == "-" or cl_id == "nan" or cl_id == "unknown" or cl_id == "skipped":
|
|
1035
|
+
return "unmapped"
|
|
1036
|
+
return "tier0_pattern" # Simplified - no tier info available
|
|
1037
|
+
|
|
1038
|
+
summary["Match Tier"] = summary.apply(get_tier, axis=1)
|
|
1039
|
+
|
|
1040
|
+
# Format score column - use actual scores if available
|
|
1041
|
+
if "Score" in summary.columns:
|
|
1042
|
+
summary["Score"] = summary["Score"].apply(
|
|
1043
|
+
lambda x: f"{float(x):.2f}" if pd.notna(x) and str(x) not in ["", "nan", "None"] and float(x) > 0 else "-"
|
|
1044
|
+
)
|
|
1045
|
+
else:
|
|
1046
|
+
# Fallback: estimate score based on tier (less accurate)
|
|
1047
|
+
tier_scores = {
|
|
1048
|
+
"tier0_pattern": "0.95",
|
|
1049
|
+
"tier1_exact": "1.00",
|
|
1050
|
+
"tier2_token": "0.75",
|
|
1051
|
+
"tier3_overlap": "0.60",
|
|
1052
|
+
"unmapped": "-",
|
|
1053
|
+
"skipped": "-",
|
|
1054
|
+
}
|
|
1055
|
+
summary["Score"] = summary["Match Tier"].map(lambda t: tier_scores.get(t, "-"))
|
|
1056
|
+
|
|
1057
|
+
# Fill empty values with placeholder
|
|
1058
|
+
summary["CL ID"] = summary["CL ID"].replace("", "-")
|
|
1059
|
+
summary.loc[summary["Ontology Name"] == "", "Ontology Name"] = summary.loc[summary["Ontology Name"] == "", "CellTypist Label"]
|
|
1060
|
+
|
|
1061
|
+
# Sort by cell count descending
|
|
1062
|
+
summary = summary.sort_values("Cells", ascending=False).reset_index(drop=True)
|
|
1063
|
+
|
|
1064
|
+
# Calculate stats
|
|
1065
|
+
n_labels = len(summary)
|
|
1066
|
+
n_mapped = (~summary["Match Tier"].isin(["unmapped", "skipped"])).sum()
|
|
1067
|
+
total_cells = summary["Cells"].sum()
|
|
1068
|
+
mapped_cells = summary[~summary["Match Tier"].isin(["unmapped", "skipped"])]["Cells"].sum()
|
|
1069
|
+
|
|
1070
|
+
# Tier colors
|
|
1071
|
+
tier_colors = {
|
|
1072
|
+
"tier0_pattern": "#d4edda", # Green
|
|
1073
|
+
"tier1_exact": "#cce5ff", # Blue
|
|
1074
|
+
"tier2_token": "#fff3cd", # Orange/Yellow
|
|
1075
|
+
"tier3_overlap": "#f8d7da", # Red
|
|
1076
|
+
"unmapped": "#e9ecef", # Gray
|
|
1077
|
+
"skipped": "#e9ecef", # Gray (for Unassigned, Unknown, etc.)
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
# Calculate figure height based on number of rows
|
|
1081
|
+
n_rows = len(summary)
|
|
1082
|
+
fig_height = max(6, 2 + n_rows * 0.4) # Dynamic height
|
|
1083
|
+
figsize = (figsize[0], fig_height)
|
|
1084
|
+
|
|
1085
|
+
# Create figure
|
|
1086
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1087
|
+
ax.axis("off")
|
|
1088
|
+
|
|
1089
|
+
# Title and stats at top
|
|
1090
|
+
if title is None:
|
|
1091
|
+
title = "CellTypist to Cell Ontology Mapping"
|
|
1092
|
+
|
|
1093
|
+
fig.text(
|
|
1094
|
+
0.5, 0.97,
|
|
1095
|
+
f"Labels: {n_mapped}/{n_labels} mapped ({100*n_mapped/n_labels:.1f}%) | "
|
|
1096
|
+
f"Cells: {mapped_cells:,}/{total_cells:,} mapped ({100*mapped_cells/total_cells:.1f}%)",
|
|
1097
|
+
ha="center", fontsize=10, color="green"
|
|
1098
|
+
)
|
|
1099
|
+
fig.text(0.5, 0.93, title, ha="center", fontsize=14, fontweight="bold")
|
|
1100
|
+
|
|
1101
|
+
# Reorder columns for display
|
|
1102
|
+
display_cols = ["CellTypist Label", "Ontology Name", "CL ID", "Match Tier", "Score", "Cells"]
|
|
1103
|
+
summary = summary[display_cols]
|
|
1104
|
+
|
|
1105
|
+
# Create table
|
|
1106
|
+
table = ax.table(
|
|
1107
|
+
cellText=summary.values,
|
|
1108
|
+
colLabels=summary.columns,
|
|
1109
|
+
cellLoc="left",
|
|
1110
|
+
loc="upper center",
|
|
1111
|
+
colColours=["#2c3e50"] * len(display_cols),
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
# Style table
|
|
1115
|
+
table.auto_set_font_size(False)
|
|
1116
|
+
table.set_fontsize(8)
|
|
1117
|
+
|
|
1118
|
+
# Scale table to fit - adjust row height based on number of rows
|
|
1119
|
+
row_height = min(0.08, 0.8 / (n_rows + 1))
|
|
1120
|
+
table.scale(1.0, 1.5)
|
|
1121
|
+
|
|
1122
|
+
# Set column widths
|
|
1123
|
+
col_widths = [0.22, 0.22, 0.12, 0.14, 0.08, 0.10]
|
|
1124
|
+
for j, width in enumerate(col_widths):
|
|
1125
|
+
for i in range(n_rows + 1):
|
|
1126
|
+
table[(i, j)].set_width(width)
|
|
1127
|
+
|
|
1128
|
+
# Color header text white
|
|
1129
|
+
for j in range(len(display_cols)):
|
|
1130
|
+
table[(0, j)].get_text().set_color("white")
|
|
1131
|
+
table[(0, j)].get_text().set_fontweight("bold")
|
|
1132
|
+
|
|
1133
|
+
# Color rows by tier
|
|
1134
|
+
for i in range(len(summary)):
|
|
1135
|
+
tier = summary.iloc[i]["Match Tier"]
|
|
1136
|
+
color = tier_colors.get(tier, "#ffffff")
|
|
1137
|
+
for j in range(len(display_cols)):
|
|
1138
|
+
table[(i + 1, j)].set_facecolor(color)
|
|
1139
|
+
|
|
1140
|
+
# Add legend at bottom
|
|
1141
|
+
legend_text = (
|
|
1142
|
+
"Tier Colors: Green = Pattern Match (tier0) | Blue = Exact Match (tier1) | "
|
|
1143
|
+
"Orange = Token Match (tier2) | Red = Overlap (tier3) | Gray = Unmapped"
|
|
1144
|
+
)
|
|
1145
|
+
fig.text(0.5, 0.02, legend_text, ha="center", fontsize=8, style="italic")
|
|
1146
|
+
|
|
1147
|
+
if save:
|
|
1148
|
+
save_figure(fig, save)
|
|
1149
|
+
|
|
1150
|
+
return fig
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def generate_annotation_plots(
|
|
1154
|
+
adata: ad.AnnData,
|
|
1155
|
+
label_column: str = "cell_type",
|
|
1156
|
+
confidence_column: str = "cell_type_confidence",
|
|
1157
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
1158
|
+
prefix: str = "celltyping",
|
|
1159
|
+
confidence_threshold: float = 0.8,
|
|
1160
|
+
markers: Optional[Dict[str, List[str]]] = None,
|
|
1161
|
+
n_deg_genes: int = 10,
|
|
1162
|
+
spatial_key: str = "spatial",
|
|
1163
|
+
source_label_column: Optional[str] = None,
|
|
1164
|
+
ontology_name_column: Optional[str] = None,
|
|
1165
|
+
ontology_id_column: Optional[str] = None,
|
|
1166
|
+
) -> Dict:
|
|
1167
|
+
"""
|
|
1168
|
+
Generate all cell typing validation plots.
|
|
1169
|
+
|
|
1170
|
+
Produces four standard validation outputs per the spec:
|
|
1171
|
+
1. Ontology mapping table - original labels to Cell Ontology
|
|
1172
|
+
2. 2D marker validation (GMM-3) - faceted by cell type
|
|
1173
|
+
3. Cell type confidence - spatial + jitter plot
|
|
1174
|
+
4. DEG heatmap - top genes per cell type
|
|
1175
|
+
|
|
1176
|
+
This function should be called after annotation to validate results.
|
|
1177
|
+
Per spec: "Validation: 2D Multivariate QC" (Step H in workflow).
|
|
1178
|
+
|
|
1179
|
+
Parameters
|
|
1180
|
+
----------
|
|
1181
|
+
adata : AnnData
|
|
1182
|
+
Annotated data with cell type labels and confidence scores.
|
|
1183
|
+
label_column : str, default "cell_type"
|
|
1184
|
+
Column containing cell type labels (CellxGene standard).
|
|
1185
|
+
confidence_column : str, default "cell_type_confidence"
|
|
1186
|
+
Column containing confidence values (z-score transformed, CellxGene standard).
|
|
1187
|
+
output_dir : str or Path, optional
|
|
1188
|
+
Directory to save plots. If None, plots are returned but not saved.
|
|
1189
|
+
prefix : str, default "celltyping"
|
|
1190
|
+
Filename prefix for saved plots.
|
|
1191
|
+
confidence_threshold : float, default 0.8
|
|
1192
|
+
Threshold for confidence validation.
|
|
1193
|
+
markers : Dict[str, List[str]], optional
|
|
1194
|
+
Custom marker genes per cell type. If None, uses canonical markers
|
|
1195
|
+
from C:/SpatialCore/Data/markers/canonical_markers.json.
|
|
1196
|
+
n_deg_genes : int, default 10
|
|
1197
|
+
Number of top DEGs per cell type for heatmap.
|
|
1198
|
+
source_label_column : str, optional
|
|
1199
|
+
Original CellTypist label column (for ontology mapping table).
|
|
1200
|
+
If None, tries to infer from label_column.
|
|
1201
|
+
ontology_name_column : str, optional
|
|
1202
|
+
Ontology name column. If None, tries "{source}_ontology_name".
|
|
1203
|
+
ontology_id_column : str, optional
|
|
1204
|
+
Ontology ID column. If None, tries "{source}_ontology_id".
|
|
1205
|
+
spatial_key : str, default "spatial"
|
|
1206
|
+
Key for spatial coordinates in adata.obsm.
|
|
1207
|
+
|
|
1208
|
+
Returns
|
|
1209
|
+
-------
|
|
1210
|
+
Dict
|
|
1211
|
+
Dictionary with keys:
|
|
1212
|
+
- "figures": dict of matplotlib Figure objects
|
|
1213
|
+
- "summary": validation summary DataFrame (from 2D validation)
|
|
1214
|
+
- "paths": dict of saved file paths (if output_dir provided)
|
|
1215
|
+
|
|
1216
|
+
Examples
|
|
1217
|
+
--------
|
|
1218
|
+
>>> from spatialcore.plotting import generate_annotation_plots
|
|
1219
|
+
>>> results = generate_annotation_plots(
|
|
1220
|
+
... adata,
|
|
1221
|
+
... output_dir="./qc_plots",
|
|
1222
|
+
... prefix="lung_cancer",
|
|
1223
|
+
... )
|
|
1224
|
+
>>> print(results["summary"])
|
|
1225
|
+
"""
|
|
1226
|
+
output_dir = Path(output_dir) if output_dir else None
|
|
1227
|
+
if output_dir:
|
|
1228
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1229
|
+
|
|
1230
|
+
results = {"figures": {}, "summary": None, "paths": {}}
|
|
1231
|
+
|
|
1232
|
+
# Infer ontology column names if not provided
|
|
1233
|
+
# Try to find source label column (original predictions before ontology mapping)
|
|
1234
|
+
if source_label_column is None:
|
|
1235
|
+
# Use the label_column as source (cell_type is the predicted type)
|
|
1236
|
+
source_label_column = label_column
|
|
1237
|
+
|
|
1238
|
+
if ontology_name_column is None:
|
|
1239
|
+
# Try CellxGene standard: cell_type_ontology_label
|
|
1240
|
+
candidate = "cell_type_ontology_label"
|
|
1241
|
+
if candidate in adata.obs.columns:
|
|
1242
|
+
ontology_name_column = candidate
|
|
1243
|
+
else:
|
|
1244
|
+
# Try legacy pattern: {base}_ontology_name
|
|
1245
|
+
base = label_column.replace("_ontology_name", "")
|
|
1246
|
+
candidate = f"{base}_ontology_name"
|
|
1247
|
+
if candidate in adata.obs.columns:
|
|
1248
|
+
ontology_name_column = candidate
|
|
1249
|
+
|
|
1250
|
+
if ontology_id_column is None:
|
|
1251
|
+
# Try CellxGene standard: cell_type_ontology_term_id
|
|
1252
|
+
candidate = "cell_type_ontology_term_id"
|
|
1253
|
+
if candidate in adata.obs.columns:
|
|
1254
|
+
ontology_id_column = candidate
|
|
1255
|
+
else:
|
|
1256
|
+
# Try legacy pattern: {base}_ontology_id
|
|
1257
|
+
base = label_column.replace("_ontology_name", "")
|
|
1258
|
+
candidate = f"{base}_ontology_id"
|
|
1259
|
+
if candidate in adata.obs.columns:
|
|
1260
|
+
ontology_id_column = candidate
|
|
1261
|
+
|
|
1262
|
+
# 0. Ontology Mapping Table
|
|
1263
|
+
logger.info("Generating ontology mapping table...")
|
|
1264
|
+
if source_label_column and ontology_name_column and ontology_id_column:
|
|
1265
|
+
try:
|
|
1266
|
+
path_ontology = output_dir / f"{prefix}_ontology_mapping.png" if output_dir else None
|
|
1267
|
+
fig_ontology = plot_ontology_mapping(
|
|
1268
|
+
adata,
|
|
1269
|
+
source_label_column=source_label_column,
|
|
1270
|
+
ontology_name_column=ontology_name_column,
|
|
1271
|
+
ontology_id_column=ontology_id_column,
|
|
1272
|
+
save=path_ontology,
|
|
1273
|
+
)
|
|
1274
|
+
results["figures"]["ontology_mapping"] = fig_ontology
|
|
1275
|
+
results["paths"]["ontology_mapping"] = path_ontology
|
|
1276
|
+
logger.info(" Ontology mapping table generated")
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
logger.warning(f"Ontology mapping table failed: {e}")
|
|
1279
|
+
else:
|
|
1280
|
+
logger.info(f" Skipping ontology table - columns not found")
|
|
1281
|
+
logger.info(f" source_label_column: {source_label_column}")
|
|
1282
|
+
logger.info(f" ontology_name_column: {ontology_name_column}")
|
|
1283
|
+
logger.info(f" ontology_id_column: {ontology_id_column}")
|
|
1284
|
+
|
|
1285
|
+
# 1. 2D Marker Validation (GMM-3)
|
|
1286
|
+
logger.info("Generating 2D marker validation plot...")
|
|
1287
|
+
try:
|
|
1288
|
+
path_2d = output_dir / f"{prefix}_2d_validation.png" if output_dir else None
|
|
1289
|
+
fig_2d, summary = plot_2d_validation(
|
|
1290
|
+
adata,
|
|
1291
|
+
label_column=label_column,
|
|
1292
|
+
confidence_column=confidence_column,
|
|
1293
|
+
markers=markers,
|
|
1294
|
+
confidence_threshold=confidence_threshold,
|
|
1295
|
+
n_components=3, # GMM-3 per spec
|
|
1296
|
+
save=path_2d,
|
|
1297
|
+
)
|
|
1298
|
+
results["figures"]["2d_validation"] = fig_2d
|
|
1299
|
+
results["summary"] = summary
|
|
1300
|
+
results["paths"]["2d_validation"] = path_2d
|
|
1301
|
+
logger.info(f" 2D validation: {len(summary)} cell types analyzed")
|
|
1302
|
+
except Exception as e:
|
|
1303
|
+
logger.warning(f"2D validation plot failed: {e}")
|
|
1304
|
+
|
|
1305
|
+
# 2. Cell Type Confidence
|
|
1306
|
+
logger.info("Generating confidence plot...")
|
|
1307
|
+
try:
|
|
1308
|
+
path_conf = output_dir / f"{prefix}_confidence.png" if output_dir else None
|
|
1309
|
+
fig_conf = plot_celltype_confidence(
|
|
1310
|
+
adata,
|
|
1311
|
+
label_column=label_column,
|
|
1312
|
+
confidence_column=confidence_column,
|
|
1313
|
+
spatial_key=spatial_key,
|
|
1314
|
+
threshold=confidence_threshold,
|
|
1315
|
+
save=path_conf,
|
|
1316
|
+
)
|
|
1317
|
+
results["figures"]["confidence"] = fig_conf
|
|
1318
|
+
results["paths"]["confidence"] = path_conf
|
|
1319
|
+
logger.info(" Confidence plot generated")
|
|
1320
|
+
except Exception as e:
|
|
1321
|
+
logger.warning(f"Confidence plot failed: {e}")
|
|
1322
|
+
|
|
1323
|
+
# 3. DEG Heatmap
|
|
1324
|
+
logger.info("Generating DEG heatmap...")
|
|
1325
|
+
try:
|
|
1326
|
+
path_deg = output_dir / f"{prefix}_deg_heatmap.png" if output_dir else None
|
|
1327
|
+
fig_deg = plot_deg_heatmap(
|
|
1328
|
+
adata,
|
|
1329
|
+
label_column=label_column,
|
|
1330
|
+
n_genes=n_deg_genes,
|
|
1331
|
+
save=path_deg,
|
|
1332
|
+
)
|
|
1333
|
+
results["figures"]["deg_heatmap"] = fig_deg
|
|
1334
|
+
results["paths"]["deg_heatmap"] = path_deg
|
|
1335
|
+
logger.info(" DEG heatmap generated")
|
|
1336
|
+
except Exception as e:
|
|
1337
|
+
logger.warning(f"DEG heatmap failed: {e}")
|
|
1338
|
+
|
|
1339
|
+
if output_dir:
|
|
1340
|
+
logger.info(f"Annotation plots saved to: {output_dir}")
|
|
1341
|
+
|
|
1342
|
+
return results
|