spatialcore 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcore/__init__.py +122 -0
- spatialcore/annotation/__init__.py +253 -0
- spatialcore/annotation/acquisition.py +529 -0
- spatialcore/annotation/annotate.py +603 -0
- spatialcore/annotation/cellxgene.py +365 -0
- spatialcore/annotation/confidence.py +802 -0
- spatialcore/annotation/discovery.py +529 -0
- spatialcore/annotation/expression.py +363 -0
- spatialcore/annotation/loading.py +529 -0
- spatialcore/annotation/markers.py +297 -0
- spatialcore/annotation/ontology.py +1282 -0
- spatialcore/annotation/patterns.py +247 -0
- spatialcore/annotation/pipeline.py +620 -0
- spatialcore/annotation/synapse.py +380 -0
- spatialcore/annotation/training.py +1457 -0
- spatialcore/annotation/validation.py +422 -0
- spatialcore/core/__init__.py +34 -0
- spatialcore/core/cache.py +118 -0
- spatialcore/core/logging.py +135 -0
- spatialcore/core/metadata.py +149 -0
- spatialcore/core/utils.py +768 -0
- spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
- spatialcore/data/markers/canonical_markers.json +83 -0
- spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
- spatialcore/plotting/__init__.py +109 -0
- spatialcore/plotting/benchmark.py +477 -0
- spatialcore/plotting/celltype.py +329 -0
- spatialcore/plotting/confidence.py +413 -0
- spatialcore/plotting/spatial.py +505 -0
- spatialcore/plotting/utils.py +411 -0
- spatialcore/plotting/validation.py +1342 -0
- spatialcore-0.1.9.dist-info/METADATA +213 -0
- spatialcore-0.1.9.dist-info/RECORD +36 -0
- spatialcore-0.1.9.dist-info/WHEEL +5 -0
- spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
- spatialcore-0.1.9.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,620 @@
|
|
|
1
|
+
"""
|
|
2
|
+
High-level cell typing pipeline for spatial transcriptomics.
|
|
3
|
+
|
|
4
|
+
This module provides consolidated entry points for the full annotation workflow:
|
|
5
|
+
- train_and_annotate(): Train custom model + annotate in one call
|
|
6
|
+
- TrainingConfig: YAML-serializable configuration for reproducibility
|
|
7
|
+
|
|
8
|
+
The pipeline integrates:
|
|
9
|
+
1. Reference loading and combination
|
|
10
|
+
2. Ontology ID filling (for semantic grouping)
|
|
11
|
+
3. Source-aware "Cap & Fill" balancing
|
|
12
|
+
4. CellTypist model training
|
|
13
|
+
5. Annotation with z-score confidence transformation
|
|
14
|
+
6. Ontology mapping
|
|
15
|
+
7. Validation plot generation
|
|
16
|
+
|
|
17
|
+
Column Naming (CellxGene Standard):
|
|
18
|
+
- cell_type: Final predicted cell type
|
|
19
|
+
- cell_type_confidence: Z-score transformed confidence
|
|
20
|
+
- cell_type_ontology_term_id: CL:XXXXX ontology ID
|
|
21
|
+
- cell_type_ontology_label: Canonical ontology name
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
27
|
+
import json
|
|
28
|
+
import gc
|
|
29
|
+
|
|
30
|
+
import anndata as ad
|
|
31
|
+
import numpy as np
|
|
32
|
+
import pandas as pd
|
|
33
|
+
|
|
34
|
+
from spatialcore.core.logging import get_logger
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ============================================================================
|
|
40
|
+
# Configuration
|
|
41
|
+
# ============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class TrainingConfig:
|
|
46
|
+
"""
|
|
47
|
+
YAML-serializable configuration for reproducible training + annotation.
|
|
48
|
+
|
|
49
|
+
This configuration captures all parameters needed to reproduce a training
|
|
50
|
+
run, from reference selection to confidence thresholds.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
tissue : str, default "unknown"
|
|
55
|
+
Tissue type for the model (used in model naming and selection).
|
|
56
|
+
references : List[str]
|
|
57
|
+
Paths to reference h5ad files. Required.
|
|
58
|
+
label_columns : List[str], optional
|
|
59
|
+
Cell type label column for each reference. If None, auto-detect.
|
|
60
|
+
balance_strategy : {"proportional", "equal"}, default "proportional"
|
|
61
|
+
How to distribute sampling across sources when balancing.
|
|
62
|
+
max_cells_per_type : int, default 10000
|
|
63
|
+
Maximum cells per cell type after balancing.
|
|
64
|
+
max_cells_per_ref : int, default 100000
|
|
65
|
+
Maximum cells to load per reference (memory management).
|
|
66
|
+
confidence_threshold : float, default 0.8
|
|
67
|
+
Threshold for marking low-confidence predictions as Unassigned.
|
|
68
|
+
add_ontology : bool, default True
|
|
69
|
+
Whether to add ontology IDs to predictions.
|
|
70
|
+
generate_plots : bool, default True
|
|
71
|
+
Whether to generate validation plots.
|
|
72
|
+
|
|
73
|
+
Examples
|
|
74
|
+
--------
|
|
75
|
+
>>> from spatialcore.annotation.pipeline import TrainingConfig
|
|
76
|
+
>>> config = TrainingConfig(
|
|
77
|
+
... tissue="lung",
|
|
78
|
+
... references=["ref1.h5ad", "ref2.h5ad"],
|
|
79
|
+
... balance_strategy="proportional",
|
|
80
|
+
... )
|
|
81
|
+
>>> config.to_yaml("training_config.yaml")
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
tissue: str = "unknown"
|
|
85
|
+
references: List[str] = field(default_factory=list)
|
|
86
|
+
label_columns: Optional[List[str]] = None
|
|
87
|
+
balance_strategy: Literal["proportional", "equal"] = "proportional"
|
|
88
|
+
max_cells_per_type: int = 10000
|
|
89
|
+
max_cells_per_ref: int = 100000
|
|
90
|
+
confidence_threshold: float = 0.8
|
|
91
|
+
add_ontology: bool = True
|
|
92
|
+
generate_plots: bool = True
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_yaml(cls, path: Union[str, Path]) -> "TrainingConfig":
|
|
96
|
+
"""Load configuration from YAML file."""
|
|
97
|
+
try:
|
|
98
|
+
import yaml
|
|
99
|
+
except ImportError:
|
|
100
|
+
raise ImportError("PyYAML is required for YAML config loading")
|
|
101
|
+
|
|
102
|
+
path = Path(path)
|
|
103
|
+
with open(path) as f:
|
|
104
|
+
data = yaml.safe_load(f)
|
|
105
|
+
|
|
106
|
+
return cls(**data)
|
|
107
|
+
|
|
108
|
+
def to_yaml(self, path: Union[str, Path]) -> None:
|
|
109
|
+
"""Save configuration to YAML file."""
|
|
110
|
+
try:
|
|
111
|
+
import yaml
|
|
112
|
+
except ImportError:
|
|
113
|
+
raise ImportError("PyYAML is required for YAML config saving")
|
|
114
|
+
|
|
115
|
+
path = Path(path)
|
|
116
|
+
with open(path, "w") as f:
|
|
117
|
+
yaml.dump(self.__dict__, f, default_flow_style=False)
|
|
118
|
+
|
|
119
|
+
logger.info(f"Saved training config to: {path}")
|
|
120
|
+
|
|
121
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
122
|
+
"""Convert to dictionary."""
|
|
123
|
+
return self.__dict__.copy()
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_dict(cls, data: Dict[str, Any]) -> "TrainingConfig":
|
|
127
|
+
"""Create from dictionary."""
|
|
128
|
+
return cls(**data)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ============================================================================
|
|
132
|
+
# High-Level Pipeline
|
|
133
|
+
# ============================================================================
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def train_and_annotate(
|
|
137
|
+
adata: ad.AnnData,
|
|
138
|
+
references: List[Union[str, Path]],
|
|
139
|
+
tissue: str = "unknown",
|
|
140
|
+
label_columns: Optional[List[str]] = None,
|
|
141
|
+
balance_strategy: Literal["proportional", "equal"] = "proportional",
|
|
142
|
+
max_cells_per_type: int = 10000,
|
|
143
|
+
max_cells_per_ref: int = 100000,
|
|
144
|
+
confidence_threshold: float = 0.8,
|
|
145
|
+
model_output: Optional[Union[str, Path]] = None,
|
|
146
|
+
plot_output: Optional[Union[str, Path]] = None,
|
|
147
|
+
add_ontology: bool = True,
|
|
148
|
+
generate_plots: bool = True,
|
|
149
|
+
copy: bool = False,
|
|
150
|
+
) -> ad.AnnData:
|
|
151
|
+
"""
|
|
152
|
+
Full workflow: train custom model on references, then annotate spatial data.
|
|
153
|
+
|
|
154
|
+
This is the core SpatialCore value proposition - NOT a thin wrapper around
|
|
155
|
+
CellTypist. The function provides significant added value:
|
|
156
|
+
|
|
157
|
+
1. **Panel-specific training** - Subsets references to spatial panel genes,
|
|
158
|
+
achieving ~100% gene overlap vs ~5-9% with pre-trained models.
|
|
159
|
+
2. **Source-aware balancing** - "Cap & Fill" strategy ensures all references
|
|
160
|
+
contribute to each cell type, preventing source-specific biases.
|
|
161
|
+
3. **CL ID-based grouping** - Groups semantically equivalent labels
|
|
162
|
+
(e.g., "CD4+ T cells" and "CD4-positive, alpha-beta T cell") by
|
|
163
|
+
their Cell Ontology ID for proper balancing.
|
|
164
|
+
4. **Z-score confidence** - Transforms raw logistic regression scores to
|
|
165
|
+
interpretable [0,1] confidence values that handle domain shift.
|
|
166
|
+
5. **Multi-tier ontology mapping** - Maps predictions to Cell Ontology
|
|
167
|
+
using pattern matching, exact match, token matching, and fuzzy overlap.
|
|
168
|
+
6. **Automatic validation plots** - Generates DEG heatmap, 2D marker
|
|
169
|
+
validation, confidence plots, and ontology mapping table.
|
|
170
|
+
|
|
171
|
+
Pipeline Stages
|
|
172
|
+
---------------
|
|
173
|
+
1. Get panel genes from spatial data (var_names)
|
|
174
|
+
2. Load + combine references (Ensembl→HUGO normalization, log1p 10k)
|
|
175
|
+
3. Fill missing ontology IDs (add_ontology_ids with skip_if_exists)
|
|
176
|
+
4. Balance by CL ID (source-aware "Cap & Fill")
|
|
177
|
+
5. Train CellTypist model
|
|
178
|
+
6. Annotate spatial data
|
|
179
|
+
7. Transform confidence (z-score)
|
|
180
|
+
8. Add ontology IDs to predictions
|
|
181
|
+
9. Generate validation plots
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
adata : AnnData
|
|
186
|
+
Spatial transcriptomics data to annotate. Must have gene names in
|
|
187
|
+
var_names (HUGO symbols preferred).
|
|
188
|
+
references : List[str or Path]
|
|
189
|
+
Paths to reference h5ad files for training.
|
|
190
|
+
tissue : str, default "unknown"
|
|
191
|
+
Tissue type for model naming.
|
|
192
|
+
label_columns : List[str], optional
|
|
193
|
+
Cell type label column for each reference. If None, auto-detect
|
|
194
|
+
common column names (cell_type, celltype, cell_type_ontology_term_id).
|
|
195
|
+
balance_strategy : {"proportional", "equal"}, default "proportional"
|
|
196
|
+
How to distribute sampling across sources:
|
|
197
|
+
- "proportional": Sample proportionally to source contribution
|
|
198
|
+
- "equal": Sample equally from each source
|
|
199
|
+
max_cells_per_type : int, default 10000
|
|
200
|
+
Maximum cells per cell type after balancing.
|
|
201
|
+
max_cells_per_ref : int, default 100000
|
|
202
|
+
Maximum cells to load per reference (memory management).
|
|
203
|
+
confidence_threshold : float, default 0.8
|
|
204
|
+
Cells with confidence below this threshold are marked "Unassigned".
|
|
205
|
+
model_output : str or Path, optional
|
|
206
|
+
Path to save trained CellTypist model (.pkl). If None, model is
|
|
207
|
+
not saved to disk.
|
|
208
|
+
plot_output : str or Path, optional
|
|
209
|
+
Directory to save validation plots. If None, uses current directory
|
|
210
|
+
when generate_plots=True.
|
|
211
|
+
add_ontology : bool, default True
|
|
212
|
+
Whether to map predictions to Cell Ontology IDs.
|
|
213
|
+
generate_plots : bool, default True
|
|
214
|
+
Whether to generate validation plots (DEG heatmap, 2D validation,
|
|
215
|
+
confidence plots, ontology mapping table).
|
|
216
|
+
copy : bool, default False
|
|
217
|
+
If True, return a copy of adata; otherwise modify in-place.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
AnnData
|
|
222
|
+
Annotated data with new columns (CellxGene standard names):
|
|
223
|
+
- cell_type: Predicted cell type
|
|
224
|
+
- cell_type_confidence: Z-score transformed confidence [0, 1]
|
|
225
|
+
- cell_type_ontology_term_id: CL:XXXXX (if add_ontology=True)
|
|
226
|
+
- cell_type_ontology_label: Canonical name (if add_ontology=True)
|
|
227
|
+
|
|
228
|
+
And metadata in uns:
|
|
229
|
+
- spatialcore_annotation: Dict with training parameters and stats
|
|
230
|
+
|
|
231
|
+
Examples
|
|
232
|
+
--------
|
|
233
|
+
>>> from spatialcore.annotation.pipeline import train_and_annotate
|
|
234
|
+
>>> import scanpy as sc
|
|
235
|
+
>>>
|
|
236
|
+
>>> # Load spatial data
|
|
237
|
+
>>> adata = sc.read_h5ad("xenium_lung.h5ad")
|
|
238
|
+
>>>
|
|
239
|
+
>>> # Train and annotate
|
|
240
|
+
>>> adata = train_and_annotate(
|
|
241
|
+
... adata,
|
|
242
|
+
... references=["hlca_core.h5ad", "tabula_sapiens_lung.h5ad"],
|
|
243
|
+
... tissue="lung",
|
|
244
|
+
... balance_strategy="proportional",
|
|
245
|
+
... confidence_threshold=0.8,
|
|
246
|
+
... plot_output="./qc_plots/",
|
|
247
|
+
... )
|
|
248
|
+
>>>
|
|
249
|
+
>>> # Check results
|
|
250
|
+
>>> print(adata.obs["cell_type"].value_counts())
|
|
251
|
+
|
|
252
|
+
See Also
|
|
253
|
+
--------
|
|
254
|
+
train_and_annotate_config : Config-driven version for reproducibility.
|
|
255
|
+
annotate_celltypist : Lower-level annotation function.
|
|
256
|
+
combine_references : Reference combination without annotation.
|
|
257
|
+
"""
|
|
258
|
+
from spatialcore.annotation.training import (
|
|
259
|
+
combine_references,
|
|
260
|
+
get_panel_genes,
|
|
261
|
+
subsample_balanced,
|
|
262
|
+
train_celltypist_model,
|
|
263
|
+
save_model_artifacts,
|
|
264
|
+
)
|
|
265
|
+
from spatialcore.annotation.annotate import annotate_celltypist
|
|
266
|
+
from spatialcore.annotation.ontology import add_ontology_ids
|
|
267
|
+
from spatialcore.plotting import generate_annotation_plots
|
|
268
|
+
|
|
269
|
+
if copy:
|
|
270
|
+
adata = adata.copy()
|
|
271
|
+
|
|
272
|
+
logger.info("=" * 60)
|
|
273
|
+
logger.info("SpatialCore Cell Typing Pipeline")
|
|
274
|
+
logger.info("=" * 60)
|
|
275
|
+
|
|
276
|
+
# -------------------------------------------------------------------------
|
|
277
|
+
# Stage 1: Get panel genes from spatial data
|
|
278
|
+
# -------------------------------------------------------------------------
|
|
279
|
+
logger.info("Stage 1: Extracting panel genes from spatial data...")
|
|
280
|
+
panel_genes = get_panel_genes(adata)
|
|
281
|
+
logger.info(f" Panel genes: {len(panel_genes)}")
|
|
282
|
+
|
|
283
|
+
# -------------------------------------------------------------------------
|
|
284
|
+
# Stage 2: Load and combine references
|
|
285
|
+
# -------------------------------------------------------------------------
|
|
286
|
+
logger.info("Stage 2: Loading and combining references...")
|
|
287
|
+
|
|
288
|
+
# Auto-detect label columns if not provided
|
|
289
|
+
if label_columns is None:
|
|
290
|
+
label_columns = _detect_label_columns(references)
|
|
291
|
+
|
|
292
|
+
combined = combine_references(
|
|
293
|
+
reference_paths=references,
|
|
294
|
+
label_columns=label_columns,
|
|
295
|
+
output_column="original_label",
|
|
296
|
+
max_cells_per_ref=max_cells_per_ref,
|
|
297
|
+
target_genes=panel_genes,
|
|
298
|
+
normalize_data=True,
|
|
299
|
+
)
|
|
300
|
+
logger.info(f" Combined: {combined.n_obs:,} cells, {combined.n_vars:,} genes")
|
|
301
|
+
|
|
302
|
+
# -------------------------------------------------------------------------
|
|
303
|
+
# Stage 3: Fill missing ontology IDs
|
|
304
|
+
# -------------------------------------------------------------------------
|
|
305
|
+
logger.info("Stage 3: Filling missing ontology IDs...")
|
|
306
|
+
|
|
307
|
+
combined, _, ontology_result = add_ontology_ids(
|
|
308
|
+
combined,
|
|
309
|
+
source_col="original_label",
|
|
310
|
+
target_col="cell_type_ontology_term_id",
|
|
311
|
+
name_col="cell_type_ontology_label",
|
|
312
|
+
skip_if_exists=True, # Preserve CellxGene's native IDs
|
|
313
|
+
copy=False,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# -------------------------------------------------------------------------
|
|
317
|
+
# Stage 4: Balance by CL ID (source-aware)
|
|
318
|
+
# -------------------------------------------------------------------------
|
|
319
|
+
logger.info("Stage 4: Balancing training data (source-aware)...")
|
|
320
|
+
|
|
321
|
+
balanced = subsample_balanced(
|
|
322
|
+
combined,
|
|
323
|
+
label_column="original_label",
|
|
324
|
+
group_by_column="cell_type_ontology_term_id", # Group by CL ID!
|
|
325
|
+
source_column="reference_source",
|
|
326
|
+
source_balance=balance_strategy,
|
|
327
|
+
max_cells_per_type=max_cells_per_type,
|
|
328
|
+
copy=True,
|
|
329
|
+
)
|
|
330
|
+
logger.info(f" Balanced: {balanced.n_obs:,} cells")
|
|
331
|
+
|
|
332
|
+
# Release combined reference data - no longer needed after balancing
|
|
333
|
+
del combined
|
|
334
|
+
gc.collect()
|
|
335
|
+
|
|
336
|
+
# -------------------------------------------------------------------------
|
|
337
|
+
# Stage 5: Train CellTypist model
|
|
338
|
+
# -------------------------------------------------------------------------
|
|
339
|
+
logger.info("Stage 5: Training CellTypist model...")
|
|
340
|
+
|
|
341
|
+
# Train on canonical ontology names (human-readable, semantically grouped)
|
|
342
|
+
# Balancing grouped by CL ID, so synonyms are consolidated under the same
|
|
343
|
+
# canonical name in cell_type_ontology_label (e.g., "CD4+ T cell" and
|
|
344
|
+
# "CD4-positive, alpha-beta T cell" both become the canonical CL name)
|
|
345
|
+
training_result = train_celltypist_model(
|
|
346
|
+
balanced,
|
|
347
|
+
label_column="cell_type_ontology_label", # Train on canonical names
|
|
348
|
+
feature_selection=False, # Use panel genes as-is
|
|
349
|
+
n_jobs=-1,
|
|
350
|
+
)
|
|
351
|
+
model = training_result["model"]
|
|
352
|
+
n_training_cells = balanced.n_obs
|
|
353
|
+
|
|
354
|
+
# Save model to user-defined path
|
|
355
|
+
model_path = Path(model_output)
|
|
356
|
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
|
357
|
+
save_model_artifacts(
|
|
358
|
+
model,
|
|
359
|
+
output_dir=model_path.parent,
|
|
360
|
+
model_name=model_path.stem,
|
|
361
|
+
training_metadata={
|
|
362
|
+
"references": [str(p) for p in references],
|
|
363
|
+
"n_cells": training_result["n_cells_trained"],
|
|
364
|
+
"n_genes": training_result["n_genes"],
|
|
365
|
+
"n_cell_types": training_result["n_cell_types"],
|
|
366
|
+
"cell_types": training_result["cell_types"],
|
|
367
|
+
},
|
|
368
|
+
)
|
|
369
|
+
logger.info(f" Model saved to: {model_path}")
|
|
370
|
+
|
|
371
|
+
# Release training data
|
|
372
|
+
del balanced
|
|
373
|
+
gc.collect()
|
|
374
|
+
|
|
375
|
+
# -------------------------------------------------------------------------
|
|
376
|
+
# Stage 6: Annotate spatial data
|
|
377
|
+
# -------------------------------------------------------------------------
|
|
378
|
+
logger.info("Stage 6: Annotating spatial data...")
|
|
379
|
+
|
|
380
|
+
adata = annotate_celltypist(
|
|
381
|
+
adata,
|
|
382
|
+
custom_model_path=model_path,
|
|
383
|
+
confidence_transform="zscore",
|
|
384
|
+
store_decision_scores=True,
|
|
385
|
+
min_confidence=0.0,
|
|
386
|
+
copy=False,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# -------------------------------------------------------------------------
|
|
390
|
+
# Stage 7: Apply confidence threshold
|
|
391
|
+
# -------------------------------------------------------------------------
|
|
392
|
+
if confidence_threshold > 0:
|
|
393
|
+
conf = adata.obs["cell_type_confidence"].values
|
|
394
|
+
low_conf_mask = conf < confidence_threshold
|
|
395
|
+
n_low = low_conf_mask.sum()
|
|
396
|
+
|
|
397
|
+
if n_low > 0:
|
|
398
|
+
# Mark low-confidence cells as Unassigned
|
|
399
|
+
labels = adata.obs["cell_type"].astype(str).copy()
|
|
400
|
+
labels[low_conf_mask] = "Unassigned"
|
|
401
|
+
adata.obs["cell_type"] = pd.Categorical(labels)
|
|
402
|
+
|
|
403
|
+
pct = 100 * n_low / adata.n_obs
|
|
404
|
+
logger.info(
|
|
405
|
+
f"Stage 7: Marked {n_low:,} cells ({pct:.1f}%) as Unassigned "
|
|
406
|
+
f"(confidence < {confidence_threshold})"
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# -------------------------------------------------------------------------
|
|
410
|
+
# Stage 8: Add ontology IDs to predictions
|
|
411
|
+
# -------------------------------------------------------------------------
|
|
412
|
+
if add_ontology:
|
|
413
|
+
logger.info("Stage 8: Mapping predictions to Cell Ontology...")
|
|
414
|
+
|
|
415
|
+
adata, _, _ = add_ontology_ids(
|
|
416
|
+
adata,
|
|
417
|
+
source_col="cell_type",
|
|
418
|
+
target_col="cell_type_ontology_term_id",
|
|
419
|
+
name_col="cell_type_ontology_label",
|
|
420
|
+
skip_if_exists=False, # Map all predictions
|
|
421
|
+
copy=False,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# -------------------------------------------------------------------------
|
|
425
|
+
# Stage 9: Generate validation plots
|
|
426
|
+
# -------------------------------------------------------------------------
|
|
427
|
+
if generate_plots:
|
|
428
|
+
logger.info("Stage 9: Generating validation plots...")
|
|
429
|
+
|
|
430
|
+
# Release training artifacts before memory-intensive plot generation
|
|
431
|
+
del model
|
|
432
|
+
del training_result
|
|
433
|
+
gc.collect()
|
|
434
|
+
|
|
435
|
+
output_dir = Path(plot_output) if plot_output else Path(".")
|
|
436
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
437
|
+
|
|
438
|
+
try:
|
|
439
|
+
generate_annotation_plots(
|
|
440
|
+
adata,
|
|
441
|
+
label_column="cell_type_ontology_label", # Use standardized ontology names
|
|
442
|
+
confidence_column="cell_type_confidence",
|
|
443
|
+
output_dir=output_dir,
|
|
444
|
+
prefix=f"{tissue}_celltyping",
|
|
445
|
+
confidence_threshold=confidence_threshold,
|
|
446
|
+
# Ontology columns for mapping table
|
|
447
|
+
source_label_column="cell_type",
|
|
448
|
+
ontology_name_column="cell_type_ontology_label",
|
|
449
|
+
ontology_id_column="cell_type_ontology_term_id",
|
|
450
|
+
)
|
|
451
|
+
except Exception as e:
|
|
452
|
+
logger.warning(f"Plot generation failed: {e}")
|
|
453
|
+
|
|
454
|
+
# -------------------------------------------------------------------------
|
|
455
|
+
# Store metadata
|
|
456
|
+
# -------------------------------------------------------------------------
|
|
457
|
+
adata.uns["spatialcore_annotation"] = {
|
|
458
|
+
"tissue": tissue,
|
|
459
|
+
"n_references": len(references),
|
|
460
|
+
"references": [str(r) for r in references],
|
|
461
|
+
"panel_genes": len(panel_genes),
|
|
462
|
+
"training_cells": n_training_cells,
|
|
463
|
+
"balance_strategy": balance_strategy,
|
|
464
|
+
"max_cells_per_type": max_cells_per_type,
|
|
465
|
+
"confidence_threshold": confidence_threshold,
|
|
466
|
+
"n_cell_types": adata.obs["cell_type"].nunique(),
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
logger.info("=" * 60)
|
|
470
|
+
logger.info("Pipeline complete!")
|
|
471
|
+
logger.info(f" Cell types: {adata.obs['cell_type'].nunique()}")
|
|
472
|
+
logger.info(f" Mean confidence: {adata.obs['cell_type_confidence'].mean():.3f}")
|
|
473
|
+
logger.info("=" * 60)
|
|
474
|
+
|
|
475
|
+
return adata
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def train_and_annotate_config(
|
|
479
|
+
adata: ad.AnnData,
|
|
480
|
+
config: TrainingConfig,
|
|
481
|
+
model_output: Optional[Union[str, Path]] = None,
|
|
482
|
+
plot_output: Optional[Union[str, Path]] = None,
|
|
483
|
+
copy: bool = False,
|
|
484
|
+
) -> ad.AnnData:
|
|
485
|
+
"""
|
|
486
|
+
Config-driven training + annotation for reproducible workflows.
|
|
487
|
+
|
|
488
|
+
This is a convenience wrapper around train_and_annotate() that accepts
|
|
489
|
+
a TrainingConfig object instead of individual parameters.
|
|
490
|
+
|
|
491
|
+
Parameters
|
|
492
|
+
----------
|
|
493
|
+
adata : AnnData
|
|
494
|
+
Spatial transcriptomics data to annotate.
|
|
495
|
+
config : TrainingConfig
|
|
496
|
+
Configuration object with training parameters.
|
|
497
|
+
model_output : str or Path, optional
|
|
498
|
+
Path to save trained model.
|
|
499
|
+
plot_output : str or Path, optional
|
|
500
|
+
Directory to save validation plots.
|
|
501
|
+
copy : bool, default False
|
|
502
|
+
If True, return a copy of adata.
|
|
503
|
+
|
|
504
|
+
Returns
|
|
505
|
+
-------
|
|
506
|
+
AnnData
|
|
507
|
+
Annotated data with cell type predictions.
|
|
508
|
+
|
|
509
|
+
Examples
|
|
510
|
+
--------
|
|
511
|
+
>>> from spatialcore.annotation.pipeline import (
|
|
512
|
+
... TrainingConfig,
|
|
513
|
+
... train_and_annotate_config,
|
|
514
|
+
... )
|
|
515
|
+
>>>
|
|
516
|
+
>>> # Load config from YAML
|
|
517
|
+
>>> config = TrainingConfig.from_yaml("training_config.yaml")
|
|
518
|
+
>>>
|
|
519
|
+
>>> # Run pipeline
|
|
520
|
+
>>> adata = train_and_annotate_config(adata, config)
|
|
521
|
+
"""
|
|
522
|
+
return train_and_annotate(
|
|
523
|
+
adata=adata,
|
|
524
|
+
references=config.references,
|
|
525
|
+
tissue=config.tissue,
|
|
526
|
+
label_columns=config.label_columns,
|
|
527
|
+
balance_strategy=config.balance_strategy,
|
|
528
|
+
max_cells_per_type=config.max_cells_per_type,
|
|
529
|
+
max_cells_per_ref=config.max_cells_per_ref,
|
|
530
|
+
confidence_threshold=config.confidence_threshold,
|
|
531
|
+
model_output=model_output,
|
|
532
|
+
plot_output=plot_output,
|
|
533
|
+
add_ontology=config.add_ontology,
|
|
534
|
+
generate_plots=config.generate_plots,
|
|
535
|
+
copy=copy,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# ============================================================================
|
|
540
|
+
# Helper Functions
|
|
541
|
+
# ============================================================================
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _detect_label_columns(
|
|
545
|
+
references: List[Union[str, Path]],
|
|
546
|
+
) -> List[str]:
|
|
547
|
+
"""
|
|
548
|
+
Auto-detect cell type label columns in reference files.
|
|
549
|
+
|
|
550
|
+
Searches for common column names: cell_type, celltype, cell_type_ontology_term_id,
|
|
551
|
+
Cell_type, CellType, etc.
|
|
552
|
+
|
|
553
|
+
Parameters
|
|
554
|
+
----------
|
|
555
|
+
references : List[str or Path]
|
|
556
|
+
Paths to reference h5ad files.
|
|
557
|
+
|
|
558
|
+
Returns
|
|
559
|
+
-------
|
|
560
|
+
List[str]
|
|
561
|
+
Detected label column for each reference.
|
|
562
|
+
"""
|
|
563
|
+
common_columns = [
|
|
564
|
+
"cell_type",
|
|
565
|
+
"celltype",
|
|
566
|
+
"cell_type_ontology_term_id",
|
|
567
|
+
"Cell_type",
|
|
568
|
+
"CellType",
|
|
569
|
+
"cell_type_label",
|
|
570
|
+
"annotation",
|
|
571
|
+
"cluster",
|
|
572
|
+
"leiden",
|
|
573
|
+
]
|
|
574
|
+
|
|
575
|
+
label_columns = []
|
|
576
|
+
|
|
577
|
+
for ref_path in references:
|
|
578
|
+
ref_path = Path(ref_path)
|
|
579
|
+
|
|
580
|
+
# Read just the obs columns (don't load full data)
|
|
581
|
+
try:
|
|
582
|
+
import h5py
|
|
583
|
+
|
|
584
|
+
with h5py.File(ref_path, "r") as f:
|
|
585
|
+
if "obs" in f:
|
|
586
|
+
if "__categories" in f["obs"]:
|
|
587
|
+
# Categorical columns
|
|
588
|
+
obs_cols = list(f["obs"]["__categories"].keys())
|
|
589
|
+
else:
|
|
590
|
+
obs_cols = [
|
|
591
|
+
k for k in f["obs"].keys()
|
|
592
|
+
if not k.startswith("_")
|
|
593
|
+
]
|
|
594
|
+
else:
|
|
595
|
+
obs_cols = []
|
|
596
|
+
except Exception:
|
|
597
|
+
# Fallback: load full file
|
|
598
|
+
import anndata
|
|
599
|
+
adata_ref = anndata.read_h5ad(ref_path, backed="r")
|
|
600
|
+
obs_cols = list(adata_ref.obs.columns)
|
|
601
|
+
adata_ref.file.close()
|
|
602
|
+
|
|
603
|
+
# Find first matching column
|
|
604
|
+
found = None
|
|
605
|
+
for col in common_columns:
|
|
606
|
+
if col in obs_cols:
|
|
607
|
+
found = col
|
|
608
|
+
break
|
|
609
|
+
|
|
610
|
+
if found is None:
|
|
611
|
+
raise ValueError(
|
|
612
|
+
f"Could not auto-detect label column in {ref_path.name}. "
|
|
613
|
+
f"Available columns: {obs_cols[:10]}... "
|
|
614
|
+
f"Please provide label_columns explicitly."
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
label_columns.append(found)
|
|
618
|
+
logger.debug(f" {ref_path.name}: detected label column '{found}'")
|
|
619
|
+
|
|
620
|
+
return label_columns
|