spatialcore 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcore/__init__.py +122 -0
- spatialcore/annotation/__init__.py +253 -0
- spatialcore/annotation/acquisition.py +529 -0
- spatialcore/annotation/annotate.py +603 -0
- spatialcore/annotation/cellxgene.py +365 -0
- spatialcore/annotation/confidence.py +802 -0
- spatialcore/annotation/discovery.py +529 -0
- spatialcore/annotation/expression.py +363 -0
- spatialcore/annotation/loading.py +529 -0
- spatialcore/annotation/markers.py +297 -0
- spatialcore/annotation/ontology.py +1282 -0
- spatialcore/annotation/patterns.py +247 -0
- spatialcore/annotation/pipeline.py +620 -0
- spatialcore/annotation/synapse.py +380 -0
- spatialcore/annotation/training.py +1457 -0
- spatialcore/annotation/validation.py +422 -0
- spatialcore/core/__init__.py +34 -0
- spatialcore/core/cache.py +118 -0
- spatialcore/core/logging.py +135 -0
- spatialcore/core/metadata.py +149 -0
- spatialcore/core/utils.py +768 -0
- spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
- spatialcore/data/markers/canonical_markers.json +83 -0
- spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
- spatialcore/plotting/__init__.py +109 -0
- spatialcore/plotting/benchmark.py +477 -0
- spatialcore/plotting/celltype.py +329 -0
- spatialcore/plotting/confidence.py +413 -0
- spatialcore/plotting/spatial.py +505 -0
- spatialcore/plotting/utils.py +411 -0
- spatialcore/plotting/validation.py +1342 -0
- spatialcore-0.1.9.dist-info/METADATA +213 -0
- spatialcore-0.1.9.dist-info/RECORD +36 -0
- spatialcore-0.1.9.dist-info/WHEEL +5 -0
- spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
- spatialcore-0.1.9.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CellTypist annotation wrapper for cell type annotation.
|
|
3
|
+
|
|
4
|
+
This module provides a convenience wrapper around CellTypist for:
|
|
5
|
+
1. Tissue-specific model selection
|
|
6
|
+
2. Ensemble annotation across multiple models
|
|
7
|
+
3. Gene overlap validation
|
|
8
|
+
4. Proper re-normalization for CellTypist compatibility
|
|
9
|
+
|
|
10
|
+
References:
|
|
11
|
+
- CellTypist: https://www.celltypist.org/
|
|
12
|
+
- Model documentation: See docs/CELLTYPIST_MODELS.md
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Dict, List, Literal, Optional, Union, Any
|
|
17
|
+
import gc
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pandas as pd
|
|
21
|
+
import scanpy as sc
|
|
22
|
+
import anndata as ad
|
|
23
|
+
|
|
24
|
+
from spatialcore.core.logging import get_logger
|
|
25
|
+
from spatialcore.annotation.confidence import (
|
|
26
|
+
extract_decision_scores,
|
|
27
|
+
transform_confidence,
|
|
28
|
+
ConfidenceMethod,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ============================================================================
|
|
35
|
+
# Tissue-Specific Model Presets
|
|
36
|
+
# ============================================================================
|
|
37
|
+
|
|
38
|
+
TISSUE_MODEL_PRESETS: Dict[str, List[str]] = {
|
|
39
|
+
# General (use for unknown tissues)
|
|
40
|
+
"unknown": [
|
|
41
|
+
"Immune_All_Low.pkl",
|
|
42
|
+
"Pan_Fetal_Human.pkl",
|
|
43
|
+
],
|
|
44
|
+
# Digestive system
|
|
45
|
+
"colon": [
|
|
46
|
+
"Immune_All_Low.pkl",
|
|
47
|
+
"Pan_Fetal_Human.pkl",
|
|
48
|
+
"Cells_Intestinal_Tract.pkl",
|
|
49
|
+
],
|
|
50
|
+
"intestine": [
|
|
51
|
+
"Immune_All_Low.pkl",
|
|
52
|
+
"Pan_Fetal_Human.pkl",
|
|
53
|
+
"Cells_Intestinal_Tract.pkl",
|
|
54
|
+
],
|
|
55
|
+
# Liver
|
|
56
|
+
"liver": [
|
|
57
|
+
"Immune_All_Low.pkl",
|
|
58
|
+
"Pan_Fetal_Human.pkl",
|
|
59
|
+
"Healthy_Human_Liver.pkl",
|
|
60
|
+
],
|
|
61
|
+
# Lung
|
|
62
|
+
"lung": [
|
|
63
|
+
"Immune_All_Low.pkl",
|
|
64
|
+
"Pan_Fetal_Human.pkl",
|
|
65
|
+
"Human_Lung_Atlas.pkl",
|
|
66
|
+
],
|
|
67
|
+
"lung_airway": [
|
|
68
|
+
"Immune_All_Low.pkl",
|
|
69
|
+
"Pan_Fetal_Human.pkl",
|
|
70
|
+
"Cells_Lung_Airway.pkl",
|
|
71
|
+
],
|
|
72
|
+
"lung_cancer": [
|
|
73
|
+
"Immune_All_Low.pkl",
|
|
74
|
+
"Human_Lung_Atlas.pkl",
|
|
75
|
+
],
|
|
76
|
+
# Heart
|
|
77
|
+
"heart": [
|
|
78
|
+
"Immune_All_Low.pkl",
|
|
79
|
+
"Pan_Fetal_Human.pkl",
|
|
80
|
+
"Healthy_Adult_Heart.pkl",
|
|
81
|
+
],
|
|
82
|
+
# Breast
|
|
83
|
+
"breast": [
|
|
84
|
+
"Immune_All_Low.pkl",
|
|
85
|
+
"Pan_Fetal_Human.pkl",
|
|
86
|
+
"Cells_Adult_Breast.pkl",
|
|
87
|
+
],
|
|
88
|
+
# Skin
|
|
89
|
+
"skin": [
|
|
90
|
+
"Immune_All_Low.pkl",
|
|
91
|
+
"Pan_Fetal_Human.pkl",
|
|
92
|
+
"Adult_Human_Skin.pkl",
|
|
93
|
+
],
|
|
94
|
+
# Pancreas
|
|
95
|
+
"pancreas": [
|
|
96
|
+
"Immune_All_Low.pkl",
|
|
97
|
+
"Pan_Fetal_Human.pkl",
|
|
98
|
+
"Adult_Human_PancreaticIslet.pkl",
|
|
99
|
+
],
|
|
100
|
+
# Brain
|
|
101
|
+
"brain": [
|
|
102
|
+
"Immune_All_Low.pkl",
|
|
103
|
+
"Pan_Fetal_Human.pkl",
|
|
104
|
+
"Adult_Human_MTG.pkl",
|
|
105
|
+
],
|
|
106
|
+
# Tonsil
|
|
107
|
+
"tonsil": [
|
|
108
|
+
"Immune_All_Low.pkl",
|
|
109
|
+
"Pan_Fetal_Human.pkl",
|
|
110
|
+
"Cells_Human_Tonsil.pkl",
|
|
111
|
+
],
|
|
112
|
+
# Blood/Immune
|
|
113
|
+
"blood": [
|
|
114
|
+
"Immune_All_Low.pkl",
|
|
115
|
+
"Pan_Fetal_Human.pkl",
|
|
116
|
+
],
|
|
117
|
+
"pbmc": [
|
|
118
|
+
"Immune_All_Low.pkl",
|
|
119
|
+
"Pan_Fetal_Human.pkl",
|
|
120
|
+
],
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_models_for_tissue(tissue: str) -> List[str]:
|
|
125
|
+
"""
|
|
126
|
+
Get recommended CellTypist models for a tissue type.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
tissue : str
|
|
131
|
+
Tissue name (e.g., "liver", "lung", "colon").
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
List[str]
|
|
136
|
+
List of model names/paths.
|
|
137
|
+
|
|
138
|
+
Examples
|
|
139
|
+
--------
|
|
140
|
+
>>> from spatialcore.annotation import get_models_for_tissue
|
|
141
|
+
>>> models = get_models_for_tissue("liver")
|
|
142
|
+
>>> print(models)
|
|
143
|
+
['Immune_All_Low.pkl', 'Pan_Fetal_Human.pkl', 'Healthy_Human_Liver.pkl']
|
|
144
|
+
"""
|
|
145
|
+
tissue_lower = tissue.lower().strip()
|
|
146
|
+
return TISSUE_MODEL_PRESETS.get(tissue_lower, TISSUE_MODEL_PRESETS["unknown"])
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# ============================================================================
|
|
150
|
+
# Model Validation
|
|
151
|
+
# ============================================================================
|
|
152
|
+
|
|
153
|
+
def _validate_gene_overlap(
|
|
154
|
+
model,
|
|
155
|
+
data_genes: set,
|
|
156
|
+
min_overlap_pct: float = 25.0,
|
|
157
|
+
) -> Dict[str, Any]:
|
|
158
|
+
"""
|
|
159
|
+
Validate gene overlap between model and data.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
model
|
|
164
|
+
Loaded CellTypist model.
|
|
165
|
+
data_genes : set
|
|
166
|
+
Gene names from query data.
|
|
167
|
+
min_overlap_pct : float, default 25.0
|
|
168
|
+
Minimum required overlap percentage.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
Dict[str, Any]
|
|
173
|
+
Overlap statistics and pass/fail status.
|
|
174
|
+
"""
|
|
175
|
+
model_genes = set(model.features)
|
|
176
|
+
overlap_genes = model_genes & data_genes
|
|
177
|
+
overlap_pct = 100 * len(overlap_genes) / len(model_genes) if model_genes else 0
|
|
178
|
+
|
|
179
|
+
return {
|
|
180
|
+
"n_model_genes": len(model_genes),
|
|
181
|
+
"n_data_genes": len(data_genes),
|
|
182
|
+
"n_overlap": len(overlap_genes),
|
|
183
|
+
"overlap_pct": overlap_pct,
|
|
184
|
+
"passes_threshold": overlap_pct >= min_overlap_pct,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _validate_celltypist_input(
|
|
189
|
+
adata: ad.AnnData,
|
|
190
|
+
norm_layer: str = "norm",
|
|
191
|
+
) -> ad.AnnData:
|
|
192
|
+
"""
|
|
193
|
+
Validate and prepare AnnData for CellTypist.
|
|
194
|
+
|
|
195
|
+
Validates that data is properly normalized (log1p, ~10k sum).
|
|
196
|
+
Does NOT modify or normalize data - errors if validation fails.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
adata : AnnData
|
|
201
|
+
Input data with normalized layer.
|
|
202
|
+
norm_layer : str, default "norm"
|
|
203
|
+
Layer containing log1p(10k) normalized data.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
AnnData
|
|
208
|
+
Copy with validated layer in X.
|
|
209
|
+
|
|
210
|
+
Raises
|
|
211
|
+
------
|
|
212
|
+
ValueError
|
|
213
|
+
If layer doesn't exist or data is not properly normalized.
|
|
214
|
+
"""
|
|
215
|
+
# Step 1: Check layer exists (NO fallback)
|
|
216
|
+
if norm_layer not in adata.layers:
|
|
217
|
+
available = list(adata.layers.keys())
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Layer '{norm_layer}' not found in adata.layers.\n"
|
|
220
|
+
f"Available layers: {available}\n"
|
|
221
|
+
f"Ensure normalization has been run before CellTypist annotation."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
logger.info(f"Validating adata.layers['{norm_layer}'] for CellTypist...")
|
|
225
|
+
data = adata.layers[norm_layer]
|
|
226
|
+
|
|
227
|
+
# Handle sparse matrices
|
|
228
|
+
if hasattr(data, "toarray"):
|
|
229
|
+
sample = data[:1000].toarray() if data.shape[0] > 1000 else data.toarray()
|
|
230
|
+
else:
|
|
231
|
+
sample = data[:1000] if data.shape[0] > 1000 else data
|
|
232
|
+
|
|
233
|
+
# Step 2: Check log-transformed (NO fallback)
|
|
234
|
+
data_max = float(np.max(sample))
|
|
235
|
+
if data_max > 50:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Data in layer '{norm_layer}' does not appear to be log-transformed.\n"
|
|
238
|
+
f" Max value: {data_max:.2f} (expected < 15 for log1p data)\n"
|
|
239
|
+
f" Run normalization pipeline before CellTypist annotation."
|
|
240
|
+
)
|
|
241
|
+
logger.info(f" [OK] Log-transformed (max={data_max:.2f})")
|
|
242
|
+
|
|
243
|
+
# Step 3: Check sum ~10000 (NO fallback)
|
|
244
|
+
original_sum = np.expm1(sample).sum(axis=1)
|
|
245
|
+
mean_sum = float(np.mean(original_sum))
|
|
246
|
+
if abs(mean_sum - 10000) / 10000 > 0.1: # 10% tolerance
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"Data in layer '{norm_layer}' not normalized to 10000 counts.\n"
|
|
249
|
+
f" Observed mean sum: {mean_sum:.0f} (expected ~10000)\n"
|
|
250
|
+
f" Normalize with: sc.pp.normalize_total(adata, target_sum=10000)"
|
|
251
|
+
)
|
|
252
|
+
logger.info(f" [OK] Sum validation passed (mean={mean_sum:.0f})")
|
|
253
|
+
|
|
254
|
+
# Step 4: Create copy with validated layer in X
|
|
255
|
+
adata_ct = adata.copy()
|
|
256
|
+
adata_ct.X = adata.layers[norm_layer].copy()
|
|
257
|
+
logger.info(f" [OK] Prepared: {adata_ct.n_obs:,} cells x {adata_ct.n_vars:,} genes")
|
|
258
|
+
|
|
259
|
+
return adata_ct
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# ============================================================================
|
|
263
|
+
# Main Annotation Function
|
|
264
|
+
# ============================================================================
|
|
265
|
+
|
|
266
|
+
def annotate_celltypist(
|
|
267
|
+
adata: ad.AnnData,
|
|
268
|
+
tissue: str = "unknown",
|
|
269
|
+
ensemble_mode: bool = True,
|
|
270
|
+
custom_model_path: Optional[Union[str, Path]] = None,
|
|
271
|
+
majority_voting: bool = False,
|
|
272
|
+
over_clustering: Optional[str] = None,
|
|
273
|
+
min_prop: float = 0.0,
|
|
274
|
+
min_gene_overlap_pct: float = 25.0,
|
|
275
|
+
min_confidence: float = 0.5,
|
|
276
|
+
norm_layer: str = "norm",
|
|
277
|
+
store_decision_scores: bool = True,
|
|
278
|
+
confidence_transform: Optional[ConfidenceMethod] = "zscore",
|
|
279
|
+
copy: bool = False,
|
|
280
|
+
) -> ad.AnnData:
|
|
281
|
+
"""
|
|
282
|
+
Annotate cells using CellTypist with tissue-specific models.
|
|
283
|
+
|
|
284
|
+
Algorithm:
|
|
285
|
+
1. Load tissue-specific model preset (or custom model)
|
|
286
|
+
2. Validate gene overlap for each model (skip if <25%)
|
|
287
|
+
3. Validate normalization in specified layer (log1p, ~10k sum)
|
|
288
|
+
4. Run prediction with native celltypist.annotate()
|
|
289
|
+
5. Ensemble: take highest confidence per cell across models
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
adata : AnnData
|
|
294
|
+
AnnData object to annotate.
|
|
295
|
+
tissue : str, default "unknown"
|
|
296
|
+
Tissue type for model selection (e.g., "liver", "lung", "colon").
|
|
297
|
+
ensemble_mode : bool, default True
|
|
298
|
+
Use multiple tissue-specific models and ensemble results.
|
|
299
|
+
custom_model_path : str or Path, optional
|
|
300
|
+
Path to custom .pkl model (overrides tissue preset).
|
|
301
|
+
majority_voting : bool, default False
|
|
302
|
+
Use CellTypist's native majority voting within clusters.
|
|
303
|
+
**Default False for spatial data** - voting can collapse cell types.
|
|
304
|
+
over_clustering : str, optional
|
|
305
|
+
Column in adata.obs for cluster-based voting (e.g., "leiden").
|
|
306
|
+
min_prop : float, default 0.0
|
|
307
|
+
Minimum proportion for subcluster assignment (0.0 = no threshold).
|
|
308
|
+
min_gene_overlap_pct : float, default 25.0
|
|
309
|
+
Skip models with less than this gene overlap.
|
|
310
|
+
min_confidence : float, default 0.5
|
|
311
|
+
Minimum confidence threshold for cell type assignment.
|
|
312
|
+
Cells below this threshold are labeled "Unassigned".
|
|
313
|
+
Set to 0.0 to disable filtering (assign all cells).
|
|
314
|
+
norm_layer : str, default "norm"
|
|
315
|
+
Layer containing log1p(10k) normalized data. Must exist in adata.layers.
|
|
316
|
+
store_decision_scores : bool, default True
|
|
317
|
+
Store full decision score matrix in adata.obsm for downstream analysis.
|
|
318
|
+
Stores in adata.obsm["celltypist_decision_scores"].
|
|
319
|
+
confidence_transform : {"raw", "zscore", "softmax", "minmax"} or None, default "zscore"
|
|
320
|
+
Transform method for confidence scores. "zscore" is recommended for
|
|
321
|
+
spatial data. Set to None to skip transformation.
|
|
322
|
+
copy : bool, default False
|
|
323
|
+
If True, return a copy.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
AnnData
|
|
328
|
+
AnnData with new columns in obs (CellxGene standard names):
|
|
329
|
+
- cell_type: Final cell type labels
|
|
330
|
+
- cell_type_confidence: Transformed confidence (z-score by default)
|
|
331
|
+
- cell_type_confidence_raw: Raw confidence scores from CellTypist
|
|
332
|
+
- cell_type_model: Which model contributed each prediction
|
|
333
|
+
- cell_type_original: Per-cell predictions (before any voting)
|
|
334
|
+
|
|
335
|
+
And optionally in obsm (if store_decision_scores=True):
|
|
336
|
+
- cell_type_decision_scores: Full decision score matrix (n_cells x n_types)
|
|
337
|
+
|
|
338
|
+
Notes
|
|
339
|
+
-----
|
|
340
|
+
For spatial data, majority_voting=False is recommended because:
|
|
341
|
+
1. Spatial clustering may be coarse (few clusters)
|
|
342
|
+
2. Voting assigns dominant type to ALL cells in cluster
|
|
343
|
+
3. This can collapse 13 cell types to 2 types
|
|
344
|
+
|
|
345
|
+
Examples
|
|
346
|
+
--------
|
|
347
|
+
>>> from spatialcore.annotation import annotate_celltypist
|
|
348
|
+
>>> adata = annotate_celltypist(
|
|
349
|
+
... adata,
|
|
350
|
+
... tissue="liver",
|
|
351
|
+
... ensemble_mode=True,
|
|
352
|
+
... majority_voting=False, # Default for spatial
|
|
353
|
+
... )
|
|
354
|
+
>>> adata.obs[["celltypist", "celltypist_confidence"]].head()
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
import celltypist
|
|
358
|
+
from celltypist import models
|
|
359
|
+
except ImportError:
|
|
360
|
+
raise ImportError(
|
|
361
|
+
"celltypist is required. Install with: pip install celltypist"
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if copy:
|
|
365
|
+
adata = adata.copy()
|
|
366
|
+
|
|
367
|
+
# Determine models to run
|
|
368
|
+
if custom_model_path:
|
|
369
|
+
models_to_run = [str(custom_model_path)]
|
|
370
|
+
logger.info(f"Using custom model: {custom_model_path}")
|
|
371
|
+
elif ensemble_mode:
|
|
372
|
+
models_to_run = get_models_for_tissue(tissue)
|
|
373
|
+
logger.info(f"Using {len(models_to_run)} models for tissue '{tissue}'")
|
|
374
|
+
else:
|
|
375
|
+
models_to_run = ["Immune_All_Low.pkl"]
|
|
376
|
+
logger.info("Using single model: Immune_All_Low.pkl")
|
|
377
|
+
|
|
378
|
+
# Load models and validate gene overlap
|
|
379
|
+
loaded_models = {}
|
|
380
|
+
all_overlap_genes = set()
|
|
381
|
+
data_genes = set(adata.var_names)
|
|
382
|
+
|
|
383
|
+
for model_name in models_to_run:
|
|
384
|
+
try:
|
|
385
|
+
if Path(model_name).exists():
|
|
386
|
+
loaded_model = models.Model.load(model_name)
|
|
387
|
+
else:
|
|
388
|
+
# Try to load from CellTypist's model collection
|
|
389
|
+
try:
|
|
390
|
+
loaded_model = models.Model.load(model=model_name)
|
|
391
|
+
except Exception:
|
|
392
|
+
logger.info(f"Downloading model: {model_name}")
|
|
393
|
+
models.download_models(model=model_name)
|
|
394
|
+
loaded_model = models.Model.load(model=model_name)
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.warning(f"Failed to load model {model_name}: {e}")
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
# Validate gene overlap
|
|
400
|
+
overlap_info = _validate_gene_overlap(
|
|
401
|
+
loaded_model, data_genes, min_gene_overlap_pct
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if not overlap_info["passes_threshold"]:
|
|
405
|
+
logger.warning(
|
|
406
|
+
f"Skipping {model_name}: only {overlap_info['overlap_pct']:.1f}% gene overlap"
|
|
407
|
+
)
|
|
408
|
+
continue
|
|
409
|
+
|
|
410
|
+
logger.info(
|
|
411
|
+
f" {model_name}: {overlap_info['overlap_pct']:.1f}% overlap "
|
|
412
|
+
f"({overlap_info['n_overlap']}/{overlap_info['n_model_genes']} genes)"
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
loaded_models[model_name] = loaded_model
|
|
416
|
+
all_overlap_genes.update(set(loaded_model.features) & data_genes)
|
|
417
|
+
|
|
418
|
+
if not loaded_models:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"No models passed gene overlap threshold. "
|
|
421
|
+
"Consider training a custom model for your panel genes."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Validate and prepare data for CellTypist
|
|
425
|
+
adata_for_prediction = _validate_celltypist_input(adata, norm_layer=norm_layer)
|
|
426
|
+
|
|
427
|
+
# Subset to overlapping genes
|
|
428
|
+
genes_mask = adata_for_prediction.var_names.isin(all_overlap_genes)
|
|
429
|
+
adata_subset = adata_for_prediction[:, genes_mask].copy()
|
|
430
|
+
logger.info(f"Predicting on {adata_subset.n_obs:,} cells × {adata_subset.n_vars:,} genes")
|
|
431
|
+
|
|
432
|
+
# Determine cluster column for voting
|
|
433
|
+
cluster_col = over_clustering or (
|
|
434
|
+
"leiden" if "leiden" in adata.obs.columns else None
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Copy cluster info to subset if using voting
|
|
438
|
+
if majority_voting and cluster_col and cluster_col in adata.obs.columns:
|
|
439
|
+
adata_subset.obs[cluster_col] = adata.obs[cluster_col].values
|
|
440
|
+
|
|
441
|
+
# Run predictions for each model
|
|
442
|
+
all_model_predictions = {}
|
|
443
|
+
|
|
444
|
+
for model_name, loaded_model in loaded_models.items():
|
|
445
|
+
logger.info(f" Running {model_name}...")
|
|
446
|
+
|
|
447
|
+
prediction = celltypist.annotate(
|
|
448
|
+
adata_subset,
|
|
449
|
+
model=loaded_model,
|
|
450
|
+
mode="best match",
|
|
451
|
+
majority_voting=majority_voting,
|
|
452
|
+
over_clustering=cluster_col if majority_voting else None,
|
|
453
|
+
min_prop=min_prop,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Get labels based on whether voting was enabled
|
|
457
|
+
if majority_voting and "majority_voting" in prediction.predicted_labels.columns:
|
|
458
|
+
labels = prediction.predicted_labels["majority_voting"]
|
|
459
|
+
else:
|
|
460
|
+
labels = prediction.predicted_labels["predicted_labels"]
|
|
461
|
+
|
|
462
|
+
confidence = prediction.probability_matrix.max(axis=1).values
|
|
463
|
+
all_model_predictions[model_name] = (labels, confidence)
|
|
464
|
+
|
|
465
|
+
gc.collect()
|
|
466
|
+
|
|
467
|
+
# Combine predictions (ensemble: highest confidence wins)
|
|
468
|
+
if len(loaded_models) == 1:
|
|
469
|
+
model_name = list(all_model_predictions.keys())[0]
|
|
470
|
+
labels, confidence = all_model_predictions[model_name]
|
|
471
|
+
per_cell_predictions = labels
|
|
472
|
+
per_cell_confidence = confidence
|
|
473
|
+
per_cell_source_model = pd.Series([model_name] * len(labels), index=labels.index)
|
|
474
|
+
else:
|
|
475
|
+
# Multi-model ensemble
|
|
476
|
+
cell_indices = list(all_model_predictions.values())[0][0].index
|
|
477
|
+
final_labels = []
|
|
478
|
+
final_confidence = []
|
|
479
|
+
final_source_model = []
|
|
480
|
+
|
|
481
|
+
for i, cell_idx in enumerate(cell_indices):
|
|
482
|
+
best_conf = -1.0
|
|
483
|
+
best_label = "Unknown"
|
|
484
|
+
best_model = "none"
|
|
485
|
+
|
|
486
|
+
for model_name, (labels, confidence) in all_model_predictions.items():
|
|
487
|
+
cell_conf = confidence[i]
|
|
488
|
+
if cell_conf > best_conf:
|
|
489
|
+
best_conf = cell_conf
|
|
490
|
+
best_label = labels.iloc[i]
|
|
491
|
+
best_model = model_name
|
|
492
|
+
|
|
493
|
+
final_labels.append(best_label)
|
|
494
|
+
final_confidence.append(best_conf)
|
|
495
|
+
final_source_model.append(best_model)
|
|
496
|
+
|
|
497
|
+
per_cell_predictions = pd.Series(final_labels, index=cell_indices)
|
|
498
|
+
per_cell_confidence = np.array(final_confidence)
|
|
499
|
+
per_cell_source_model = pd.Series(final_source_model, index=cell_indices)
|
|
500
|
+
|
|
501
|
+
# Store results (CellxGene standard column names)
|
|
502
|
+
adata.obs["cell_type_original"] = per_cell_predictions.values
|
|
503
|
+
adata.obs["cell_type_confidence_raw"] = per_cell_confidence
|
|
504
|
+
adata.obs["cell_type_model"] = per_cell_source_model.values
|
|
505
|
+
|
|
506
|
+
# Apply confidence threshold (post-hoc filter)
|
|
507
|
+
# Convert to numpy array of strings to allow "Unassigned" assignment
|
|
508
|
+
final_labels = np.array(per_cell_predictions.values, dtype=object)
|
|
509
|
+
if min_confidence > 0.0:
|
|
510
|
+
low_conf_mask = per_cell_confidence < min_confidence
|
|
511
|
+
n_unassigned = low_conf_mask.sum()
|
|
512
|
+
if n_unassigned > 0:
|
|
513
|
+
final_labels[low_conf_mask] = "Unassigned"
|
|
514
|
+
logger.info(
|
|
515
|
+
f"Confidence filter: {n_unassigned:,} cells ({100*n_unassigned/len(final_labels):.1f}%) "
|
|
516
|
+
f"below {min_confidence} threshold -> 'Unassigned'"
|
|
517
|
+
)
|
|
518
|
+
adata.obs["cell_type"] = pd.Categorical(final_labels)
|
|
519
|
+
|
|
520
|
+
# Store decision scores if requested (uses last model's prediction for now)
|
|
521
|
+
# In ensemble mode, only stores the winning model's scores
|
|
522
|
+
if store_decision_scores:
|
|
523
|
+
# For ensemble, we'd need to combine scores - for now, store best model's scores
|
|
524
|
+
# Get the last prediction result for decision matrix access
|
|
525
|
+
model_name = list(loaded_models.keys())[0] # Use first model for scores
|
|
526
|
+
loaded_model = loaded_models[model_name]
|
|
527
|
+
|
|
528
|
+
# Re-run prediction to get decision matrix (this is a limitation for ensemble)
|
|
529
|
+
# Future: Store during prediction loop
|
|
530
|
+
if len(loaded_models) == 1:
|
|
531
|
+
prediction_for_scores = celltypist.annotate(
|
|
532
|
+
adata_subset,
|
|
533
|
+
model=loaded_model,
|
|
534
|
+
mode="best match",
|
|
535
|
+
majority_voting=False,
|
|
536
|
+
)
|
|
537
|
+
adata = extract_decision_scores(
|
|
538
|
+
adata,
|
|
539
|
+
prediction_for_scores,
|
|
540
|
+
key_added="cell_type",
|
|
541
|
+
)
|
|
542
|
+
logger.info(
|
|
543
|
+
f"Stored decision scores in adata.obsm['cell_type_decision_scores']"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Apply confidence transformation if requested
|
|
547
|
+
# Store as main confidence column (cell_type_confidence) per CellxGene standard
|
|
548
|
+
if confidence_transform is not None and "cell_type_decision_scores" in adata.obsm:
|
|
549
|
+
decision_scores = adata.obsm["cell_type_decision_scores"]
|
|
550
|
+
transformed_conf = transform_confidence(decision_scores, method=confidence_transform)
|
|
551
|
+
adata.obs["cell_type_confidence"] = transformed_conf
|
|
552
|
+
logger.info(
|
|
553
|
+
f"Applied {confidence_transform} confidence transform "
|
|
554
|
+
f"(mean={transformed_conf.mean():.3f})"
|
|
555
|
+
)
|
|
556
|
+
else:
|
|
557
|
+
# Use raw confidence if no transform available
|
|
558
|
+
adata.obs["cell_type_confidence"] = per_cell_confidence
|
|
559
|
+
|
|
560
|
+
# Log summary
|
|
561
|
+
n_types = adata.obs["cell_type"].nunique()
|
|
562
|
+
mean_conf = np.mean(per_cell_confidence)
|
|
563
|
+
logger.info(f"Annotation complete: {n_types} cell types, mean confidence: {mean_conf:.3f}")
|
|
564
|
+
|
|
565
|
+
return adata
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def get_annotation_summary(adata: ad.AnnData) -> pd.DataFrame:
|
|
569
|
+
"""
|
|
570
|
+
Get summary of CellTypist annotations.
|
|
571
|
+
|
|
572
|
+
Parameters
|
|
573
|
+
----------
|
|
574
|
+
adata : AnnData
|
|
575
|
+
Annotated AnnData object.
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
pd.DataFrame
|
|
580
|
+
Summary with columns: cell_type, n_cells, pct_total, mean_confidence.
|
|
581
|
+
|
|
582
|
+
Examples
|
|
583
|
+
--------
|
|
584
|
+
>>> from spatialcore.annotation import get_annotation_summary
|
|
585
|
+
>>> summary = get_annotation_summary(adata)
|
|
586
|
+
>>> print(summary.head())
|
|
587
|
+
"""
|
|
588
|
+
# Support both old (celltypist) and new (cell_type) column names
|
|
589
|
+
label_col = "cell_type" if "cell_type" in adata.obs.columns else "celltypist"
|
|
590
|
+
conf_col = "cell_type_confidence" if "cell_type_confidence" in adata.obs.columns else "celltypist_confidence"
|
|
591
|
+
|
|
592
|
+
if label_col not in adata.obs.columns:
|
|
593
|
+
raise ValueError("No cell type annotations found. Run annotate_celltypist first.")
|
|
594
|
+
|
|
595
|
+
summary = adata.obs.groupby(label_col).agg({
|
|
596
|
+
conf_col: ["count", "mean"],
|
|
597
|
+
})
|
|
598
|
+
summary.columns = ["n_cells", "mean_confidence"]
|
|
599
|
+
summary["pct_total"] = 100 * summary["n_cells"] / adata.n_obs
|
|
600
|
+
summary = summary.sort_values("n_cells", ascending=False).reset_index()
|
|
601
|
+
summary.columns = ["cell_type", "n_cells", "mean_confidence", "pct_total"]
|
|
602
|
+
|
|
603
|
+
return summary
|