spatialcore 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialcore/__init__.py +122 -0
- spatialcore/annotation/__init__.py +253 -0
- spatialcore/annotation/acquisition.py +529 -0
- spatialcore/annotation/annotate.py +603 -0
- spatialcore/annotation/cellxgene.py +365 -0
- spatialcore/annotation/confidence.py +802 -0
- spatialcore/annotation/discovery.py +529 -0
- spatialcore/annotation/expression.py +363 -0
- spatialcore/annotation/loading.py +529 -0
- spatialcore/annotation/markers.py +297 -0
- spatialcore/annotation/ontology.py +1282 -0
- spatialcore/annotation/patterns.py +247 -0
- spatialcore/annotation/pipeline.py +620 -0
- spatialcore/annotation/synapse.py +380 -0
- spatialcore/annotation/training.py +1457 -0
- spatialcore/annotation/validation.py +422 -0
- spatialcore/core/__init__.py +34 -0
- spatialcore/core/cache.py +118 -0
- spatialcore/core/logging.py +135 -0
- spatialcore/core/metadata.py +149 -0
- spatialcore/core/utils.py +768 -0
- spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
- spatialcore/data/markers/canonical_markers.json +83 -0
- spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
- spatialcore/plotting/__init__.py +109 -0
- spatialcore/plotting/benchmark.py +477 -0
- spatialcore/plotting/celltype.py +329 -0
- spatialcore/plotting/confidence.py +413 -0
- spatialcore/plotting/spatial.py +505 -0
- spatialcore/plotting/utils.py +411 -0
- spatialcore/plotting/validation.py +1342 -0
- spatialcore-0.1.9.dist-info/METADATA +213 -0
- spatialcore-0.1.9.dist-info/RECORD +36 -0
- spatialcore-0.1.9.dist-info/WHEEL +5 -0
- spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
- spatialcore-0.1.9.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1457 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CellTypist model training utilities.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for:
|
|
5
|
+
1. Combining multiple reference datasets for training
|
|
6
|
+
2. Training custom CellTypist models
|
|
7
|
+
3. Panel gene subsetting for spatial transcriptomics
|
|
8
|
+
|
|
9
|
+
For spatial data (e.g., Xenium), custom models trained on panel-specific genes
|
|
10
|
+
achieve ~100% gene utilization vs ~8% with pre-trained models.
|
|
11
|
+
|
|
12
|
+
References:
|
|
13
|
+
- CellTypist: https://www.celltypist.org/
|
|
14
|
+
- Domínguez Conde et al., Science (2022)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
import json
|
|
21
|
+
import gc
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import scanpy as sc
|
|
26
|
+
import anndata as ad
|
|
27
|
+
|
|
28
|
+
from spatialcore.core.logging import get_logger
|
|
29
|
+
from spatialcore.core.utils import (
|
|
30
|
+
load_ensembl_to_hugo_mapping,
|
|
31
|
+
normalize_gene_names,
|
|
32
|
+
check_normalization_status,
|
|
33
|
+
)
|
|
34
|
+
from spatialcore.annotation.loading import (
|
|
35
|
+
load_adata_backed,
|
|
36
|
+
ensure_normalized,
|
|
37
|
+
get_available_memory_gb,
|
|
38
|
+
)
|
|
39
|
+
from spatialcore.annotation.validation import validate_cell_type_column
|
|
40
|
+
from spatialcore.annotation.acquisition import resolve_uri_to_local
|
|
41
|
+
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
# Default cache directory for downloaded references
|
|
45
|
+
DEFAULT_CACHE_DIR = Path.home() / ".spatialcore" / "cache" / "references"
|
|
46
|
+
|
|
47
|
+
# Default labels to exclude from training (ambiguous/uninformative)
|
|
48
|
+
# These are matched EXACTLY (case-sensitive, no partial matching)
|
|
49
|
+
# "unknown cells" would NOT be filtered (not an exact match to "unknown")
|
|
50
|
+
DEFAULT_EXCLUDE_LABELS = [
|
|
51
|
+
"unknown",
|
|
52
|
+
"Unknown",
|
|
53
|
+
"UNKNOWN",
|
|
54
|
+
"unassigned",
|
|
55
|
+
"Unassigned",
|
|
56
|
+
"na",
|
|
57
|
+
"NA",
|
|
58
|
+
"N/A",
|
|
59
|
+
"n/a",
|
|
60
|
+
"none",
|
|
61
|
+
"None",
|
|
62
|
+
"null",
|
|
63
|
+
"doublet",
|
|
64
|
+
"Doublet",
|
|
65
|
+
"low quality",
|
|
66
|
+
"Low quality",
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ============================================================================
|
|
71
|
+
# Reference Combination
|
|
72
|
+
# ============================================================================
|
|
73
|
+
|
|
74
|
+
def combine_references(
|
|
75
|
+
reference_paths: List[Union[str, Path]],
|
|
76
|
+
label_columns: List[str],
|
|
77
|
+
output_column: str = "original_label",
|
|
78
|
+
max_cells_per_ref: int = 100000,
|
|
79
|
+
target_genes: Optional[List[str]] = None,
|
|
80
|
+
normalize_data: bool = True,
|
|
81
|
+
random_state: int = 42,
|
|
82
|
+
validate_labels: bool = True,
|
|
83
|
+
min_cells_per_type: int = 10,
|
|
84
|
+
strict_validation: bool = False,
|
|
85
|
+
cache_dir: Optional[Path] = None,
|
|
86
|
+
exclude_labels: Optional[List[str]] = None,
|
|
87
|
+
filter_min_cells: bool = True,
|
|
88
|
+
) -> ad.AnnData:
|
|
89
|
+
"""
|
|
90
|
+
Combine multiple reference datasets for CellTypist training.
|
|
91
|
+
|
|
92
|
+
This function handles:
|
|
93
|
+
1. Memory-efficient loading with per-reference cell caps (stratified)
|
|
94
|
+
2. Gene name normalization (Ensembl → HUGO)
|
|
95
|
+
3. Expression normalization (log1p to 10k)
|
|
96
|
+
4. Gene intersection (with optional panel gene subsetting)
|
|
97
|
+
5. Concatenation with source tracking
|
|
98
|
+
|
|
99
|
+
IMPORTANT: This function does NOT perform post-combine balancing.
|
|
100
|
+
Call subsample_balanced() on the output for source-aware balancing.
|
|
101
|
+
For semantic grouping by ontology ID, first run add_ontology_ids() then use
|
|
102
|
+
subsample_balanced(group_by_column="cell_type_ontology_term_id").
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
reference_paths : List[str or Path]
|
|
107
|
+
Paths to reference h5ad files. Supports:
|
|
108
|
+
|
|
109
|
+
- Local paths: ``/data/references/lung.h5ad``
|
|
110
|
+
- GCS URIs: ``gs://bucket/references/lung.h5ad``
|
|
111
|
+
- S3 URIs: ``s3://bucket/references/lung.h5ad``
|
|
112
|
+
|
|
113
|
+
Cloud files are automatically downloaded to cache_dir and loaded
|
|
114
|
+
in memory-efficient backed mode.
|
|
115
|
+
label_columns : List[str]
|
|
116
|
+
Cell type label column for each reference.
|
|
117
|
+
output_column : str, default "original_label"
|
|
118
|
+
Column name for unified cell type labels in output.
|
|
119
|
+
Use "original_label" (new default) for clarity that this is the raw
|
|
120
|
+
source label before any harmonization.
|
|
121
|
+
max_cells_per_ref : int, default 100000
|
|
122
|
+
Maximum cells to load per reference. Uses stratified sampling
|
|
123
|
+
to preserve natural cell type proportions within each reference.
|
|
124
|
+
This is for MEMORY MANAGEMENT during loading, not training balance.
|
|
125
|
+
target_genes : List[str], optional
|
|
126
|
+
Panel genes to subset to (e.g., from spatial data via get_panel_genes()).
|
|
127
|
+
If provided, each reference is subset to these genes before
|
|
128
|
+
finding the intersection. This ensures maximum gene utilization.
|
|
129
|
+
normalize_data : bool, default True
|
|
130
|
+
Whether to ensure log1p(10k) normalization.
|
|
131
|
+
random_state : int, default 42
|
|
132
|
+
Random seed for reproducibility.
|
|
133
|
+
validate_labels : bool, default True
|
|
134
|
+
Run cell type column validation before combining. Checks for
|
|
135
|
+
null values, suspicious patterns, and cardinality issues.
|
|
136
|
+
min_cells_per_type : int, default 10
|
|
137
|
+
Minimum cells required per cell type. Used for validation warnings
|
|
138
|
+
and, when ``filter_min_cells=True``, to remove low-count types
|
|
139
|
+
after concatenation.
|
|
140
|
+
strict_validation : bool, default False
|
|
141
|
+
If True, fail on validation warnings (not just errors).
|
|
142
|
+
cache_dir : Path, optional
|
|
143
|
+
Directory for caching downloaded cloud files. Defaults to
|
|
144
|
+
``~/.spatialcore/cache/references/``. Only used for gs:// and s3:// URIs.
|
|
145
|
+
exclude_labels : List[str], optional
|
|
146
|
+
Cell type labels to exclude from the combined output. Uses exact
|
|
147
|
+
case-sensitive matching (no partial matches). Cells with these labels
|
|
148
|
+
are removed after concatenation. If None (default), uses
|
|
149
|
+
DEFAULT_EXCLUDE_LABELS which includes common ambiguous labels like
|
|
150
|
+
"unknown", "Unknown", "unassigned", "NA", "doublet", etc.
|
|
151
|
+
Pass an empty list ``[]`` to disable label filtering entirely.
|
|
152
|
+
filter_min_cells : bool, default True
|
|
153
|
+
If True, remove cell types with fewer than ``min_cells_per_type``
|
|
154
|
+
cells after concatenation. If False, only warn about low-count
|
|
155
|
+
types (original behavior). Filtering is recommended for training
|
|
156
|
+
as singleton or very-low-count cell types can destabilize models.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
AnnData
|
|
161
|
+
Combined reference data with:
|
|
162
|
+
- Unified cell type labels in output_column
|
|
163
|
+
- Source tracking in .obs["reference_source"]
|
|
164
|
+
- Gene intersection applied
|
|
165
|
+
- Ready for subsample_balanced()
|
|
166
|
+
|
|
167
|
+
Notes
|
|
168
|
+
-----
|
|
169
|
+
**Two-stage workflow for multi-reference training:**
|
|
170
|
+
|
|
171
|
+
1. **combine_references()** (this function): Load, normalize, intersect genes
|
|
172
|
+
2. **subsample_balanced()**: Source-aware balancing for training quality
|
|
173
|
+
|
|
174
|
+
**For spatial transcriptomics**, providing `target_genes` (panel genes)
|
|
175
|
+
before training is critical:
|
|
176
|
+
- Pre-trained models: ~8% gene overlap with 400-gene panel
|
|
177
|
+
- Custom panel models: 100% gene overlap
|
|
178
|
+
|
|
179
|
+
Examples
|
|
180
|
+
--------
|
|
181
|
+
>>> from spatialcore.annotation import (
|
|
182
|
+
... combine_references, subsample_balanced, get_panel_genes
|
|
183
|
+
... )
|
|
184
|
+
>>> # Step 1: Get panel genes from spatial data
|
|
185
|
+
>>> panel_genes = get_panel_genes(xenium_adata)
|
|
186
|
+
>>>
|
|
187
|
+
>>> # Step 2: Combine references (loading + gene intersection)
|
|
188
|
+
>>> # Supports local paths and cloud URIs (GCS, S3)
|
|
189
|
+
>>> combined = combine_references(
|
|
190
|
+
... reference_paths=[
|
|
191
|
+
... "gs://my-bucket/references/hlca.h5ad", # GCS
|
|
192
|
+
... "s3://my-bucket/references/liver.h5ad", # S3
|
|
193
|
+
... "/local/data/study3.h5ad", # Local
|
|
194
|
+
... ],
|
|
195
|
+
... label_columns=["cell_type", "cell_type", "cell_type"],
|
|
196
|
+
... max_cells_per_ref=100000, # Memory cap during loading
|
|
197
|
+
... target_genes=panel_genes, # Intersect with spatial panel
|
|
198
|
+
... )
|
|
199
|
+
>>> # Output: combined AnnData with reference_source tracked in .obs
|
|
200
|
+
>>>
|
|
201
|
+
>>> # Step 3: Source-aware balancing (SEPARATE STEP)
|
|
202
|
+
>>> balanced = subsample_balanced(
|
|
203
|
+
... combined,
|
|
204
|
+
... label_column="original_label",
|
|
205
|
+
... max_cells_per_type=10000,
|
|
206
|
+
... source_balance="proportional",
|
|
207
|
+
... )
|
|
208
|
+
"""
|
|
209
|
+
if len(reference_paths) != len(label_columns):
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"Number of paths ({len(reference_paths)}) must match "
|
|
212
|
+
f"number of label columns ({len(label_columns)})"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Initialize cache directory for cloud downloads
|
|
216
|
+
if cache_dir is None:
|
|
217
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
218
|
+
cache_dir = Path(cache_dir)
|
|
219
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
220
|
+
|
|
221
|
+
# Load gene mapping once at start (reused for all references)
|
|
222
|
+
logger.info("Loading Ensembl to HUGO gene mapping...")
|
|
223
|
+
ensembl_to_hugo = load_ensembl_to_hugo_mapping()
|
|
224
|
+
logger.info(f"Loaded {len(ensembl_to_hugo):,} gene mappings")
|
|
225
|
+
|
|
226
|
+
adatas = []
|
|
227
|
+
validation_results = [] # Track validation results for each reference
|
|
228
|
+
|
|
229
|
+
for i, (ref_path, label_col) in enumerate(zip(reference_paths, label_columns)):
|
|
230
|
+
ref_path_str = str(ref_path)
|
|
231
|
+
|
|
232
|
+
# Extract source name for logging (handle URIs and local paths)
|
|
233
|
+
if ref_path_str.startswith(("gs://", "s3://")):
|
|
234
|
+
# Cloud URI: extract filename from URI
|
|
235
|
+
source_name = Path(ref_path_str.split("/")[-1]).stem
|
|
236
|
+
logger.info(f"\n[{i+1}/{len(reference_paths)}] Loading: {ref_path_str}")
|
|
237
|
+
else:
|
|
238
|
+
source_name = Path(ref_path_str).stem
|
|
239
|
+
logger.info(f"\n[{i+1}/{len(reference_paths)}] Loading: {Path(ref_path_str).name}")
|
|
240
|
+
|
|
241
|
+
# Resolve URI to local path (downloads cloud files if needed)
|
|
242
|
+
local_path = resolve_uri_to_local(ref_path_str, cache_dir)
|
|
243
|
+
|
|
244
|
+
# Memory-efficient loading with backed mode for large files
|
|
245
|
+
adata = load_adata_backed(
|
|
246
|
+
path=local_path,
|
|
247
|
+
max_cells=max_cells_per_ref,
|
|
248
|
+
label_column=label_col,
|
|
249
|
+
random_state=random_state,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Validate cell type column BEFORE gene normalization
|
|
253
|
+
if validate_labels:
|
|
254
|
+
logger.info(f" Validating cell type column: {label_col}")
|
|
255
|
+
val_result = validate_cell_type_column(
|
|
256
|
+
adata,
|
|
257
|
+
label_col,
|
|
258
|
+
min_cells_per_type=min_cells_per_type,
|
|
259
|
+
)
|
|
260
|
+
validation_results.append({
|
|
261
|
+
"path": ref_path_str,
|
|
262
|
+
"column": label_col,
|
|
263
|
+
"is_valid": val_result.is_valid,
|
|
264
|
+
"n_cell_types": val_result.n_cell_types,
|
|
265
|
+
"n_cells": val_result.n_cells,
|
|
266
|
+
"errors": [str(e) for e in val_result.errors],
|
|
267
|
+
"warnings": [str(w) for w in val_result.warnings],
|
|
268
|
+
})
|
|
269
|
+
|
|
270
|
+
if not val_result.is_valid:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Validation failed for {source_name}:\n{val_result.summary()}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if strict_validation and val_result.warnings:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Validation warnings (strict mode) for {source_name}:\n"
|
|
278
|
+
f"{val_result.summary()}"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Normalize gene names (Ensembl → HUGO)
|
|
282
|
+
adata = normalize_gene_names(adata, ensembl_to_hugo)
|
|
283
|
+
|
|
284
|
+
# Check/apply normalization
|
|
285
|
+
if normalize_data:
|
|
286
|
+
status = check_normalization_status(adata)
|
|
287
|
+
if status["x_state"] == "log1p_10k":
|
|
288
|
+
logger.info(" Data already log-normalized")
|
|
289
|
+
else:
|
|
290
|
+
adata = ensure_normalized(adata)
|
|
291
|
+
logger.info(" Applied log1p(10k) normalization")
|
|
292
|
+
|
|
293
|
+
# Copy cell type labels to unified column
|
|
294
|
+
if label_col not in adata.obs.columns:
|
|
295
|
+
available = list(adata.obs.columns)
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"Label column '{label_col}' not found in {source_name}. "
|
|
298
|
+
f"Available columns: {available}"
|
|
299
|
+
)
|
|
300
|
+
adata.obs[output_column] = adata.obs[label_col].astype(str)
|
|
301
|
+
|
|
302
|
+
# Add source reference info (use source name from URI or local path)
|
|
303
|
+
adata.obs["reference_source"] = source_name
|
|
304
|
+
|
|
305
|
+
adatas.append(adata)
|
|
306
|
+
gc.collect()
|
|
307
|
+
|
|
308
|
+
# Subset to target genes if provided (BEFORE finding shared genes)
|
|
309
|
+
if target_genes:
|
|
310
|
+
logger.info(f"\nSubsetting to {len(target_genes)} target genes...")
|
|
311
|
+
target_set = set(target_genes)
|
|
312
|
+
for i, adata in enumerate(adatas):
|
|
313
|
+
overlap = list(set(adata.var_names) & target_set)
|
|
314
|
+
if len(overlap) == 0:
|
|
315
|
+
raise ValueError(
|
|
316
|
+
f"No overlap between reference {i} and target genes. "
|
|
317
|
+
f"Check gene name format (HUGO symbols expected)."
|
|
318
|
+
)
|
|
319
|
+
adatas[i] = adata[:, overlap].copy()
|
|
320
|
+
logger.info(f" Reference {i}: {len(overlap)} genes after subset")
|
|
321
|
+
|
|
322
|
+
# Find shared genes (inner join)
|
|
323
|
+
logger.info("\nFinding shared genes across all references...")
|
|
324
|
+
shared_genes = set(adatas[0].var_names)
|
|
325
|
+
for adata in adatas[1:]:
|
|
326
|
+
shared_genes &= set(adata.var_names)
|
|
327
|
+
|
|
328
|
+
if len(shared_genes) == 0:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
"No shared genes found across references! "
|
|
331
|
+
"Check that gene names are in the same format (HUGO symbols)."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
logger.info(f" Shared genes: {len(shared_genes):,}")
|
|
335
|
+
|
|
336
|
+
# Subset all to shared genes (sorted for consistency)
|
|
337
|
+
shared_genes_sorted = sorted(shared_genes)
|
|
338
|
+
for i in range(len(adatas)):
|
|
339
|
+
adatas[i] = adatas[i][:, shared_genes_sorted].copy()
|
|
340
|
+
|
|
341
|
+
# Memory check before concatenation
|
|
342
|
+
available_gb = get_available_memory_gb()
|
|
343
|
+
total_cells = sum(a.n_obs for a in adatas)
|
|
344
|
+
estimated_gb = (total_cells * len(shared_genes) * 4) / (1024**3)
|
|
345
|
+
|
|
346
|
+
logger.info(f"\nMemory check before concatenation:")
|
|
347
|
+
logger.info(f" Total cells: {total_cells:,}")
|
|
348
|
+
if available_gb > 0:
|
|
349
|
+
logger.info(f" Available: {available_gb:.1f} GB")
|
|
350
|
+
logger.info(f" Estimated need: ~{estimated_gb:.1f} GB")
|
|
351
|
+
|
|
352
|
+
# Concatenate
|
|
353
|
+
logger.info("\nConcatenating references...")
|
|
354
|
+
combined = sc.concat(adatas, join="inner", label="batch", index_unique="-")
|
|
355
|
+
logger.info(f" Combined: {combined.n_obs:,} cells × {combined.n_vars:,} genes")
|
|
356
|
+
|
|
357
|
+
# Filter excluded labels (exact match only, no partial matching)
|
|
358
|
+
if exclude_labels is None:
|
|
359
|
+
exclude_labels = DEFAULT_EXCLUDE_LABELS
|
|
360
|
+
|
|
361
|
+
labels_to_exclude = set(exclude_labels)
|
|
362
|
+
label_values = combined.obs[output_column].astype(str)
|
|
363
|
+
exclude_mask = label_values.isin(labels_to_exclude)
|
|
364
|
+
n_excluded_labels = exclude_mask.sum()
|
|
365
|
+
|
|
366
|
+
if n_excluded_labels > 0:
|
|
367
|
+
excluded_counts = combined.obs.loc[exclude_mask, output_column].value_counts()
|
|
368
|
+
logger.info(f"\nFiltering excluded labels:")
|
|
369
|
+
for label, count in excluded_counts.items():
|
|
370
|
+
logger.info(f" Removing '{label}': {count:,} cells")
|
|
371
|
+
combined = combined[~exclude_mask].copy()
|
|
372
|
+
logger.info(f" Remaining: {combined.n_obs:,} cells")
|
|
373
|
+
|
|
374
|
+
# Filter low-count cell types
|
|
375
|
+
if filter_min_cells and min_cells_per_type > 0:
|
|
376
|
+
type_counts = combined.obs[output_column].value_counts()
|
|
377
|
+
low_count_types = type_counts[type_counts < min_cells_per_type].index.tolist()
|
|
378
|
+
|
|
379
|
+
if low_count_types:
|
|
380
|
+
low_count_mask = combined.obs[output_column].isin(low_count_types)
|
|
381
|
+
n_removed = low_count_mask.sum()
|
|
382
|
+
logger.info(f"\nFiltering low-count cell types (<{min_cells_per_type} cells):")
|
|
383
|
+
logger.info(f" Removing {len(low_count_types)} types, {n_removed:,} cells")
|
|
384
|
+
for ct in low_count_types[:10]:
|
|
385
|
+
logger.info(f" {ct}: {type_counts[ct]} cells")
|
|
386
|
+
if len(low_count_types) > 10:
|
|
387
|
+
logger.info(f" ... and {len(low_count_types) - 10} more types")
|
|
388
|
+
combined = combined[~low_count_mask].copy()
|
|
389
|
+
logger.info(f" Remaining: {combined.n_obs:,} cells, {combined.obs[output_column].nunique()} types")
|
|
390
|
+
|
|
391
|
+
# Print cell type distribution
|
|
392
|
+
logger.info(f"\n Cell type distribution:")
|
|
393
|
+
ct_counts = combined.obs[output_column].value_counts()
|
|
394
|
+
for ct, count in ct_counts.head(10).items():
|
|
395
|
+
logger.info(f" {ct}: {count:,} cells")
|
|
396
|
+
if len(ct_counts) > 10:
|
|
397
|
+
logger.info(f" ... and {len(ct_counts) - 10} more types")
|
|
398
|
+
|
|
399
|
+
# Store validation results in uns if validation was performed
|
|
400
|
+
if validate_labels and validation_results:
|
|
401
|
+
combined.uns["validation_results"] = validation_results
|
|
402
|
+
|
|
403
|
+
logger.info(
|
|
404
|
+
f"\nCombined reference ready. Call subsample_balanced() for "
|
|
405
|
+
f"source-aware balancing before training."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
return combined
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def get_panel_genes(adata: ad.AnnData) -> List[str]:
|
|
412
|
+
"""
|
|
413
|
+
Extract panel genes from AnnData (e.g., Xenium spatial data).
|
|
414
|
+
|
|
415
|
+
Parameters
|
|
416
|
+
----------
|
|
417
|
+
adata : AnnData
|
|
418
|
+
AnnData object (typically from spatial platform).
|
|
419
|
+
|
|
420
|
+
Returns
|
|
421
|
+
-------
|
|
422
|
+
List[str]
|
|
423
|
+
List of gene names (panel genes).
|
|
424
|
+
"""
|
|
425
|
+
return list(adata.var_names)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# ============================================================================
|
|
429
|
+
# CellTypist Training
|
|
430
|
+
# ============================================================================
|
|
431
|
+
|
|
432
|
+
def _save_model_metadata(
|
|
433
|
+
metadata_path: Path,
|
|
434
|
+
model,
|
|
435
|
+
adata: ad.AnnData,
|
|
436
|
+
label_column: str,
|
|
437
|
+
training_params: Dict[str, Any],
|
|
438
|
+
) -> None:
|
|
439
|
+
"""
|
|
440
|
+
Save model training metadata for reproducibility.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
metadata_path : Path
|
|
445
|
+
Path to save JSON metadata.
|
|
446
|
+
model
|
|
447
|
+
Trained CellTypist model.
|
|
448
|
+
adata : AnnData
|
|
449
|
+
Training data.
|
|
450
|
+
label_column : str
|
|
451
|
+
Cell type label column.
|
|
452
|
+
training_params : Dict
|
|
453
|
+
Training parameters used.
|
|
454
|
+
"""
|
|
455
|
+
try:
|
|
456
|
+
import celltypist
|
|
457
|
+
celltypist_version = celltypist.__version__
|
|
458
|
+
except Exception:
|
|
459
|
+
celltypist_version = "unknown"
|
|
460
|
+
|
|
461
|
+
try:
|
|
462
|
+
import spatialcore
|
|
463
|
+
spatialcore_version = spatialcore.__version__
|
|
464
|
+
except Exception:
|
|
465
|
+
spatialcore_version = "unknown"
|
|
466
|
+
|
|
467
|
+
# Cell type counts
|
|
468
|
+
cell_type_counts = adata.obs[label_column].value_counts().to_dict()
|
|
469
|
+
|
|
470
|
+
# Reference sources if tracked
|
|
471
|
+
reference_sources = []
|
|
472
|
+
if "reference_source" in adata.obs.columns:
|
|
473
|
+
for source in adata.obs["reference_source"].unique():
|
|
474
|
+
source_mask = adata.obs["reference_source"] == source
|
|
475
|
+
reference_sources.append({
|
|
476
|
+
"name": source,
|
|
477
|
+
"n_cells_used": int(source_mask.sum()),
|
|
478
|
+
})
|
|
479
|
+
|
|
480
|
+
metadata = {
|
|
481
|
+
"model_name": metadata_path.stem.replace("_metadata", ""),
|
|
482
|
+
"created_at": datetime.now().isoformat(),
|
|
483
|
+
"spatialcore_version": spatialcore_version,
|
|
484
|
+
"celltypist_version": celltypist_version,
|
|
485
|
+
"training": {
|
|
486
|
+
"n_cells": int(adata.n_obs),
|
|
487
|
+
"n_genes": int(len(model.features)),
|
|
488
|
+
"n_cell_types": int(len(model.cell_types)),
|
|
489
|
+
"label_column": label_column,
|
|
490
|
+
**training_params,
|
|
491
|
+
},
|
|
492
|
+
"references": reference_sources,
|
|
493
|
+
"panel_genes": {
|
|
494
|
+
"n_genes": int(len(model.features)),
|
|
495
|
+
"genes": list(model.features)[:50], # First 50 for preview
|
|
496
|
+
"genes_truncated": len(model.features) > 50,
|
|
497
|
+
},
|
|
498
|
+
"cell_type_summary": {
|
|
499
|
+
str(k): int(v) for k, v in cell_type_counts.items()
|
|
500
|
+
},
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
with open(metadata_path, "w") as f:
|
|
504
|
+
json.dump(metadata, f, indent=2)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def train_celltypist_model(
|
|
508
|
+
adata: ad.AnnData,
|
|
509
|
+
label_column: str = "unified_cell_type",
|
|
510
|
+
model_name: str = "custom_model",
|
|
511
|
+
output_path: Optional[Union[str, Path]] = None,
|
|
512
|
+
use_SGD: bool = True,
|
|
513
|
+
mini_batch: bool = True,
|
|
514
|
+
balance_cell_type: bool = True,
|
|
515
|
+
feature_selection: bool = False,
|
|
516
|
+
n_jobs: int = -1,
|
|
517
|
+
max_iter: int = 100,
|
|
518
|
+
epochs: int = 10,
|
|
519
|
+
batch_size: int = 1000,
|
|
520
|
+
batch_number: int = 200,
|
|
521
|
+
) -> Dict[str, Any]:
|
|
522
|
+
"""
|
|
523
|
+
Train a custom CellTypist logistic regression model.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
adata : AnnData
|
|
528
|
+
Reference data (already subset to panel genes, normalized).
|
|
529
|
+
label_column : str, default "unified_cell_type"
|
|
530
|
+
Cell type label column in adata.obs.
|
|
531
|
+
model_name : str, default "custom_model"
|
|
532
|
+
Name for the trained model.
|
|
533
|
+
output_path : str or Path, optional
|
|
534
|
+
Where to save .pkl file. If None, doesn't save.
|
|
535
|
+
use_SGD : bool, default True
|
|
536
|
+
Use stochastic gradient descent (faster for large data).
|
|
537
|
+
mini_batch : bool, default True
|
|
538
|
+
Use mini-batch training (recommended for large datasets).
|
|
539
|
+
balance_cell_type : bool, default True
|
|
540
|
+
Balance rare cell types in mini-batches.
|
|
541
|
+
feature_selection : bool, default False
|
|
542
|
+
Perform feature selection (False = use all genes).
|
|
543
|
+
n_jobs : int, default -1
|
|
544
|
+
Parallel jobs (-1 = all cores).
|
|
545
|
+
max_iter : int, default 100
|
|
546
|
+
Maximum training iterations (for non-mini-batch).
|
|
547
|
+
epochs : int, default 10
|
|
548
|
+
Training epochs (for mini-batch).
|
|
549
|
+
batch_size : int, default 1000
|
|
550
|
+
Cells per batch (for mini-batch).
|
|
551
|
+
batch_number : int, default 200
|
|
552
|
+
Batches per epoch (for mini-batch).
|
|
553
|
+
|
|
554
|
+
Returns
|
|
555
|
+
-------
|
|
556
|
+
Dict[str, Any]
|
|
557
|
+
Model metadata including:
|
|
558
|
+
- model_path: Path to saved model (if output_path provided)
|
|
559
|
+
- n_cells_trained: Number of cells used for training
|
|
560
|
+
- n_genes: Number of features/genes
|
|
561
|
+
- n_cell_types: Number of cell types
|
|
562
|
+
- cell_types: List of cell type names
|
|
563
|
+
- model: The trained CellTypist model object
|
|
564
|
+
|
|
565
|
+
Notes
|
|
566
|
+
-----
|
|
567
|
+
For imbalanced data (e.g., 192k T cells vs 277 mast cells), use
|
|
568
|
+
`balance_cell_type=True` with `mini_batch=True` to ensure rare
|
|
569
|
+
cell types are adequately represented in training batches.
|
|
570
|
+
|
|
571
|
+
Examples
|
|
572
|
+
--------
|
|
573
|
+
>>> from spatialcore.annotation import train_celltypist_model
|
|
574
|
+
>>> result = train_celltypist_model(
|
|
575
|
+
... combined_adata,
|
|
576
|
+
... label_column="unified_cell_type",
|
|
577
|
+
... output_path="./models/liver_xenium_v1.pkl",
|
|
578
|
+
... mini_batch=True,
|
|
579
|
+
... balance_cell_type=True,
|
|
580
|
+
... )
|
|
581
|
+
>>> print(f"Trained on {result['n_cells_trained']:,} cells")
|
|
582
|
+
>>> print(f"Cell types: {result['n_cell_types']}")
|
|
583
|
+
"""
|
|
584
|
+
try:
|
|
585
|
+
import celltypist
|
|
586
|
+
except ImportError:
|
|
587
|
+
raise ImportError(
|
|
588
|
+
"celltypist is required for model training. "
|
|
589
|
+
"Install with: pip install celltypist"
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Validate label column
|
|
593
|
+
if label_column not in adata.obs.columns:
|
|
594
|
+
available = list(adata.obs.columns)
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Label column '{label_column}' not found. Available: {available}"
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Log training parameters
|
|
600
|
+
logger.info("Training CellTypist model...")
|
|
601
|
+
logger.info(f" Cells: {adata.n_obs:,}")
|
|
602
|
+
logger.info(f" Genes: {adata.n_vars:,}")
|
|
603
|
+
logger.info(f" Cell types: {adata.obs[label_column].nunique()}")
|
|
604
|
+
logger.info(f" Mini-batch: {mini_batch}")
|
|
605
|
+
logger.info(f" Balance cell types: {balance_cell_type}")
|
|
606
|
+
|
|
607
|
+
# Train model
|
|
608
|
+
if mini_batch:
|
|
609
|
+
model = celltypist.train(
|
|
610
|
+
adata,
|
|
611
|
+
labels=label_column,
|
|
612
|
+
check_expression=False, # Already validated
|
|
613
|
+
use_SGD=use_SGD,
|
|
614
|
+
mini_batch=True,
|
|
615
|
+
balance_cell_type=balance_cell_type,
|
|
616
|
+
feature_selection=feature_selection,
|
|
617
|
+
n_jobs=n_jobs,
|
|
618
|
+
epochs=epochs,
|
|
619
|
+
batch_size=batch_size,
|
|
620
|
+
batch_number=batch_number,
|
|
621
|
+
)
|
|
622
|
+
else:
|
|
623
|
+
model = celltypist.train(
|
|
624
|
+
adata,
|
|
625
|
+
labels=label_column,
|
|
626
|
+
check_expression=False,
|
|
627
|
+
use_SGD=use_SGD,
|
|
628
|
+
feature_selection=feature_selection,
|
|
629
|
+
n_jobs=n_jobs,
|
|
630
|
+
max_iter=max_iter,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Save model if output path provided
|
|
634
|
+
if output_path:
|
|
635
|
+
output_path = Path(output_path)
|
|
636
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
637
|
+
model.write(str(output_path))
|
|
638
|
+
logger.info(f" Saved model to: {output_path}")
|
|
639
|
+
|
|
640
|
+
# Save metadata JSON for reproducibility
|
|
641
|
+
metadata_path = output_path.with_suffix("").with_name(
|
|
642
|
+
output_path.stem + "_celltypist.json"
|
|
643
|
+
)
|
|
644
|
+
_save_model_metadata(
|
|
645
|
+
metadata_path=metadata_path,
|
|
646
|
+
model=model,
|
|
647
|
+
adata=adata,
|
|
648
|
+
label_column=label_column,
|
|
649
|
+
training_params={
|
|
650
|
+
"use_SGD": use_SGD,
|
|
651
|
+
"mini_batch": mini_batch,
|
|
652
|
+
"balance_cell_type": balance_cell_type,
|
|
653
|
+
"feature_selection": feature_selection,
|
|
654
|
+
"n_jobs": n_jobs,
|
|
655
|
+
"max_iter": max_iter,
|
|
656
|
+
"epochs": epochs,
|
|
657
|
+
"batch_size": batch_size,
|
|
658
|
+
"batch_number": batch_number,
|
|
659
|
+
},
|
|
660
|
+
)
|
|
661
|
+
logger.info(f" Saved metadata to: {metadata_path}")
|
|
662
|
+
|
|
663
|
+
return {
|
|
664
|
+
"model_path": str(output_path) if output_path else None,
|
|
665
|
+
"metadata_path": str(metadata_path) if output_path else None,
|
|
666
|
+
"n_cells_trained": adata.n_obs,
|
|
667
|
+
"n_genes": len(model.features),
|
|
668
|
+
"n_cell_types": len(model.cell_types),
|
|
669
|
+
"cell_types": list(model.cell_types),
|
|
670
|
+
"model": model,
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def get_model_gene_overlap(
|
|
675
|
+
model_path: Union[str, Path],
|
|
676
|
+
query_genes: List[str],
|
|
677
|
+
) -> Dict[str, Any]:
|
|
678
|
+
"""
|
|
679
|
+
Calculate gene overlap between a CellTypist model and query data.
|
|
680
|
+
|
|
681
|
+
Parameters
|
|
682
|
+
----------
|
|
683
|
+
model_path : str or Path
|
|
684
|
+
Path to CellTypist model (.pkl file).
|
|
685
|
+
query_genes : List[str]
|
|
686
|
+
Gene names from query data (e.g., Xenium panel).
|
|
687
|
+
|
|
688
|
+
Returns
|
|
689
|
+
-------
|
|
690
|
+
Dict[str, Any]
|
|
691
|
+
- n_model_genes: Total genes in model
|
|
692
|
+
- n_query_genes: Total genes in query
|
|
693
|
+
- n_overlap: Number of overlapping genes
|
|
694
|
+
- overlap_pct: Percentage of model genes in query
|
|
695
|
+
- overlapping_genes: List of overlapping gene names
|
|
696
|
+
- missing_genes: List of model genes missing from query
|
|
697
|
+
|
|
698
|
+
Examples
|
|
699
|
+
--------
|
|
700
|
+
>>> from spatialcore.annotation import get_model_gene_overlap
|
|
701
|
+
>>> overlap = get_model_gene_overlap(
|
|
702
|
+
... "Healthy_Human_Liver.pkl",
|
|
703
|
+
... list(xenium_adata.var_names)
|
|
704
|
+
... )
|
|
705
|
+
>>> print(f"Gene overlap: {overlap['overlap_pct']:.1f}%")
|
|
706
|
+
"""
|
|
707
|
+
try:
|
|
708
|
+
from celltypist import models
|
|
709
|
+
except ImportError:
|
|
710
|
+
raise ImportError("celltypist is required. Install with: pip install celltypist")
|
|
711
|
+
|
|
712
|
+
model = models.Model.load(str(model_path))
|
|
713
|
+
model_genes = set(model.features)
|
|
714
|
+
query_genes_set = set(query_genes)
|
|
715
|
+
|
|
716
|
+
overlap = model_genes & query_genes_set
|
|
717
|
+
missing = model_genes - query_genes_set
|
|
718
|
+
|
|
719
|
+
return {
|
|
720
|
+
"n_model_genes": len(model_genes),
|
|
721
|
+
"n_query_genes": len(query_genes_set),
|
|
722
|
+
"n_overlap": len(overlap),
|
|
723
|
+
"overlap_pct": 100 * len(overlap) / len(model_genes) if model_genes else 0,
|
|
724
|
+
"overlapping_genes": sorted(overlap),
|
|
725
|
+
"missing_genes": sorted(missing),
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def get_training_summary(combined_adata: ad.AnnData, label_column: str) -> pd.DataFrame:
|
|
730
|
+
"""
|
|
731
|
+
Get summary of cell type distribution for training data.
|
|
732
|
+
|
|
733
|
+
Parameters
|
|
734
|
+
----------
|
|
735
|
+
combined_adata : AnnData
|
|
736
|
+
Combined reference data.
|
|
737
|
+
label_column : str
|
|
738
|
+
Cell type label column.
|
|
739
|
+
|
|
740
|
+
Returns
|
|
741
|
+
-------
|
|
742
|
+
pd.DataFrame
|
|
743
|
+
Summary with columns: cell_type, n_cells, pct_total.
|
|
744
|
+
"""
|
|
745
|
+
counts = combined_adata.obs[label_column].value_counts()
|
|
746
|
+
df = pd.DataFrame({
|
|
747
|
+
"cell_type": counts.index,
|
|
748
|
+
"n_cells": counts.values,
|
|
749
|
+
"pct_total": 100 * counts.values / combined_adata.n_obs,
|
|
750
|
+
})
|
|
751
|
+
return df
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
# ============================================================================
|
|
755
|
+
# Color Palettes
|
|
756
|
+
# ============================================================================
|
|
757
|
+
|
|
758
|
+
# High-contrast palette optimized for dark backgrounds
|
|
759
|
+
# Designed for spatial maps and UMAP visualizations
|
|
760
|
+
HIGH_CONTRAST_PALETTE = [
|
|
761
|
+
"#F5252F", # Red
|
|
762
|
+
"#FB3FFC", # Magenta
|
|
763
|
+
"#00FFFF", # Cyan
|
|
764
|
+
"#33FF33", # Green
|
|
765
|
+
"#FFB300", # Amber
|
|
766
|
+
"#9966FF", # Purple
|
|
767
|
+
"#FF6B6B", # Coral
|
|
768
|
+
"#3784FE", # Blue
|
|
769
|
+
"#FF8000", # Orange
|
|
770
|
+
"#66CCCC", # Teal
|
|
771
|
+
"#CC66FF", # Lavender
|
|
772
|
+
"#99FF99", # Light green
|
|
773
|
+
"#FF6699", # Pink
|
|
774
|
+
"#E3E1E3", # Light gray
|
|
775
|
+
"#FFB3CC", # Light pink
|
|
776
|
+
"#E6E680", # Yellow-green
|
|
777
|
+
"#CC9966", # Tan
|
|
778
|
+
"#8080FF", # Periwinkle
|
|
779
|
+
"#FF9999", # Salmon
|
|
780
|
+
"#66FF99", # Mint
|
|
781
|
+
]
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def generate_color_scheme(
|
|
785
|
+
cell_types: List[str],
|
|
786
|
+
custom_colors: Optional[Dict[str, str]] = None,
|
|
787
|
+
palette: Optional[List[str]] = None,
|
|
788
|
+
) -> Dict[str, str]:
|
|
789
|
+
"""
|
|
790
|
+
Generate deterministic color mapping for cell types.
|
|
791
|
+
|
|
792
|
+
Creates a mapping from cell type names to hex colors, using
|
|
793
|
+
custom overrides if provided, then filling remaining types
|
|
794
|
+
from the palette.
|
|
795
|
+
|
|
796
|
+
Parameters
|
|
797
|
+
----------
|
|
798
|
+
cell_types : List[str]
|
|
799
|
+
List of cell type names.
|
|
800
|
+
custom_colors : Dict[str, str], optional
|
|
801
|
+
Custom color overrides. Keys are cell type names, values
|
|
802
|
+
are hex color codes (e.g., "#FF0000").
|
|
803
|
+
palette : List[str], optional
|
|
804
|
+
Color palette to use. If None, uses HIGH_CONTRAST_PALETTE.
|
|
805
|
+
|
|
806
|
+
Returns
|
|
807
|
+
-------
|
|
808
|
+
Dict[str, str]
|
|
809
|
+
Mapping from cell type to hex color.
|
|
810
|
+
|
|
811
|
+
Notes
|
|
812
|
+
-----
|
|
813
|
+
Colors are assigned deterministically based on sorted cell type names,
|
|
814
|
+
ensuring consistent colors across runs.
|
|
815
|
+
|
|
816
|
+
Examples
|
|
817
|
+
--------
|
|
818
|
+
>>> from spatialcore.annotation.training import generate_color_scheme
|
|
819
|
+
>>> colors = generate_color_scheme(
|
|
820
|
+
... ["T cell", "B cell", "Macrophage"],
|
|
821
|
+
... custom_colors={"T cell": "#FF0000"},
|
|
822
|
+
... )
|
|
823
|
+
>>> print(colors)
|
|
824
|
+
"""
|
|
825
|
+
if palette is None:
|
|
826
|
+
palette = HIGH_CONTRAST_PALETTE
|
|
827
|
+
|
|
828
|
+
if custom_colors is None:
|
|
829
|
+
custom_colors = {}
|
|
830
|
+
|
|
831
|
+
color_scheme = {}
|
|
832
|
+
palette_idx = 0
|
|
833
|
+
|
|
834
|
+
# Sort for deterministic ordering
|
|
835
|
+
for cell_type in sorted(cell_types):
|
|
836
|
+
if cell_type in custom_colors:
|
|
837
|
+
color_scheme[cell_type] = custom_colors[cell_type]
|
|
838
|
+
else:
|
|
839
|
+
color_scheme[cell_type] = palette[palette_idx % len(palette)]
|
|
840
|
+
palette_idx += 1
|
|
841
|
+
|
|
842
|
+
return color_scheme
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
# ============================================================================
|
|
846
|
+
# Model Artifact Management
|
|
847
|
+
# ============================================================================
|
|
848
|
+
|
|
849
|
+
def save_model_artifacts(
|
|
850
|
+
model,
|
|
851
|
+
output_dir: Union[str, Path],
|
|
852
|
+
model_name: str,
|
|
853
|
+
training_metadata: Optional[Dict[str, Any]] = None,
|
|
854
|
+
custom_colors: Optional[Dict[str, str]] = None,
|
|
855
|
+
) -> Dict[str, Path]:
|
|
856
|
+
"""
|
|
857
|
+
Save model with metadata and color scheme.
|
|
858
|
+
|
|
859
|
+
Creates a complete model artifact package with:
|
|
860
|
+
- Model pickle file (.pkl)
|
|
861
|
+
- Training metadata JSON
|
|
862
|
+
- Color scheme JSON for visualization
|
|
863
|
+
|
|
864
|
+
Parameters
|
|
865
|
+
----------
|
|
866
|
+
model
|
|
867
|
+
Trained CellTypist model.
|
|
868
|
+
output_dir : str or Path
|
|
869
|
+
Directory to save artifacts.
|
|
870
|
+
model_name : str
|
|
871
|
+
Base name for output files.
|
|
872
|
+
training_metadata : Dict[str, Any], optional
|
|
873
|
+
Additional metadata to include (e.g., training parameters,
|
|
874
|
+
reference sources).
|
|
875
|
+
custom_colors : Dict[str, str], optional
|
|
876
|
+
Custom color overrides for specific cell types.
|
|
877
|
+
|
|
878
|
+
Returns
|
|
879
|
+
-------
|
|
880
|
+
Dict[str, Path]
|
|
881
|
+
Paths to saved files:
|
|
882
|
+
- model_path: Path to .pkl model file
|
|
883
|
+
- metadata_path: Path to metadata JSON
|
|
884
|
+
- colors_path: Path to color scheme JSON
|
|
885
|
+
|
|
886
|
+
Notes
|
|
887
|
+
-----
|
|
888
|
+
The color scheme is saved separately so it can be loaded by
|
|
889
|
+
visualization functions without loading the full model.
|
|
890
|
+
|
|
891
|
+
File naming:
|
|
892
|
+
- {model_name}.pkl
|
|
893
|
+
- {model_name}_celltypist.json
|
|
894
|
+
- {model_name}_colors.json
|
|
895
|
+
|
|
896
|
+
Examples
|
|
897
|
+
--------
|
|
898
|
+
>>> from spatialcore.annotation.training import save_model_artifacts
|
|
899
|
+
>>> result = train_celltypist_model(adata, label_column="cell_type")
|
|
900
|
+
>>> paths = save_model_artifacts(
|
|
901
|
+
... result["model"],
|
|
902
|
+
... output_dir="./models",
|
|
903
|
+
... model_name="liver_v1",
|
|
904
|
+
... training_metadata={"reference": "cellxgene"},
|
|
905
|
+
... )
|
|
906
|
+
>>> print(f"Model saved to: {paths['model_path']}")
|
|
907
|
+
"""
|
|
908
|
+
output_dir = Path(output_dir)
|
|
909
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
910
|
+
|
|
911
|
+
# Save model
|
|
912
|
+
model_path = output_dir / f"{model_name}.pkl"
|
|
913
|
+
model.write(str(model_path))
|
|
914
|
+
logger.info(f"Saved model to: {model_path}")
|
|
915
|
+
|
|
916
|
+
# Generate color scheme
|
|
917
|
+
cell_types = list(model.cell_types)
|
|
918
|
+
colors = generate_color_scheme(cell_types, custom_colors)
|
|
919
|
+
|
|
920
|
+
# Save color scheme
|
|
921
|
+
colors_path = output_dir / f"{model_name}_colors.json"
|
|
922
|
+
with open(colors_path, "w") as f:
|
|
923
|
+
json.dump(colors, f, indent=2)
|
|
924
|
+
logger.info(f"Saved color scheme to: {colors_path}")
|
|
925
|
+
|
|
926
|
+
# Build metadata
|
|
927
|
+
try:
|
|
928
|
+
import celltypist
|
|
929
|
+
celltypist_version = celltypist.__version__
|
|
930
|
+
except Exception:
|
|
931
|
+
celltypist_version = "unknown"
|
|
932
|
+
|
|
933
|
+
try:
|
|
934
|
+
import spatialcore
|
|
935
|
+
spatialcore_version = spatialcore.__version__
|
|
936
|
+
except Exception:
|
|
937
|
+
spatialcore_version = "unknown"
|
|
938
|
+
|
|
939
|
+
metadata = {
|
|
940
|
+
"model_name": model_name,
|
|
941
|
+
"created_at": datetime.now().isoformat(),
|
|
942
|
+
"spatialcore_version": spatialcore_version,
|
|
943
|
+
"celltypist_version": celltypist_version,
|
|
944
|
+
"n_genes": len(model.features),
|
|
945
|
+
"n_cell_types": len(model.cell_types),
|
|
946
|
+
"cell_types": cell_types,
|
|
947
|
+
"genes": list(model.features),
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
if training_metadata:
|
|
951
|
+
metadata["training"] = training_metadata
|
|
952
|
+
|
|
953
|
+
# Save metadata
|
|
954
|
+
metadata_path = output_dir / f"{model_name}_celltypist.json"
|
|
955
|
+
with open(metadata_path, "w") as f:
|
|
956
|
+
json.dump(metadata, f, indent=2)
|
|
957
|
+
logger.info(f"Saved metadata to: {metadata_path}")
|
|
958
|
+
|
|
959
|
+
return {
|
|
960
|
+
"model_path": model_path,
|
|
961
|
+
"metadata_path": metadata_path,
|
|
962
|
+
"colors_path": colors_path,
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
# ============================================================================
|
|
967
|
+
# Balanced Subsampling
|
|
968
|
+
# ============================================================================
|
|
969
|
+
|
|
970
|
+
def _load_target_proportions(
|
|
971
|
+
target_proportions: Union[Dict[str, float], str, Path, None]
|
|
972
|
+
) -> Optional[Dict[str, float]]:
|
|
973
|
+
"""
|
|
974
|
+
Load and validate target proportions from dict, JSON, or CSV.
|
|
975
|
+
|
|
976
|
+
Parameters
|
|
977
|
+
----------
|
|
978
|
+
target_proportions : dict, str, Path, or None
|
|
979
|
+
- Dict: Used directly
|
|
980
|
+
- str/Path ending in .json: Load as JSON
|
|
981
|
+
- str/Path ending in .csv: Load as CSV with columns (cell_type, proportion)
|
|
982
|
+
- None: Return None
|
|
983
|
+
|
|
984
|
+
Returns
|
|
985
|
+
-------
|
|
986
|
+
dict or None
|
|
987
|
+
Validated proportions dict, or None if input was None.
|
|
988
|
+
|
|
989
|
+
Raises
|
|
990
|
+
------
|
|
991
|
+
ValueError
|
|
992
|
+
If proportions are invalid (negative, >1.0, or wrong format).
|
|
993
|
+
"""
|
|
994
|
+
if target_proportions is None:
|
|
995
|
+
return None
|
|
996
|
+
|
|
997
|
+
# Load from file if path provided
|
|
998
|
+
if isinstance(target_proportions, (str, Path)):
|
|
999
|
+
path = Path(target_proportions)
|
|
1000
|
+
|
|
1001
|
+
if not path.exists():
|
|
1002
|
+
raise ValueError(f"Target proportions file not found: {path}")
|
|
1003
|
+
|
|
1004
|
+
if path.suffix.lower() == ".json":
|
|
1005
|
+
with open(path, "r") as f:
|
|
1006
|
+
props = json.load(f)
|
|
1007
|
+
elif path.suffix.lower() == ".csv":
|
|
1008
|
+
df = pd.read_csv(path)
|
|
1009
|
+
if "cell_type" not in df.columns or "proportion" not in df.columns:
|
|
1010
|
+
raise ValueError(
|
|
1011
|
+
f"CSV must have 'cell_type' and 'proportion' columns. "
|
|
1012
|
+
f"Found: {list(df.columns)}"
|
|
1013
|
+
)
|
|
1014
|
+
props = dict(zip(df["cell_type"], df["proportion"]))
|
|
1015
|
+
else:
|
|
1016
|
+
raise ValueError(
|
|
1017
|
+
f"Unsupported file format: {path.suffix}. Use .json or .csv"
|
|
1018
|
+
)
|
|
1019
|
+
else:
|
|
1020
|
+
props = target_proportions
|
|
1021
|
+
|
|
1022
|
+
# Validate proportions
|
|
1023
|
+
if not isinstance(props, dict):
|
|
1024
|
+
raise ValueError(
|
|
1025
|
+
f"target_proportions must be a dict, got {type(props).__name__}"
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
for cell_type, proportion in props.items():
|
|
1029
|
+
if not isinstance(proportion, (int, float)):
|
|
1030
|
+
raise ValueError(
|
|
1031
|
+
f"Invalid proportion for '{cell_type}': {proportion}. "
|
|
1032
|
+
f"Must be a number."
|
|
1033
|
+
)
|
|
1034
|
+
if proportion < 0 or proportion > 1.0:
|
|
1035
|
+
raise ValueError(
|
|
1036
|
+
f"Invalid proportion for '{cell_type}': {proportion}. "
|
|
1037
|
+
f"Must be between 0.0 and 1.0."
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
return props
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def subsample_balanced(
|
|
1044
|
+
adata: ad.AnnData,
|
|
1045
|
+
label_column: str,
|
|
1046
|
+
max_cells_per_type: int = 10000,
|
|
1047
|
+
min_cells_per_type: int = 50,
|
|
1048
|
+
source_column: Optional[str] = "reference_source",
|
|
1049
|
+
source_balance: str = "proportional",
|
|
1050
|
+
min_cells_per_source: int = 50,
|
|
1051
|
+
group_by_column: Optional[str] = None,
|
|
1052
|
+
target_proportions: Optional[Union[Dict[str, float], str, Path]] = None,
|
|
1053
|
+
random_state: int = 42,
|
|
1054
|
+
copy: bool = True,
|
|
1055
|
+
) -> ad.AnnData:
|
|
1056
|
+
"""
|
|
1057
|
+
Source-aware balanced subsampling for multi-reference training data.
|
|
1058
|
+
|
|
1059
|
+
When combining multiple scRNA-seq references, this function ensures each
|
|
1060
|
+
cell type is sampled proportionally from all sources that contain it.
|
|
1061
|
+
This prevents the model from learning source-specific artifacts.
|
|
1062
|
+
|
|
1063
|
+
Parameters
|
|
1064
|
+
----------
|
|
1065
|
+
adata : AnnData
|
|
1066
|
+
Combined reference data (output of combine_references()).
|
|
1067
|
+
label_column : str
|
|
1068
|
+
Column in adata.obs containing cell type labels (for logging/display).
|
|
1069
|
+
max_cells_per_type : int, default 10000
|
|
1070
|
+
Maximum cells per cell type in output.
|
|
1071
|
+
min_cells_per_type : int, default 50
|
|
1072
|
+
Cell types with fewer cells than this are kept entirely
|
|
1073
|
+
(no subsampling applied).
|
|
1074
|
+
source_column : str, optional, default "reference_source"
|
|
1075
|
+
Column identifying which reference each cell came from.
|
|
1076
|
+
Set to None to disable source-aware balancing (simple capping).
|
|
1077
|
+
source_balance : {"proportional", "equal"}, default "proportional"
|
|
1078
|
+
How to distribute sampling across sources:
|
|
1079
|
+
- "proportional": Draw from each source proportionally to its
|
|
1080
|
+
contribution for that cell type. (RECOMMENDED)
|
|
1081
|
+
- "equal": Draw equally from each source (up to available).
|
|
1082
|
+
min_cells_per_source : int, default 50
|
|
1083
|
+
Minimum cells to draw from a source that has the cell type.
|
|
1084
|
+
Ensures rare sources still contribute to the model.
|
|
1085
|
+
group_by_column : str, optional
|
|
1086
|
+
If provided, group cells by values in this column instead of label_column.
|
|
1087
|
+
This enables semantic grouping where different text labels map to the
|
|
1088
|
+
same identity. For example, using "cell_type_ontology_term_id":
|
|
1089
|
+
- "CD4+ T cells" and "CD4-positive, alpha-beta T cell" -> CL:0000624
|
|
1090
|
+
- Both are grouped together for balancing purposes.
|
|
1091
|
+
If None, groups by label_column text (current behavior).
|
|
1092
|
+
target_proportions : dict, str, or Path, optional
|
|
1093
|
+
Target proportions for specific cell types. Accepts:
|
|
1094
|
+
- Dict mapping cell type names to proportions (0.0-1.0)
|
|
1095
|
+
- Path to JSON file: ``{"NK cell": 0.0025, "plasma cell": 0.001}``
|
|
1096
|
+
- Path to CSV file with columns: ``cell_type``, ``proportion``
|
|
1097
|
+
|
|
1098
|
+
For cell types in this mapping, the target count is calculated as:
|
|
1099
|
+
``proportion × total_input_cells`` instead of using max_cells_per_type.
|
|
1100
|
+
|
|
1101
|
+
Cell types NOT in the mapping use normal max_cells_per_type capping.
|
|
1102
|
+
|
|
1103
|
+
This is essential for handling pure/enriched references (e.g., FACS-sorted
|
|
1104
|
+
cells) where a cell type exists only in the enriched source and would
|
|
1105
|
+
otherwise dominate training.
|
|
1106
|
+
random_state : int, default 42
|
|
1107
|
+
Random seed for reproducibility.
|
|
1108
|
+
copy : bool, default True
|
|
1109
|
+
Return a copy or modify in-place.
|
|
1110
|
+
|
|
1111
|
+
Returns
|
|
1112
|
+
-------
|
|
1113
|
+
AnnData
|
|
1114
|
+
Subsampled data with source-balanced cell type representation.
|
|
1115
|
+
|
|
1116
|
+
Raises
|
|
1117
|
+
------
|
|
1118
|
+
ValueError
|
|
1119
|
+
If label_column not found in adata.obs.
|
|
1120
|
+
If source_column specified but not found in adata.obs.
|
|
1121
|
+
If source_balance is not "proportional" or "equal".
|
|
1122
|
+
|
|
1123
|
+
Examples
|
|
1124
|
+
--------
|
|
1125
|
+
>>> from spatialcore.annotation import combine_references, subsample_balanced
|
|
1126
|
+
>>> # Step 1: Combine (no balancing)
|
|
1127
|
+
>>> combined = combine_references(
|
|
1128
|
+
... reference_paths=["study1.h5ad", "study2.h5ad"],
|
|
1129
|
+
... label_columns=["cell_type", "cell_type"],
|
|
1130
|
+
... target_genes=panel_genes,
|
|
1131
|
+
... )
|
|
1132
|
+
>>> # Step 2: Source-aware balancing
|
|
1133
|
+
>>> balanced = subsample_balanced(
|
|
1134
|
+
... combined,
|
|
1135
|
+
... label_column="unified_cell_type",
|
|
1136
|
+
... max_cells_per_type=10000,
|
|
1137
|
+
... source_balance="proportional",
|
|
1138
|
+
... )
|
|
1139
|
+
|
|
1140
|
+
Notes
|
|
1141
|
+
-----
|
|
1142
|
+
**Why source-aware balancing?**
|
|
1143
|
+
|
|
1144
|
+
When combining Study 1 (30K Macrophages) + Study 2 (5K Macrophages) and
|
|
1145
|
+
capping at 10K, naive capping takes mostly from Study 1. The model learns
|
|
1146
|
+
Study 1's version of Macrophage, not a consensus.
|
|
1147
|
+
|
|
1148
|
+
Source-aware balancing draws proportionally: ~8.5K from Study 1, ~1.5K
|
|
1149
|
+
from Study 2, ensuring both studies contribute to each shared cell type.
|
|
1150
|
+
|
|
1151
|
+
**Handling FACS-Enriched / Pure Cell Type References**
|
|
1152
|
+
|
|
1153
|
+
When combining tissue-derived scRNA-seq (natural cell type proportions)
|
|
1154
|
+
with FACS-sorted or enriched populations (e.g., 100% pure T cells, sorted
|
|
1155
|
+
NK cells), special considerations apply:
|
|
1156
|
+
|
|
1157
|
+
- **source_balance="proportional"** (default): Best when all references
|
|
1158
|
+
are tissue-derived with natural proportions. Each source contributes
|
|
1159
|
+
proportionally to its cell count for each type.
|
|
1160
|
+
|
|
1161
|
+
- **source_balance="equal"**: Recommended when combining tissue-derived
|
|
1162
|
+
data with FACS-enriched pure populations. Forces equal contribution
|
|
1163
|
+
from each source, preventing pure populations from overwhelming the
|
|
1164
|
+
model's concept of that cell type.
|
|
1165
|
+
|
|
1166
|
+
Example scenario:
|
|
1167
|
+
- Reference A (tissue): 50K cells with natural proportions (5% T cells = 2.5K)
|
|
1168
|
+
- Reference B (FACS): 100K pure T cells (100% T cells)
|
|
1169
|
+
|
|
1170
|
+
With "proportional": T cells would be 97.5% from FACS, 2.5% from tissue.
|
|
1171
|
+
With "equal": T cells would be 50% from each, learning a consensus.
|
|
1172
|
+
|
|
1173
|
+
**Future Improvements**
|
|
1174
|
+
|
|
1175
|
+
- Automatic detection of enriched/pure cell type sources based on
|
|
1176
|
+
cell type distribution entropy
|
|
1177
|
+
- Per-cell-type source_balance overrides (e.g., use "equal" only for
|
|
1178
|
+
T cells that appear in FACS source)
|
|
1179
|
+
- Integration with Harmony batch correction for multi-source training via PCAs and predicition on spatial data with >5000 genes
|
|
1180
|
+
"""
|
|
1181
|
+
if copy:
|
|
1182
|
+
adata = adata.copy()
|
|
1183
|
+
|
|
1184
|
+
rng = np.random.default_rng(random_state)
|
|
1185
|
+
|
|
1186
|
+
# Load and validate target proportions
|
|
1187
|
+
props = _load_target_proportions(target_proportions)
|
|
1188
|
+
if props:
|
|
1189
|
+
logger.info(f"Using target proportions for {len(props)} cell type(s)")
|
|
1190
|
+
for ct, prop in props.items():
|
|
1191
|
+
logger.debug(f" {ct}: {prop:.4f} ({prop*100:.2f}%)")
|
|
1192
|
+
|
|
1193
|
+
# Validate label_column
|
|
1194
|
+
if label_column not in adata.obs.columns:
|
|
1195
|
+
raise ValueError(
|
|
1196
|
+
f"Label column '{label_column}' not found. "
|
|
1197
|
+
f"Available: {list(adata.obs.columns)}"
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# Determine grouping column: use group_by_column if provided, else label_column
|
|
1201
|
+
if group_by_column is not None:
|
|
1202
|
+
if group_by_column not in adata.obs.columns:
|
|
1203
|
+
raise ValueError(
|
|
1204
|
+
f"Group-by column '{group_by_column}' not found. "
|
|
1205
|
+
f"Available: {list(adata.obs.columns)}"
|
|
1206
|
+
)
|
|
1207
|
+
# Use group_by_column for grouping (e.g., CL IDs)
|
|
1208
|
+
cell_types = adata.obs[group_by_column].astype(str)
|
|
1209
|
+
logger.info(
|
|
1210
|
+
f"Grouping by '{group_by_column}' instead of '{label_column}' "
|
|
1211
|
+
f"(semantic grouping enabled)"
|
|
1212
|
+
)
|
|
1213
|
+
else:
|
|
1214
|
+
cell_types = adata.obs[label_column].astype(str)
|
|
1215
|
+
|
|
1216
|
+
unique_types = cell_types.unique()
|
|
1217
|
+
|
|
1218
|
+
# =========================================================================
|
|
1219
|
+
# Source-unaware mode (simple capping)
|
|
1220
|
+
# =========================================================================
|
|
1221
|
+
if source_column is None:
|
|
1222
|
+
logger.info(
|
|
1223
|
+
f"Subsampling {adata.n_obs:,} cells "
|
|
1224
|
+
f"(cap={max_cells_per_type:,}, no source balancing)"
|
|
1225
|
+
)
|
|
1226
|
+
return _subsample_simple_cap(
|
|
1227
|
+
adata, cell_types, unique_types,
|
|
1228
|
+
max_cells_per_type, min_cells_per_type, rng
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
# =========================================================================
|
|
1232
|
+
# Source-aware mode (nested balancing)
|
|
1233
|
+
# =========================================================================
|
|
1234
|
+
if source_column not in adata.obs.columns:
|
|
1235
|
+
raise ValueError(
|
|
1236
|
+
f"Source column '{source_column}' not found. "
|
|
1237
|
+
f"Available: {list(adata.obs.columns)}. "
|
|
1238
|
+
f"Set source_column=None to disable source-aware balancing."
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
if source_balance not in ("proportional", "equal"):
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
f"Invalid source_balance: '{source_balance}'. "
|
|
1244
|
+
f"Must be 'proportional' or 'equal'."
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
sources = adata.obs[source_column].astype(str)
|
|
1248
|
+
unique_sources = sources.unique()
|
|
1249
|
+
n_sources = len(unique_sources)
|
|
1250
|
+
|
|
1251
|
+
logger.info(
|
|
1252
|
+
f"Subsampling {adata.n_obs:,} cells "
|
|
1253
|
+
f"(source-aware, {n_sources} sources, {source_balance} balance, "
|
|
1254
|
+
f"max={max_cells_per_type:,}/type)"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
selected_indices = []
|
|
1258
|
+
|
|
1259
|
+
for cell_type in unique_types:
|
|
1260
|
+
type_mask = cell_types == cell_type
|
|
1261
|
+
type_indices = np.where(type_mask)[0]
|
|
1262
|
+
n_available = len(type_indices)
|
|
1263
|
+
|
|
1264
|
+
# Skip empty types
|
|
1265
|
+
if n_available == 0:
|
|
1266
|
+
continue
|
|
1267
|
+
|
|
1268
|
+
# Keep all if below minimum
|
|
1269
|
+
if n_available <= min_cells_per_type:
|
|
1270
|
+
selected_indices.extend(type_indices)
|
|
1271
|
+
logger.debug(f" {cell_type}: keeping all {n_available} (below min)")
|
|
1272
|
+
continue
|
|
1273
|
+
|
|
1274
|
+
# Identify which sources have this cell type
|
|
1275
|
+
type_sources = sources.iloc[type_indices]
|
|
1276
|
+
sources_with_type = type_sources.unique()
|
|
1277
|
+
n_sources_with_type = len(sources_with_type)
|
|
1278
|
+
|
|
1279
|
+
# Warn if only one source has this type (when multiple sources exist)
|
|
1280
|
+
if n_sources_with_type == 1 and n_sources > 1:
|
|
1281
|
+
logger.warning(
|
|
1282
|
+
f" {cell_type}: only in '{sources_with_type[0]}' "
|
|
1283
|
+
f"(no cross-source balancing)"
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
# Calculate target per cell type
|
|
1287
|
+
if props and cell_type in props:
|
|
1288
|
+
# Use proportion-based target for specified types (Cap & Fill)
|
|
1289
|
+
# Cap: don't exceed proportion × total_cells
|
|
1290
|
+
# Fill: but ensure at least min_cells_per_type for training quality
|
|
1291
|
+
proportion_target = int(props[cell_type] * adata.n_obs)
|
|
1292
|
+
capped_target = min(proportion_target, n_available)
|
|
1293
|
+
target_total = max(min_cells_per_type, capped_target)
|
|
1294
|
+
logger.debug(
|
|
1295
|
+
f" {cell_type}: proportion {props[cell_type]:.4f} -> "
|
|
1296
|
+
f"{proportion_target} cells (cap&fill: {target_total})"
|
|
1297
|
+
)
|
|
1298
|
+
else:
|
|
1299
|
+
# Use normal capping for unspecified types
|
|
1300
|
+
target_total = min(max_cells_per_type, n_available)
|
|
1301
|
+
|
|
1302
|
+
# Calculate per-source targets
|
|
1303
|
+
source_targets = _calculate_source_targets(
|
|
1304
|
+
type_sources=type_sources,
|
|
1305
|
+
sources_with_type=sources_with_type,
|
|
1306
|
+
target_total=target_total,
|
|
1307
|
+
source_balance=source_balance,
|
|
1308
|
+
min_cells_per_source=min_cells_per_source,
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
# Sample from each source
|
|
1312
|
+
for source_name, (target, available) in source_targets.items():
|
|
1313
|
+
source_type_mask = (cell_types == cell_type) & (sources == source_name)
|
|
1314
|
+
source_indices = np.where(source_type_mask)[0]
|
|
1315
|
+
|
|
1316
|
+
if target >= len(source_indices):
|
|
1317
|
+
selected_indices.extend(source_indices)
|
|
1318
|
+
else:
|
|
1319
|
+
sampled = rng.choice(source_indices, size=target, replace=False)
|
|
1320
|
+
selected_indices.extend(sampled)
|
|
1321
|
+
|
|
1322
|
+
# Debug log
|
|
1323
|
+
total_sampled = sum(t for t, _ in source_targets.values())
|
|
1324
|
+
source_summary = ", ".join(f"{s}:{t}" for s, (t, _) in source_targets.items())
|
|
1325
|
+
logger.debug(f" {cell_type}: {n_available} -> {total_sampled} [{source_summary}]")
|
|
1326
|
+
|
|
1327
|
+
# Sort and subset
|
|
1328
|
+
selected_indices = sorted(set(selected_indices))
|
|
1329
|
+
adata_sub = adata[selected_indices].copy()
|
|
1330
|
+
|
|
1331
|
+
# Log final summary
|
|
1332
|
+
new_counts = adata_sub.obs[label_column].value_counts()
|
|
1333
|
+
logger.info(
|
|
1334
|
+
f"Subsampled: {adata.n_obs:,} -> {adata_sub.n_obs:,} cells "
|
|
1335
|
+
f"({len(new_counts)} types)"
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
return adata_sub
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def _calculate_source_targets(
|
|
1342
|
+
type_sources: pd.Series,
|
|
1343
|
+
sources_with_type: np.ndarray,
|
|
1344
|
+
target_total: int,
|
|
1345
|
+
source_balance: str,
|
|
1346
|
+
min_cells_per_source: int,
|
|
1347
|
+
) -> Dict[str, Tuple[int, int]]:
|
|
1348
|
+
"""
|
|
1349
|
+
Calculate how many cells to sample from each source for one cell type.
|
|
1350
|
+
|
|
1351
|
+
Parameters
|
|
1352
|
+
----------
|
|
1353
|
+
type_sources : pd.Series
|
|
1354
|
+
Series of source labels for cells of this type.
|
|
1355
|
+
sources_with_type : np.ndarray
|
|
1356
|
+
Unique sources that have this cell type.
|
|
1357
|
+
target_total : int
|
|
1358
|
+
Total cells to sample for this cell type.
|
|
1359
|
+
source_balance : str
|
|
1360
|
+
"proportional" or "equal".
|
|
1361
|
+
min_cells_per_source : int
|
|
1362
|
+
Minimum cells to draw from each source.
|
|
1363
|
+
|
|
1364
|
+
Returns
|
|
1365
|
+
-------
|
|
1366
|
+
Dict[str, Tuple[int, int]]
|
|
1367
|
+
Mapping: source_name -> (target_count, available_count)
|
|
1368
|
+
"""
|
|
1369
|
+
source_counts = type_sources.value_counts().to_dict()
|
|
1370
|
+
total_available = sum(source_counts.values())
|
|
1371
|
+
n_sources = len(sources_with_type)
|
|
1372
|
+
|
|
1373
|
+
targets = {}
|
|
1374
|
+
|
|
1375
|
+
if source_balance == "proportional":
|
|
1376
|
+
# Draw proportionally to each source's contribution
|
|
1377
|
+
for source_name in sources_with_type:
|
|
1378
|
+
available = source_counts[source_name]
|
|
1379
|
+
proportion = available / total_available
|
|
1380
|
+
target = int(np.ceil(target_total * proportion))
|
|
1381
|
+
|
|
1382
|
+
# Enforce minimum (if source has enough cells)
|
|
1383
|
+
if available >= min_cells_per_source:
|
|
1384
|
+
target = max(target, min_cells_per_source)
|
|
1385
|
+
|
|
1386
|
+
# Can't exceed available
|
|
1387
|
+
target = min(target, available)
|
|
1388
|
+
targets[source_name] = (target, available)
|
|
1389
|
+
|
|
1390
|
+
elif source_balance == "equal":
|
|
1391
|
+
# Draw equally from each source
|
|
1392
|
+
per_source = target_total // n_sources
|
|
1393
|
+
remainder = target_total % n_sources
|
|
1394
|
+
|
|
1395
|
+
for i, source_name in enumerate(sorted(sources_with_type)):
|
|
1396
|
+
available = source_counts[source_name]
|
|
1397
|
+
target = per_source + (1 if i < remainder else 0)
|
|
1398
|
+
target = min(target, available)
|
|
1399
|
+
targets[source_name] = (target, available)
|
|
1400
|
+
|
|
1401
|
+
# Redistribute shortfall (when some sources can't provide enough)
|
|
1402
|
+
total_targeted = sum(t for t, _ in targets.values())
|
|
1403
|
+
shortfall = target_total - total_targeted
|
|
1404
|
+
|
|
1405
|
+
if shortfall > 0:
|
|
1406
|
+
for source_name in sources_with_type:
|
|
1407
|
+
if shortfall <= 0:
|
|
1408
|
+
break
|
|
1409
|
+
target, available = targets[source_name]
|
|
1410
|
+
capacity = available - target
|
|
1411
|
+
if capacity > 0:
|
|
1412
|
+
additional = min(capacity, shortfall)
|
|
1413
|
+
targets[source_name] = (target + additional, available)
|
|
1414
|
+
shortfall -= additional
|
|
1415
|
+
|
|
1416
|
+
return targets
|
|
1417
|
+
|
|
1418
|
+
|
|
1419
|
+
def _subsample_simple_cap(
|
|
1420
|
+
adata: ad.AnnData,
|
|
1421
|
+
cell_types: pd.Series,
|
|
1422
|
+
unique_types: np.ndarray,
|
|
1423
|
+
max_cells_per_type: int,
|
|
1424
|
+
min_cells_per_type: int,
|
|
1425
|
+
rng: np.random.Generator,
|
|
1426
|
+
) -> ad.AnnData:
|
|
1427
|
+
"""
|
|
1428
|
+
Simple per-type capping without source awareness.
|
|
1429
|
+
|
|
1430
|
+
Used when source_column=None is passed to subsample_balanced().
|
|
1431
|
+
"""
|
|
1432
|
+
selected_indices = []
|
|
1433
|
+
|
|
1434
|
+
for cell_type in unique_types:
|
|
1435
|
+
type_mask = cell_types == cell_type
|
|
1436
|
+
type_indices = np.where(type_mask)[0]
|
|
1437
|
+
n_available = len(type_indices)
|
|
1438
|
+
|
|
1439
|
+
if n_available == 0:
|
|
1440
|
+
continue
|
|
1441
|
+
elif n_available <= min_cells_per_type:
|
|
1442
|
+
selected_indices.extend(type_indices)
|
|
1443
|
+
elif n_available <= max_cells_per_type:
|
|
1444
|
+
selected_indices.extend(type_indices)
|
|
1445
|
+
else:
|
|
1446
|
+
sampled = rng.choice(type_indices, size=max_cells_per_type, replace=False)
|
|
1447
|
+
selected_indices.extend(sampled)
|
|
1448
|
+
|
|
1449
|
+
selected_indices = sorted(selected_indices)
|
|
1450
|
+
return adata[selected_indices].copy()
|
|
1451
|
+
|
|
1452
|
+
|
|
1453
|
+
|
|
1454
|
+
# NOTE: harmonize_labels() function was removed in favor of using:
|
|
1455
|
+
# 1. add_ontology_ids() to fill missing CL IDs
|
|
1456
|
+
# 2. subsample_balanced(group_by_column="cell_type_ontology_term_id") for semantic grouping
|
|
1457
|
+
# See pipeline.py for the new recommended workflow.
|