spatialcore 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. spatialcore/__init__.py +122 -0
  2. spatialcore/annotation/__init__.py +253 -0
  3. spatialcore/annotation/acquisition.py +529 -0
  4. spatialcore/annotation/annotate.py +603 -0
  5. spatialcore/annotation/cellxgene.py +365 -0
  6. spatialcore/annotation/confidence.py +802 -0
  7. spatialcore/annotation/discovery.py +529 -0
  8. spatialcore/annotation/expression.py +363 -0
  9. spatialcore/annotation/loading.py +529 -0
  10. spatialcore/annotation/markers.py +297 -0
  11. spatialcore/annotation/ontology.py +1282 -0
  12. spatialcore/annotation/patterns.py +247 -0
  13. spatialcore/annotation/pipeline.py +620 -0
  14. spatialcore/annotation/synapse.py +380 -0
  15. spatialcore/annotation/training.py +1457 -0
  16. spatialcore/annotation/validation.py +422 -0
  17. spatialcore/core/__init__.py +34 -0
  18. spatialcore/core/cache.py +118 -0
  19. spatialcore/core/logging.py +135 -0
  20. spatialcore/core/metadata.py +149 -0
  21. spatialcore/core/utils.py +768 -0
  22. spatialcore/data/gene_mappings/ensembl_to_hugo_human.tsv +86372 -0
  23. spatialcore/data/markers/canonical_markers.json +83 -0
  24. spatialcore/data/ontology_mappings/ontology_index.json +63865 -0
  25. spatialcore/plotting/__init__.py +109 -0
  26. spatialcore/plotting/benchmark.py +477 -0
  27. spatialcore/plotting/celltype.py +329 -0
  28. spatialcore/plotting/confidence.py +413 -0
  29. spatialcore/plotting/spatial.py +505 -0
  30. spatialcore/plotting/utils.py +411 -0
  31. spatialcore/plotting/validation.py +1342 -0
  32. spatialcore-0.1.9.dist-info/METADATA +213 -0
  33. spatialcore-0.1.9.dist-info/RECORD +36 -0
  34. spatialcore-0.1.9.dist-info/WHEEL +5 -0
  35. spatialcore-0.1.9.dist-info/licenses/LICENSE +201 -0
  36. spatialcore-0.1.9.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1457 @@
1
+ """
2
+ CellTypist model training utilities.
3
+
4
+ This module provides utilities for:
5
+ 1. Combining multiple reference datasets for training
6
+ 2. Training custom CellTypist models
7
+ 3. Panel gene subsetting for spatial transcriptomics
8
+
9
+ For spatial data (e.g., Xenium), custom models trained on panel-specific genes
10
+ achieve ~100% gene utilization vs ~8% with pre-trained models.
11
+
12
+ References:
13
+ - CellTypist: https://www.celltypist.org/
14
+ - Domínguez Conde et al., Science (2022)
15
+ """
16
+
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional, Tuple, Union, Any
19
+ from datetime import datetime
20
+ import json
21
+ import gc
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import scanpy as sc
26
+ import anndata as ad
27
+
28
+ from spatialcore.core.logging import get_logger
29
+ from spatialcore.core.utils import (
30
+ load_ensembl_to_hugo_mapping,
31
+ normalize_gene_names,
32
+ check_normalization_status,
33
+ )
34
+ from spatialcore.annotation.loading import (
35
+ load_adata_backed,
36
+ ensure_normalized,
37
+ get_available_memory_gb,
38
+ )
39
+ from spatialcore.annotation.validation import validate_cell_type_column
40
+ from spatialcore.annotation.acquisition import resolve_uri_to_local
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ # Default cache directory for downloaded references
45
+ DEFAULT_CACHE_DIR = Path.home() / ".spatialcore" / "cache" / "references"
46
+
47
+ # Default labels to exclude from training (ambiguous/uninformative)
48
+ # These are matched EXACTLY (case-sensitive, no partial matching)
49
+ # "unknown cells" would NOT be filtered (not an exact match to "unknown")
50
+ DEFAULT_EXCLUDE_LABELS = [
51
+ "unknown",
52
+ "Unknown",
53
+ "UNKNOWN",
54
+ "unassigned",
55
+ "Unassigned",
56
+ "na",
57
+ "NA",
58
+ "N/A",
59
+ "n/a",
60
+ "none",
61
+ "None",
62
+ "null",
63
+ "doublet",
64
+ "Doublet",
65
+ "low quality",
66
+ "Low quality",
67
+ ]
68
+
69
+
70
+ # ============================================================================
71
+ # Reference Combination
72
+ # ============================================================================
73
+
74
+ def combine_references(
75
+ reference_paths: List[Union[str, Path]],
76
+ label_columns: List[str],
77
+ output_column: str = "original_label",
78
+ max_cells_per_ref: int = 100000,
79
+ target_genes: Optional[List[str]] = None,
80
+ normalize_data: bool = True,
81
+ random_state: int = 42,
82
+ validate_labels: bool = True,
83
+ min_cells_per_type: int = 10,
84
+ strict_validation: bool = False,
85
+ cache_dir: Optional[Path] = None,
86
+ exclude_labels: Optional[List[str]] = None,
87
+ filter_min_cells: bool = True,
88
+ ) -> ad.AnnData:
89
+ """
90
+ Combine multiple reference datasets for CellTypist training.
91
+
92
+ This function handles:
93
+ 1. Memory-efficient loading with per-reference cell caps (stratified)
94
+ 2. Gene name normalization (Ensembl → HUGO)
95
+ 3. Expression normalization (log1p to 10k)
96
+ 4. Gene intersection (with optional panel gene subsetting)
97
+ 5. Concatenation with source tracking
98
+
99
+ IMPORTANT: This function does NOT perform post-combine balancing.
100
+ Call subsample_balanced() on the output for source-aware balancing.
101
+ For semantic grouping by ontology ID, first run add_ontology_ids() then use
102
+ subsample_balanced(group_by_column="cell_type_ontology_term_id").
103
+
104
+ Parameters
105
+ ----------
106
+ reference_paths : List[str or Path]
107
+ Paths to reference h5ad files. Supports:
108
+
109
+ - Local paths: ``/data/references/lung.h5ad``
110
+ - GCS URIs: ``gs://bucket/references/lung.h5ad``
111
+ - S3 URIs: ``s3://bucket/references/lung.h5ad``
112
+
113
+ Cloud files are automatically downloaded to cache_dir and loaded
114
+ in memory-efficient backed mode.
115
+ label_columns : List[str]
116
+ Cell type label column for each reference.
117
+ output_column : str, default "original_label"
118
+ Column name for unified cell type labels in output.
119
+ Use "original_label" (new default) for clarity that this is the raw
120
+ source label before any harmonization.
121
+ max_cells_per_ref : int, default 100000
122
+ Maximum cells to load per reference. Uses stratified sampling
123
+ to preserve natural cell type proportions within each reference.
124
+ This is for MEMORY MANAGEMENT during loading, not training balance.
125
+ target_genes : List[str], optional
126
+ Panel genes to subset to (e.g., from spatial data via get_panel_genes()).
127
+ If provided, each reference is subset to these genes before
128
+ finding the intersection. This ensures maximum gene utilization.
129
+ normalize_data : bool, default True
130
+ Whether to ensure log1p(10k) normalization.
131
+ random_state : int, default 42
132
+ Random seed for reproducibility.
133
+ validate_labels : bool, default True
134
+ Run cell type column validation before combining. Checks for
135
+ null values, suspicious patterns, and cardinality issues.
136
+ min_cells_per_type : int, default 10
137
+ Minimum cells required per cell type. Used for validation warnings
138
+ and, when ``filter_min_cells=True``, to remove low-count types
139
+ after concatenation.
140
+ strict_validation : bool, default False
141
+ If True, fail on validation warnings (not just errors).
142
+ cache_dir : Path, optional
143
+ Directory for caching downloaded cloud files. Defaults to
144
+ ``~/.spatialcore/cache/references/``. Only used for gs:// and s3:// URIs.
145
+ exclude_labels : List[str], optional
146
+ Cell type labels to exclude from the combined output. Uses exact
147
+ case-sensitive matching (no partial matches). Cells with these labels
148
+ are removed after concatenation. If None (default), uses
149
+ DEFAULT_EXCLUDE_LABELS which includes common ambiguous labels like
150
+ "unknown", "Unknown", "unassigned", "NA", "doublet", etc.
151
+ Pass an empty list ``[]`` to disable label filtering entirely.
152
+ filter_min_cells : bool, default True
153
+ If True, remove cell types with fewer than ``min_cells_per_type``
154
+ cells after concatenation. If False, only warn about low-count
155
+ types (original behavior). Filtering is recommended for training
156
+ as singleton or very-low-count cell types can destabilize models.
157
+
158
+ Returns
159
+ -------
160
+ AnnData
161
+ Combined reference data with:
162
+ - Unified cell type labels in output_column
163
+ - Source tracking in .obs["reference_source"]
164
+ - Gene intersection applied
165
+ - Ready for subsample_balanced()
166
+
167
+ Notes
168
+ -----
169
+ **Two-stage workflow for multi-reference training:**
170
+
171
+ 1. **combine_references()** (this function): Load, normalize, intersect genes
172
+ 2. **subsample_balanced()**: Source-aware balancing for training quality
173
+
174
+ **For spatial transcriptomics**, providing `target_genes` (panel genes)
175
+ before training is critical:
176
+ - Pre-trained models: ~8% gene overlap with 400-gene panel
177
+ - Custom panel models: 100% gene overlap
178
+
179
+ Examples
180
+ --------
181
+ >>> from spatialcore.annotation import (
182
+ ... combine_references, subsample_balanced, get_panel_genes
183
+ ... )
184
+ >>> # Step 1: Get panel genes from spatial data
185
+ >>> panel_genes = get_panel_genes(xenium_adata)
186
+ >>>
187
+ >>> # Step 2: Combine references (loading + gene intersection)
188
+ >>> # Supports local paths and cloud URIs (GCS, S3)
189
+ >>> combined = combine_references(
190
+ ... reference_paths=[
191
+ ... "gs://my-bucket/references/hlca.h5ad", # GCS
192
+ ... "s3://my-bucket/references/liver.h5ad", # S3
193
+ ... "/local/data/study3.h5ad", # Local
194
+ ... ],
195
+ ... label_columns=["cell_type", "cell_type", "cell_type"],
196
+ ... max_cells_per_ref=100000, # Memory cap during loading
197
+ ... target_genes=panel_genes, # Intersect with spatial panel
198
+ ... )
199
+ >>> # Output: combined AnnData with reference_source tracked in .obs
200
+ >>>
201
+ >>> # Step 3: Source-aware balancing (SEPARATE STEP)
202
+ >>> balanced = subsample_balanced(
203
+ ... combined,
204
+ ... label_column="original_label",
205
+ ... max_cells_per_type=10000,
206
+ ... source_balance="proportional",
207
+ ... )
208
+ """
209
+ if len(reference_paths) != len(label_columns):
210
+ raise ValueError(
211
+ f"Number of paths ({len(reference_paths)}) must match "
212
+ f"number of label columns ({len(label_columns)})"
213
+ )
214
+
215
+ # Initialize cache directory for cloud downloads
216
+ if cache_dir is None:
217
+ cache_dir = DEFAULT_CACHE_DIR
218
+ cache_dir = Path(cache_dir)
219
+ cache_dir.mkdir(parents=True, exist_ok=True)
220
+
221
+ # Load gene mapping once at start (reused for all references)
222
+ logger.info("Loading Ensembl to HUGO gene mapping...")
223
+ ensembl_to_hugo = load_ensembl_to_hugo_mapping()
224
+ logger.info(f"Loaded {len(ensembl_to_hugo):,} gene mappings")
225
+
226
+ adatas = []
227
+ validation_results = [] # Track validation results for each reference
228
+
229
+ for i, (ref_path, label_col) in enumerate(zip(reference_paths, label_columns)):
230
+ ref_path_str = str(ref_path)
231
+
232
+ # Extract source name for logging (handle URIs and local paths)
233
+ if ref_path_str.startswith(("gs://", "s3://")):
234
+ # Cloud URI: extract filename from URI
235
+ source_name = Path(ref_path_str.split("/")[-1]).stem
236
+ logger.info(f"\n[{i+1}/{len(reference_paths)}] Loading: {ref_path_str}")
237
+ else:
238
+ source_name = Path(ref_path_str).stem
239
+ logger.info(f"\n[{i+1}/{len(reference_paths)}] Loading: {Path(ref_path_str).name}")
240
+
241
+ # Resolve URI to local path (downloads cloud files if needed)
242
+ local_path = resolve_uri_to_local(ref_path_str, cache_dir)
243
+
244
+ # Memory-efficient loading with backed mode for large files
245
+ adata = load_adata_backed(
246
+ path=local_path,
247
+ max_cells=max_cells_per_ref,
248
+ label_column=label_col,
249
+ random_state=random_state,
250
+ )
251
+
252
+ # Validate cell type column BEFORE gene normalization
253
+ if validate_labels:
254
+ logger.info(f" Validating cell type column: {label_col}")
255
+ val_result = validate_cell_type_column(
256
+ adata,
257
+ label_col,
258
+ min_cells_per_type=min_cells_per_type,
259
+ )
260
+ validation_results.append({
261
+ "path": ref_path_str,
262
+ "column": label_col,
263
+ "is_valid": val_result.is_valid,
264
+ "n_cell_types": val_result.n_cell_types,
265
+ "n_cells": val_result.n_cells,
266
+ "errors": [str(e) for e in val_result.errors],
267
+ "warnings": [str(w) for w in val_result.warnings],
268
+ })
269
+
270
+ if not val_result.is_valid:
271
+ raise ValueError(
272
+ f"Validation failed for {source_name}:\n{val_result.summary()}"
273
+ )
274
+
275
+ if strict_validation and val_result.warnings:
276
+ raise ValueError(
277
+ f"Validation warnings (strict mode) for {source_name}:\n"
278
+ f"{val_result.summary()}"
279
+ )
280
+
281
+ # Normalize gene names (Ensembl → HUGO)
282
+ adata = normalize_gene_names(adata, ensembl_to_hugo)
283
+
284
+ # Check/apply normalization
285
+ if normalize_data:
286
+ status = check_normalization_status(adata)
287
+ if status["x_state"] == "log1p_10k":
288
+ logger.info(" Data already log-normalized")
289
+ else:
290
+ adata = ensure_normalized(adata)
291
+ logger.info(" Applied log1p(10k) normalization")
292
+
293
+ # Copy cell type labels to unified column
294
+ if label_col not in adata.obs.columns:
295
+ available = list(adata.obs.columns)
296
+ raise ValueError(
297
+ f"Label column '{label_col}' not found in {source_name}. "
298
+ f"Available columns: {available}"
299
+ )
300
+ adata.obs[output_column] = adata.obs[label_col].astype(str)
301
+
302
+ # Add source reference info (use source name from URI or local path)
303
+ adata.obs["reference_source"] = source_name
304
+
305
+ adatas.append(adata)
306
+ gc.collect()
307
+
308
+ # Subset to target genes if provided (BEFORE finding shared genes)
309
+ if target_genes:
310
+ logger.info(f"\nSubsetting to {len(target_genes)} target genes...")
311
+ target_set = set(target_genes)
312
+ for i, adata in enumerate(adatas):
313
+ overlap = list(set(adata.var_names) & target_set)
314
+ if len(overlap) == 0:
315
+ raise ValueError(
316
+ f"No overlap between reference {i} and target genes. "
317
+ f"Check gene name format (HUGO symbols expected)."
318
+ )
319
+ adatas[i] = adata[:, overlap].copy()
320
+ logger.info(f" Reference {i}: {len(overlap)} genes after subset")
321
+
322
+ # Find shared genes (inner join)
323
+ logger.info("\nFinding shared genes across all references...")
324
+ shared_genes = set(adatas[0].var_names)
325
+ for adata in adatas[1:]:
326
+ shared_genes &= set(adata.var_names)
327
+
328
+ if len(shared_genes) == 0:
329
+ raise ValueError(
330
+ "No shared genes found across references! "
331
+ "Check that gene names are in the same format (HUGO symbols)."
332
+ )
333
+
334
+ logger.info(f" Shared genes: {len(shared_genes):,}")
335
+
336
+ # Subset all to shared genes (sorted for consistency)
337
+ shared_genes_sorted = sorted(shared_genes)
338
+ for i in range(len(adatas)):
339
+ adatas[i] = adatas[i][:, shared_genes_sorted].copy()
340
+
341
+ # Memory check before concatenation
342
+ available_gb = get_available_memory_gb()
343
+ total_cells = sum(a.n_obs for a in adatas)
344
+ estimated_gb = (total_cells * len(shared_genes) * 4) / (1024**3)
345
+
346
+ logger.info(f"\nMemory check before concatenation:")
347
+ logger.info(f" Total cells: {total_cells:,}")
348
+ if available_gb > 0:
349
+ logger.info(f" Available: {available_gb:.1f} GB")
350
+ logger.info(f" Estimated need: ~{estimated_gb:.1f} GB")
351
+
352
+ # Concatenate
353
+ logger.info("\nConcatenating references...")
354
+ combined = sc.concat(adatas, join="inner", label="batch", index_unique="-")
355
+ logger.info(f" Combined: {combined.n_obs:,} cells × {combined.n_vars:,} genes")
356
+
357
+ # Filter excluded labels (exact match only, no partial matching)
358
+ if exclude_labels is None:
359
+ exclude_labels = DEFAULT_EXCLUDE_LABELS
360
+
361
+ labels_to_exclude = set(exclude_labels)
362
+ label_values = combined.obs[output_column].astype(str)
363
+ exclude_mask = label_values.isin(labels_to_exclude)
364
+ n_excluded_labels = exclude_mask.sum()
365
+
366
+ if n_excluded_labels > 0:
367
+ excluded_counts = combined.obs.loc[exclude_mask, output_column].value_counts()
368
+ logger.info(f"\nFiltering excluded labels:")
369
+ for label, count in excluded_counts.items():
370
+ logger.info(f" Removing '{label}': {count:,} cells")
371
+ combined = combined[~exclude_mask].copy()
372
+ logger.info(f" Remaining: {combined.n_obs:,} cells")
373
+
374
+ # Filter low-count cell types
375
+ if filter_min_cells and min_cells_per_type > 0:
376
+ type_counts = combined.obs[output_column].value_counts()
377
+ low_count_types = type_counts[type_counts < min_cells_per_type].index.tolist()
378
+
379
+ if low_count_types:
380
+ low_count_mask = combined.obs[output_column].isin(low_count_types)
381
+ n_removed = low_count_mask.sum()
382
+ logger.info(f"\nFiltering low-count cell types (<{min_cells_per_type} cells):")
383
+ logger.info(f" Removing {len(low_count_types)} types, {n_removed:,} cells")
384
+ for ct in low_count_types[:10]:
385
+ logger.info(f" {ct}: {type_counts[ct]} cells")
386
+ if len(low_count_types) > 10:
387
+ logger.info(f" ... and {len(low_count_types) - 10} more types")
388
+ combined = combined[~low_count_mask].copy()
389
+ logger.info(f" Remaining: {combined.n_obs:,} cells, {combined.obs[output_column].nunique()} types")
390
+
391
+ # Print cell type distribution
392
+ logger.info(f"\n Cell type distribution:")
393
+ ct_counts = combined.obs[output_column].value_counts()
394
+ for ct, count in ct_counts.head(10).items():
395
+ logger.info(f" {ct}: {count:,} cells")
396
+ if len(ct_counts) > 10:
397
+ logger.info(f" ... and {len(ct_counts) - 10} more types")
398
+
399
+ # Store validation results in uns if validation was performed
400
+ if validate_labels and validation_results:
401
+ combined.uns["validation_results"] = validation_results
402
+
403
+ logger.info(
404
+ f"\nCombined reference ready. Call subsample_balanced() for "
405
+ f"source-aware balancing before training."
406
+ )
407
+
408
+ return combined
409
+
410
+
411
+ def get_panel_genes(adata: ad.AnnData) -> List[str]:
412
+ """
413
+ Extract panel genes from AnnData (e.g., Xenium spatial data).
414
+
415
+ Parameters
416
+ ----------
417
+ adata : AnnData
418
+ AnnData object (typically from spatial platform).
419
+
420
+ Returns
421
+ -------
422
+ List[str]
423
+ List of gene names (panel genes).
424
+ """
425
+ return list(adata.var_names)
426
+
427
+
428
+ # ============================================================================
429
+ # CellTypist Training
430
+ # ============================================================================
431
+
432
+ def _save_model_metadata(
433
+ metadata_path: Path,
434
+ model,
435
+ adata: ad.AnnData,
436
+ label_column: str,
437
+ training_params: Dict[str, Any],
438
+ ) -> None:
439
+ """
440
+ Save model training metadata for reproducibility.
441
+
442
+ Parameters
443
+ ----------
444
+ metadata_path : Path
445
+ Path to save JSON metadata.
446
+ model
447
+ Trained CellTypist model.
448
+ adata : AnnData
449
+ Training data.
450
+ label_column : str
451
+ Cell type label column.
452
+ training_params : Dict
453
+ Training parameters used.
454
+ """
455
+ try:
456
+ import celltypist
457
+ celltypist_version = celltypist.__version__
458
+ except Exception:
459
+ celltypist_version = "unknown"
460
+
461
+ try:
462
+ import spatialcore
463
+ spatialcore_version = spatialcore.__version__
464
+ except Exception:
465
+ spatialcore_version = "unknown"
466
+
467
+ # Cell type counts
468
+ cell_type_counts = adata.obs[label_column].value_counts().to_dict()
469
+
470
+ # Reference sources if tracked
471
+ reference_sources = []
472
+ if "reference_source" in adata.obs.columns:
473
+ for source in adata.obs["reference_source"].unique():
474
+ source_mask = adata.obs["reference_source"] == source
475
+ reference_sources.append({
476
+ "name": source,
477
+ "n_cells_used": int(source_mask.sum()),
478
+ })
479
+
480
+ metadata = {
481
+ "model_name": metadata_path.stem.replace("_metadata", ""),
482
+ "created_at": datetime.now().isoformat(),
483
+ "spatialcore_version": spatialcore_version,
484
+ "celltypist_version": celltypist_version,
485
+ "training": {
486
+ "n_cells": int(adata.n_obs),
487
+ "n_genes": int(len(model.features)),
488
+ "n_cell_types": int(len(model.cell_types)),
489
+ "label_column": label_column,
490
+ **training_params,
491
+ },
492
+ "references": reference_sources,
493
+ "panel_genes": {
494
+ "n_genes": int(len(model.features)),
495
+ "genes": list(model.features)[:50], # First 50 for preview
496
+ "genes_truncated": len(model.features) > 50,
497
+ },
498
+ "cell_type_summary": {
499
+ str(k): int(v) for k, v in cell_type_counts.items()
500
+ },
501
+ }
502
+
503
+ with open(metadata_path, "w") as f:
504
+ json.dump(metadata, f, indent=2)
505
+
506
+
507
+ def train_celltypist_model(
508
+ adata: ad.AnnData,
509
+ label_column: str = "unified_cell_type",
510
+ model_name: str = "custom_model",
511
+ output_path: Optional[Union[str, Path]] = None,
512
+ use_SGD: bool = True,
513
+ mini_batch: bool = True,
514
+ balance_cell_type: bool = True,
515
+ feature_selection: bool = False,
516
+ n_jobs: int = -1,
517
+ max_iter: int = 100,
518
+ epochs: int = 10,
519
+ batch_size: int = 1000,
520
+ batch_number: int = 200,
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Train a custom CellTypist logistic regression model.
524
+
525
+ Parameters
526
+ ----------
527
+ adata : AnnData
528
+ Reference data (already subset to panel genes, normalized).
529
+ label_column : str, default "unified_cell_type"
530
+ Cell type label column in adata.obs.
531
+ model_name : str, default "custom_model"
532
+ Name for the trained model.
533
+ output_path : str or Path, optional
534
+ Where to save .pkl file. If None, doesn't save.
535
+ use_SGD : bool, default True
536
+ Use stochastic gradient descent (faster for large data).
537
+ mini_batch : bool, default True
538
+ Use mini-batch training (recommended for large datasets).
539
+ balance_cell_type : bool, default True
540
+ Balance rare cell types in mini-batches.
541
+ feature_selection : bool, default False
542
+ Perform feature selection (False = use all genes).
543
+ n_jobs : int, default -1
544
+ Parallel jobs (-1 = all cores).
545
+ max_iter : int, default 100
546
+ Maximum training iterations (for non-mini-batch).
547
+ epochs : int, default 10
548
+ Training epochs (for mini-batch).
549
+ batch_size : int, default 1000
550
+ Cells per batch (for mini-batch).
551
+ batch_number : int, default 200
552
+ Batches per epoch (for mini-batch).
553
+
554
+ Returns
555
+ -------
556
+ Dict[str, Any]
557
+ Model metadata including:
558
+ - model_path: Path to saved model (if output_path provided)
559
+ - n_cells_trained: Number of cells used for training
560
+ - n_genes: Number of features/genes
561
+ - n_cell_types: Number of cell types
562
+ - cell_types: List of cell type names
563
+ - model: The trained CellTypist model object
564
+
565
+ Notes
566
+ -----
567
+ For imbalanced data (e.g., 192k T cells vs 277 mast cells), use
568
+ `balance_cell_type=True` with `mini_batch=True` to ensure rare
569
+ cell types are adequately represented in training batches.
570
+
571
+ Examples
572
+ --------
573
+ >>> from spatialcore.annotation import train_celltypist_model
574
+ >>> result = train_celltypist_model(
575
+ ... combined_adata,
576
+ ... label_column="unified_cell_type",
577
+ ... output_path="./models/liver_xenium_v1.pkl",
578
+ ... mini_batch=True,
579
+ ... balance_cell_type=True,
580
+ ... )
581
+ >>> print(f"Trained on {result['n_cells_trained']:,} cells")
582
+ >>> print(f"Cell types: {result['n_cell_types']}")
583
+ """
584
+ try:
585
+ import celltypist
586
+ except ImportError:
587
+ raise ImportError(
588
+ "celltypist is required for model training. "
589
+ "Install with: pip install celltypist"
590
+ )
591
+
592
+ # Validate label column
593
+ if label_column not in adata.obs.columns:
594
+ available = list(adata.obs.columns)
595
+ raise ValueError(
596
+ f"Label column '{label_column}' not found. Available: {available}"
597
+ )
598
+
599
+ # Log training parameters
600
+ logger.info("Training CellTypist model...")
601
+ logger.info(f" Cells: {adata.n_obs:,}")
602
+ logger.info(f" Genes: {adata.n_vars:,}")
603
+ logger.info(f" Cell types: {adata.obs[label_column].nunique()}")
604
+ logger.info(f" Mini-batch: {mini_batch}")
605
+ logger.info(f" Balance cell types: {balance_cell_type}")
606
+
607
+ # Train model
608
+ if mini_batch:
609
+ model = celltypist.train(
610
+ adata,
611
+ labels=label_column,
612
+ check_expression=False, # Already validated
613
+ use_SGD=use_SGD,
614
+ mini_batch=True,
615
+ balance_cell_type=balance_cell_type,
616
+ feature_selection=feature_selection,
617
+ n_jobs=n_jobs,
618
+ epochs=epochs,
619
+ batch_size=batch_size,
620
+ batch_number=batch_number,
621
+ )
622
+ else:
623
+ model = celltypist.train(
624
+ adata,
625
+ labels=label_column,
626
+ check_expression=False,
627
+ use_SGD=use_SGD,
628
+ feature_selection=feature_selection,
629
+ n_jobs=n_jobs,
630
+ max_iter=max_iter,
631
+ )
632
+
633
+ # Save model if output path provided
634
+ if output_path:
635
+ output_path = Path(output_path)
636
+ output_path.parent.mkdir(parents=True, exist_ok=True)
637
+ model.write(str(output_path))
638
+ logger.info(f" Saved model to: {output_path}")
639
+
640
+ # Save metadata JSON for reproducibility
641
+ metadata_path = output_path.with_suffix("").with_name(
642
+ output_path.stem + "_celltypist.json"
643
+ )
644
+ _save_model_metadata(
645
+ metadata_path=metadata_path,
646
+ model=model,
647
+ adata=adata,
648
+ label_column=label_column,
649
+ training_params={
650
+ "use_SGD": use_SGD,
651
+ "mini_batch": mini_batch,
652
+ "balance_cell_type": balance_cell_type,
653
+ "feature_selection": feature_selection,
654
+ "n_jobs": n_jobs,
655
+ "max_iter": max_iter,
656
+ "epochs": epochs,
657
+ "batch_size": batch_size,
658
+ "batch_number": batch_number,
659
+ },
660
+ )
661
+ logger.info(f" Saved metadata to: {metadata_path}")
662
+
663
+ return {
664
+ "model_path": str(output_path) if output_path else None,
665
+ "metadata_path": str(metadata_path) if output_path else None,
666
+ "n_cells_trained": adata.n_obs,
667
+ "n_genes": len(model.features),
668
+ "n_cell_types": len(model.cell_types),
669
+ "cell_types": list(model.cell_types),
670
+ "model": model,
671
+ }
672
+
673
+
674
+ def get_model_gene_overlap(
675
+ model_path: Union[str, Path],
676
+ query_genes: List[str],
677
+ ) -> Dict[str, Any]:
678
+ """
679
+ Calculate gene overlap between a CellTypist model and query data.
680
+
681
+ Parameters
682
+ ----------
683
+ model_path : str or Path
684
+ Path to CellTypist model (.pkl file).
685
+ query_genes : List[str]
686
+ Gene names from query data (e.g., Xenium panel).
687
+
688
+ Returns
689
+ -------
690
+ Dict[str, Any]
691
+ - n_model_genes: Total genes in model
692
+ - n_query_genes: Total genes in query
693
+ - n_overlap: Number of overlapping genes
694
+ - overlap_pct: Percentage of model genes in query
695
+ - overlapping_genes: List of overlapping gene names
696
+ - missing_genes: List of model genes missing from query
697
+
698
+ Examples
699
+ --------
700
+ >>> from spatialcore.annotation import get_model_gene_overlap
701
+ >>> overlap = get_model_gene_overlap(
702
+ ... "Healthy_Human_Liver.pkl",
703
+ ... list(xenium_adata.var_names)
704
+ ... )
705
+ >>> print(f"Gene overlap: {overlap['overlap_pct']:.1f}%")
706
+ """
707
+ try:
708
+ from celltypist import models
709
+ except ImportError:
710
+ raise ImportError("celltypist is required. Install with: pip install celltypist")
711
+
712
+ model = models.Model.load(str(model_path))
713
+ model_genes = set(model.features)
714
+ query_genes_set = set(query_genes)
715
+
716
+ overlap = model_genes & query_genes_set
717
+ missing = model_genes - query_genes_set
718
+
719
+ return {
720
+ "n_model_genes": len(model_genes),
721
+ "n_query_genes": len(query_genes_set),
722
+ "n_overlap": len(overlap),
723
+ "overlap_pct": 100 * len(overlap) / len(model_genes) if model_genes else 0,
724
+ "overlapping_genes": sorted(overlap),
725
+ "missing_genes": sorted(missing),
726
+ }
727
+
728
+
729
+ def get_training_summary(combined_adata: ad.AnnData, label_column: str) -> pd.DataFrame:
730
+ """
731
+ Get summary of cell type distribution for training data.
732
+
733
+ Parameters
734
+ ----------
735
+ combined_adata : AnnData
736
+ Combined reference data.
737
+ label_column : str
738
+ Cell type label column.
739
+
740
+ Returns
741
+ -------
742
+ pd.DataFrame
743
+ Summary with columns: cell_type, n_cells, pct_total.
744
+ """
745
+ counts = combined_adata.obs[label_column].value_counts()
746
+ df = pd.DataFrame({
747
+ "cell_type": counts.index,
748
+ "n_cells": counts.values,
749
+ "pct_total": 100 * counts.values / combined_adata.n_obs,
750
+ })
751
+ return df
752
+
753
+
754
+ # ============================================================================
755
+ # Color Palettes
756
+ # ============================================================================
757
+
758
+ # High-contrast palette optimized for dark backgrounds
759
+ # Designed for spatial maps and UMAP visualizations
760
+ HIGH_CONTRAST_PALETTE = [
761
+ "#F5252F", # Red
762
+ "#FB3FFC", # Magenta
763
+ "#00FFFF", # Cyan
764
+ "#33FF33", # Green
765
+ "#FFB300", # Amber
766
+ "#9966FF", # Purple
767
+ "#FF6B6B", # Coral
768
+ "#3784FE", # Blue
769
+ "#FF8000", # Orange
770
+ "#66CCCC", # Teal
771
+ "#CC66FF", # Lavender
772
+ "#99FF99", # Light green
773
+ "#FF6699", # Pink
774
+ "#E3E1E3", # Light gray
775
+ "#FFB3CC", # Light pink
776
+ "#E6E680", # Yellow-green
777
+ "#CC9966", # Tan
778
+ "#8080FF", # Periwinkle
779
+ "#FF9999", # Salmon
780
+ "#66FF99", # Mint
781
+ ]
782
+
783
+
784
+ def generate_color_scheme(
785
+ cell_types: List[str],
786
+ custom_colors: Optional[Dict[str, str]] = None,
787
+ palette: Optional[List[str]] = None,
788
+ ) -> Dict[str, str]:
789
+ """
790
+ Generate deterministic color mapping for cell types.
791
+
792
+ Creates a mapping from cell type names to hex colors, using
793
+ custom overrides if provided, then filling remaining types
794
+ from the palette.
795
+
796
+ Parameters
797
+ ----------
798
+ cell_types : List[str]
799
+ List of cell type names.
800
+ custom_colors : Dict[str, str], optional
801
+ Custom color overrides. Keys are cell type names, values
802
+ are hex color codes (e.g., "#FF0000").
803
+ palette : List[str], optional
804
+ Color palette to use. If None, uses HIGH_CONTRAST_PALETTE.
805
+
806
+ Returns
807
+ -------
808
+ Dict[str, str]
809
+ Mapping from cell type to hex color.
810
+
811
+ Notes
812
+ -----
813
+ Colors are assigned deterministically based on sorted cell type names,
814
+ ensuring consistent colors across runs.
815
+
816
+ Examples
817
+ --------
818
+ >>> from spatialcore.annotation.training import generate_color_scheme
819
+ >>> colors = generate_color_scheme(
820
+ ... ["T cell", "B cell", "Macrophage"],
821
+ ... custom_colors={"T cell": "#FF0000"},
822
+ ... )
823
+ >>> print(colors)
824
+ """
825
+ if palette is None:
826
+ palette = HIGH_CONTRAST_PALETTE
827
+
828
+ if custom_colors is None:
829
+ custom_colors = {}
830
+
831
+ color_scheme = {}
832
+ palette_idx = 0
833
+
834
+ # Sort for deterministic ordering
835
+ for cell_type in sorted(cell_types):
836
+ if cell_type in custom_colors:
837
+ color_scheme[cell_type] = custom_colors[cell_type]
838
+ else:
839
+ color_scheme[cell_type] = palette[palette_idx % len(palette)]
840
+ palette_idx += 1
841
+
842
+ return color_scheme
843
+
844
+
845
+ # ============================================================================
846
+ # Model Artifact Management
847
+ # ============================================================================
848
+
849
+ def save_model_artifacts(
850
+ model,
851
+ output_dir: Union[str, Path],
852
+ model_name: str,
853
+ training_metadata: Optional[Dict[str, Any]] = None,
854
+ custom_colors: Optional[Dict[str, str]] = None,
855
+ ) -> Dict[str, Path]:
856
+ """
857
+ Save model with metadata and color scheme.
858
+
859
+ Creates a complete model artifact package with:
860
+ - Model pickle file (.pkl)
861
+ - Training metadata JSON
862
+ - Color scheme JSON for visualization
863
+
864
+ Parameters
865
+ ----------
866
+ model
867
+ Trained CellTypist model.
868
+ output_dir : str or Path
869
+ Directory to save artifacts.
870
+ model_name : str
871
+ Base name for output files.
872
+ training_metadata : Dict[str, Any], optional
873
+ Additional metadata to include (e.g., training parameters,
874
+ reference sources).
875
+ custom_colors : Dict[str, str], optional
876
+ Custom color overrides for specific cell types.
877
+
878
+ Returns
879
+ -------
880
+ Dict[str, Path]
881
+ Paths to saved files:
882
+ - model_path: Path to .pkl model file
883
+ - metadata_path: Path to metadata JSON
884
+ - colors_path: Path to color scheme JSON
885
+
886
+ Notes
887
+ -----
888
+ The color scheme is saved separately so it can be loaded by
889
+ visualization functions without loading the full model.
890
+
891
+ File naming:
892
+ - {model_name}.pkl
893
+ - {model_name}_celltypist.json
894
+ - {model_name}_colors.json
895
+
896
+ Examples
897
+ --------
898
+ >>> from spatialcore.annotation.training import save_model_artifacts
899
+ >>> result = train_celltypist_model(adata, label_column="cell_type")
900
+ >>> paths = save_model_artifacts(
901
+ ... result["model"],
902
+ ... output_dir="./models",
903
+ ... model_name="liver_v1",
904
+ ... training_metadata={"reference": "cellxgene"},
905
+ ... )
906
+ >>> print(f"Model saved to: {paths['model_path']}")
907
+ """
908
+ output_dir = Path(output_dir)
909
+ output_dir.mkdir(parents=True, exist_ok=True)
910
+
911
+ # Save model
912
+ model_path = output_dir / f"{model_name}.pkl"
913
+ model.write(str(model_path))
914
+ logger.info(f"Saved model to: {model_path}")
915
+
916
+ # Generate color scheme
917
+ cell_types = list(model.cell_types)
918
+ colors = generate_color_scheme(cell_types, custom_colors)
919
+
920
+ # Save color scheme
921
+ colors_path = output_dir / f"{model_name}_colors.json"
922
+ with open(colors_path, "w") as f:
923
+ json.dump(colors, f, indent=2)
924
+ logger.info(f"Saved color scheme to: {colors_path}")
925
+
926
+ # Build metadata
927
+ try:
928
+ import celltypist
929
+ celltypist_version = celltypist.__version__
930
+ except Exception:
931
+ celltypist_version = "unknown"
932
+
933
+ try:
934
+ import spatialcore
935
+ spatialcore_version = spatialcore.__version__
936
+ except Exception:
937
+ spatialcore_version = "unknown"
938
+
939
+ metadata = {
940
+ "model_name": model_name,
941
+ "created_at": datetime.now().isoformat(),
942
+ "spatialcore_version": spatialcore_version,
943
+ "celltypist_version": celltypist_version,
944
+ "n_genes": len(model.features),
945
+ "n_cell_types": len(model.cell_types),
946
+ "cell_types": cell_types,
947
+ "genes": list(model.features),
948
+ }
949
+
950
+ if training_metadata:
951
+ metadata["training"] = training_metadata
952
+
953
+ # Save metadata
954
+ metadata_path = output_dir / f"{model_name}_celltypist.json"
955
+ with open(metadata_path, "w") as f:
956
+ json.dump(metadata, f, indent=2)
957
+ logger.info(f"Saved metadata to: {metadata_path}")
958
+
959
+ return {
960
+ "model_path": model_path,
961
+ "metadata_path": metadata_path,
962
+ "colors_path": colors_path,
963
+ }
964
+
965
+
966
+ # ============================================================================
967
+ # Balanced Subsampling
968
+ # ============================================================================
969
+
970
+ def _load_target_proportions(
971
+ target_proportions: Union[Dict[str, float], str, Path, None]
972
+ ) -> Optional[Dict[str, float]]:
973
+ """
974
+ Load and validate target proportions from dict, JSON, or CSV.
975
+
976
+ Parameters
977
+ ----------
978
+ target_proportions : dict, str, Path, or None
979
+ - Dict: Used directly
980
+ - str/Path ending in .json: Load as JSON
981
+ - str/Path ending in .csv: Load as CSV with columns (cell_type, proportion)
982
+ - None: Return None
983
+
984
+ Returns
985
+ -------
986
+ dict or None
987
+ Validated proportions dict, or None if input was None.
988
+
989
+ Raises
990
+ ------
991
+ ValueError
992
+ If proportions are invalid (negative, >1.0, or wrong format).
993
+ """
994
+ if target_proportions is None:
995
+ return None
996
+
997
+ # Load from file if path provided
998
+ if isinstance(target_proportions, (str, Path)):
999
+ path = Path(target_proportions)
1000
+
1001
+ if not path.exists():
1002
+ raise ValueError(f"Target proportions file not found: {path}")
1003
+
1004
+ if path.suffix.lower() == ".json":
1005
+ with open(path, "r") as f:
1006
+ props = json.load(f)
1007
+ elif path.suffix.lower() == ".csv":
1008
+ df = pd.read_csv(path)
1009
+ if "cell_type" not in df.columns or "proportion" not in df.columns:
1010
+ raise ValueError(
1011
+ f"CSV must have 'cell_type' and 'proportion' columns. "
1012
+ f"Found: {list(df.columns)}"
1013
+ )
1014
+ props = dict(zip(df["cell_type"], df["proportion"]))
1015
+ else:
1016
+ raise ValueError(
1017
+ f"Unsupported file format: {path.suffix}. Use .json or .csv"
1018
+ )
1019
+ else:
1020
+ props = target_proportions
1021
+
1022
+ # Validate proportions
1023
+ if not isinstance(props, dict):
1024
+ raise ValueError(
1025
+ f"target_proportions must be a dict, got {type(props).__name__}"
1026
+ )
1027
+
1028
+ for cell_type, proportion in props.items():
1029
+ if not isinstance(proportion, (int, float)):
1030
+ raise ValueError(
1031
+ f"Invalid proportion for '{cell_type}': {proportion}. "
1032
+ f"Must be a number."
1033
+ )
1034
+ if proportion < 0 or proportion > 1.0:
1035
+ raise ValueError(
1036
+ f"Invalid proportion for '{cell_type}': {proportion}. "
1037
+ f"Must be between 0.0 and 1.0."
1038
+ )
1039
+
1040
+ return props
1041
+
1042
+
1043
+ def subsample_balanced(
1044
+ adata: ad.AnnData,
1045
+ label_column: str,
1046
+ max_cells_per_type: int = 10000,
1047
+ min_cells_per_type: int = 50,
1048
+ source_column: Optional[str] = "reference_source",
1049
+ source_balance: str = "proportional",
1050
+ min_cells_per_source: int = 50,
1051
+ group_by_column: Optional[str] = None,
1052
+ target_proportions: Optional[Union[Dict[str, float], str, Path]] = None,
1053
+ random_state: int = 42,
1054
+ copy: bool = True,
1055
+ ) -> ad.AnnData:
1056
+ """
1057
+ Source-aware balanced subsampling for multi-reference training data.
1058
+
1059
+ When combining multiple scRNA-seq references, this function ensures each
1060
+ cell type is sampled proportionally from all sources that contain it.
1061
+ This prevents the model from learning source-specific artifacts.
1062
+
1063
+ Parameters
1064
+ ----------
1065
+ adata : AnnData
1066
+ Combined reference data (output of combine_references()).
1067
+ label_column : str
1068
+ Column in adata.obs containing cell type labels (for logging/display).
1069
+ max_cells_per_type : int, default 10000
1070
+ Maximum cells per cell type in output.
1071
+ min_cells_per_type : int, default 50
1072
+ Cell types with fewer cells than this are kept entirely
1073
+ (no subsampling applied).
1074
+ source_column : str, optional, default "reference_source"
1075
+ Column identifying which reference each cell came from.
1076
+ Set to None to disable source-aware balancing (simple capping).
1077
+ source_balance : {"proportional", "equal"}, default "proportional"
1078
+ How to distribute sampling across sources:
1079
+ - "proportional": Draw from each source proportionally to its
1080
+ contribution for that cell type. (RECOMMENDED)
1081
+ - "equal": Draw equally from each source (up to available).
1082
+ min_cells_per_source : int, default 50
1083
+ Minimum cells to draw from a source that has the cell type.
1084
+ Ensures rare sources still contribute to the model.
1085
+ group_by_column : str, optional
1086
+ If provided, group cells by values in this column instead of label_column.
1087
+ This enables semantic grouping where different text labels map to the
1088
+ same identity. For example, using "cell_type_ontology_term_id":
1089
+ - "CD4+ T cells" and "CD4-positive, alpha-beta T cell" -> CL:0000624
1090
+ - Both are grouped together for balancing purposes.
1091
+ If None, groups by label_column text (current behavior).
1092
+ target_proportions : dict, str, or Path, optional
1093
+ Target proportions for specific cell types. Accepts:
1094
+ - Dict mapping cell type names to proportions (0.0-1.0)
1095
+ - Path to JSON file: ``{"NK cell": 0.0025, "plasma cell": 0.001}``
1096
+ - Path to CSV file with columns: ``cell_type``, ``proportion``
1097
+
1098
+ For cell types in this mapping, the target count is calculated as:
1099
+ ``proportion × total_input_cells`` instead of using max_cells_per_type.
1100
+
1101
+ Cell types NOT in the mapping use normal max_cells_per_type capping.
1102
+
1103
+ This is essential for handling pure/enriched references (e.g., FACS-sorted
1104
+ cells) where a cell type exists only in the enriched source and would
1105
+ otherwise dominate training.
1106
+ random_state : int, default 42
1107
+ Random seed for reproducibility.
1108
+ copy : bool, default True
1109
+ Return a copy or modify in-place.
1110
+
1111
+ Returns
1112
+ -------
1113
+ AnnData
1114
+ Subsampled data with source-balanced cell type representation.
1115
+
1116
+ Raises
1117
+ ------
1118
+ ValueError
1119
+ If label_column not found in adata.obs.
1120
+ If source_column specified but not found in adata.obs.
1121
+ If source_balance is not "proportional" or "equal".
1122
+
1123
+ Examples
1124
+ --------
1125
+ >>> from spatialcore.annotation import combine_references, subsample_balanced
1126
+ >>> # Step 1: Combine (no balancing)
1127
+ >>> combined = combine_references(
1128
+ ... reference_paths=["study1.h5ad", "study2.h5ad"],
1129
+ ... label_columns=["cell_type", "cell_type"],
1130
+ ... target_genes=panel_genes,
1131
+ ... )
1132
+ >>> # Step 2: Source-aware balancing
1133
+ >>> balanced = subsample_balanced(
1134
+ ... combined,
1135
+ ... label_column="unified_cell_type",
1136
+ ... max_cells_per_type=10000,
1137
+ ... source_balance="proportional",
1138
+ ... )
1139
+
1140
+ Notes
1141
+ -----
1142
+ **Why source-aware balancing?**
1143
+
1144
+ When combining Study 1 (30K Macrophages) + Study 2 (5K Macrophages) and
1145
+ capping at 10K, naive capping takes mostly from Study 1. The model learns
1146
+ Study 1's version of Macrophage, not a consensus.
1147
+
1148
+ Source-aware balancing draws proportionally: ~8.5K from Study 1, ~1.5K
1149
+ from Study 2, ensuring both studies contribute to each shared cell type.
1150
+
1151
+ **Handling FACS-Enriched / Pure Cell Type References**
1152
+
1153
+ When combining tissue-derived scRNA-seq (natural cell type proportions)
1154
+ with FACS-sorted or enriched populations (e.g., 100% pure T cells, sorted
1155
+ NK cells), special considerations apply:
1156
+
1157
+ - **source_balance="proportional"** (default): Best when all references
1158
+ are tissue-derived with natural proportions. Each source contributes
1159
+ proportionally to its cell count for each type.
1160
+
1161
+ - **source_balance="equal"**: Recommended when combining tissue-derived
1162
+ data with FACS-enriched pure populations. Forces equal contribution
1163
+ from each source, preventing pure populations from overwhelming the
1164
+ model's concept of that cell type.
1165
+
1166
+ Example scenario:
1167
+ - Reference A (tissue): 50K cells with natural proportions (5% T cells = 2.5K)
1168
+ - Reference B (FACS): 100K pure T cells (100% T cells)
1169
+
1170
+ With "proportional": T cells would be 97.5% from FACS, 2.5% from tissue.
1171
+ With "equal": T cells would be 50% from each, learning a consensus.
1172
+
1173
+ **Future Improvements**
1174
+
1175
+ - Automatic detection of enriched/pure cell type sources based on
1176
+ cell type distribution entropy
1177
+ - Per-cell-type source_balance overrides (e.g., use "equal" only for
1178
+ T cells that appear in FACS source)
1179
+ - Integration with Harmony batch correction for multi-source training via PCAs and predicition on spatial data with >5000 genes
1180
+ """
1181
+ if copy:
1182
+ adata = adata.copy()
1183
+
1184
+ rng = np.random.default_rng(random_state)
1185
+
1186
+ # Load and validate target proportions
1187
+ props = _load_target_proportions(target_proportions)
1188
+ if props:
1189
+ logger.info(f"Using target proportions for {len(props)} cell type(s)")
1190
+ for ct, prop in props.items():
1191
+ logger.debug(f" {ct}: {prop:.4f} ({prop*100:.2f}%)")
1192
+
1193
+ # Validate label_column
1194
+ if label_column not in adata.obs.columns:
1195
+ raise ValueError(
1196
+ f"Label column '{label_column}' not found. "
1197
+ f"Available: {list(adata.obs.columns)}"
1198
+ )
1199
+
1200
+ # Determine grouping column: use group_by_column if provided, else label_column
1201
+ if group_by_column is not None:
1202
+ if group_by_column not in adata.obs.columns:
1203
+ raise ValueError(
1204
+ f"Group-by column '{group_by_column}' not found. "
1205
+ f"Available: {list(adata.obs.columns)}"
1206
+ )
1207
+ # Use group_by_column for grouping (e.g., CL IDs)
1208
+ cell_types = adata.obs[group_by_column].astype(str)
1209
+ logger.info(
1210
+ f"Grouping by '{group_by_column}' instead of '{label_column}' "
1211
+ f"(semantic grouping enabled)"
1212
+ )
1213
+ else:
1214
+ cell_types = adata.obs[label_column].astype(str)
1215
+
1216
+ unique_types = cell_types.unique()
1217
+
1218
+ # =========================================================================
1219
+ # Source-unaware mode (simple capping)
1220
+ # =========================================================================
1221
+ if source_column is None:
1222
+ logger.info(
1223
+ f"Subsampling {adata.n_obs:,} cells "
1224
+ f"(cap={max_cells_per_type:,}, no source balancing)"
1225
+ )
1226
+ return _subsample_simple_cap(
1227
+ adata, cell_types, unique_types,
1228
+ max_cells_per_type, min_cells_per_type, rng
1229
+ )
1230
+
1231
+ # =========================================================================
1232
+ # Source-aware mode (nested balancing)
1233
+ # =========================================================================
1234
+ if source_column not in adata.obs.columns:
1235
+ raise ValueError(
1236
+ f"Source column '{source_column}' not found. "
1237
+ f"Available: {list(adata.obs.columns)}. "
1238
+ f"Set source_column=None to disable source-aware balancing."
1239
+ )
1240
+
1241
+ if source_balance not in ("proportional", "equal"):
1242
+ raise ValueError(
1243
+ f"Invalid source_balance: '{source_balance}'. "
1244
+ f"Must be 'proportional' or 'equal'."
1245
+ )
1246
+
1247
+ sources = adata.obs[source_column].astype(str)
1248
+ unique_sources = sources.unique()
1249
+ n_sources = len(unique_sources)
1250
+
1251
+ logger.info(
1252
+ f"Subsampling {adata.n_obs:,} cells "
1253
+ f"(source-aware, {n_sources} sources, {source_balance} balance, "
1254
+ f"max={max_cells_per_type:,}/type)"
1255
+ )
1256
+
1257
+ selected_indices = []
1258
+
1259
+ for cell_type in unique_types:
1260
+ type_mask = cell_types == cell_type
1261
+ type_indices = np.where(type_mask)[0]
1262
+ n_available = len(type_indices)
1263
+
1264
+ # Skip empty types
1265
+ if n_available == 0:
1266
+ continue
1267
+
1268
+ # Keep all if below minimum
1269
+ if n_available <= min_cells_per_type:
1270
+ selected_indices.extend(type_indices)
1271
+ logger.debug(f" {cell_type}: keeping all {n_available} (below min)")
1272
+ continue
1273
+
1274
+ # Identify which sources have this cell type
1275
+ type_sources = sources.iloc[type_indices]
1276
+ sources_with_type = type_sources.unique()
1277
+ n_sources_with_type = len(sources_with_type)
1278
+
1279
+ # Warn if only one source has this type (when multiple sources exist)
1280
+ if n_sources_with_type == 1 and n_sources > 1:
1281
+ logger.warning(
1282
+ f" {cell_type}: only in '{sources_with_type[0]}' "
1283
+ f"(no cross-source balancing)"
1284
+ )
1285
+
1286
+ # Calculate target per cell type
1287
+ if props and cell_type in props:
1288
+ # Use proportion-based target for specified types (Cap & Fill)
1289
+ # Cap: don't exceed proportion × total_cells
1290
+ # Fill: but ensure at least min_cells_per_type for training quality
1291
+ proportion_target = int(props[cell_type] * adata.n_obs)
1292
+ capped_target = min(proportion_target, n_available)
1293
+ target_total = max(min_cells_per_type, capped_target)
1294
+ logger.debug(
1295
+ f" {cell_type}: proportion {props[cell_type]:.4f} -> "
1296
+ f"{proportion_target} cells (cap&fill: {target_total})"
1297
+ )
1298
+ else:
1299
+ # Use normal capping for unspecified types
1300
+ target_total = min(max_cells_per_type, n_available)
1301
+
1302
+ # Calculate per-source targets
1303
+ source_targets = _calculate_source_targets(
1304
+ type_sources=type_sources,
1305
+ sources_with_type=sources_with_type,
1306
+ target_total=target_total,
1307
+ source_balance=source_balance,
1308
+ min_cells_per_source=min_cells_per_source,
1309
+ )
1310
+
1311
+ # Sample from each source
1312
+ for source_name, (target, available) in source_targets.items():
1313
+ source_type_mask = (cell_types == cell_type) & (sources == source_name)
1314
+ source_indices = np.where(source_type_mask)[0]
1315
+
1316
+ if target >= len(source_indices):
1317
+ selected_indices.extend(source_indices)
1318
+ else:
1319
+ sampled = rng.choice(source_indices, size=target, replace=False)
1320
+ selected_indices.extend(sampled)
1321
+
1322
+ # Debug log
1323
+ total_sampled = sum(t for t, _ in source_targets.values())
1324
+ source_summary = ", ".join(f"{s}:{t}" for s, (t, _) in source_targets.items())
1325
+ logger.debug(f" {cell_type}: {n_available} -> {total_sampled} [{source_summary}]")
1326
+
1327
+ # Sort and subset
1328
+ selected_indices = sorted(set(selected_indices))
1329
+ adata_sub = adata[selected_indices].copy()
1330
+
1331
+ # Log final summary
1332
+ new_counts = adata_sub.obs[label_column].value_counts()
1333
+ logger.info(
1334
+ f"Subsampled: {adata.n_obs:,} -> {adata_sub.n_obs:,} cells "
1335
+ f"({len(new_counts)} types)"
1336
+ )
1337
+
1338
+ return adata_sub
1339
+
1340
+
1341
+ def _calculate_source_targets(
1342
+ type_sources: pd.Series,
1343
+ sources_with_type: np.ndarray,
1344
+ target_total: int,
1345
+ source_balance: str,
1346
+ min_cells_per_source: int,
1347
+ ) -> Dict[str, Tuple[int, int]]:
1348
+ """
1349
+ Calculate how many cells to sample from each source for one cell type.
1350
+
1351
+ Parameters
1352
+ ----------
1353
+ type_sources : pd.Series
1354
+ Series of source labels for cells of this type.
1355
+ sources_with_type : np.ndarray
1356
+ Unique sources that have this cell type.
1357
+ target_total : int
1358
+ Total cells to sample for this cell type.
1359
+ source_balance : str
1360
+ "proportional" or "equal".
1361
+ min_cells_per_source : int
1362
+ Minimum cells to draw from each source.
1363
+
1364
+ Returns
1365
+ -------
1366
+ Dict[str, Tuple[int, int]]
1367
+ Mapping: source_name -> (target_count, available_count)
1368
+ """
1369
+ source_counts = type_sources.value_counts().to_dict()
1370
+ total_available = sum(source_counts.values())
1371
+ n_sources = len(sources_with_type)
1372
+
1373
+ targets = {}
1374
+
1375
+ if source_balance == "proportional":
1376
+ # Draw proportionally to each source's contribution
1377
+ for source_name in sources_with_type:
1378
+ available = source_counts[source_name]
1379
+ proportion = available / total_available
1380
+ target = int(np.ceil(target_total * proportion))
1381
+
1382
+ # Enforce minimum (if source has enough cells)
1383
+ if available >= min_cells_per_source:
1384
+ target = max(target, min_cells_per_source)
1385
+
1386
+ # Can't exceed available
1387
+ target = min(target, available)
1388
+ targets[source_name] = (target, available)
1389
+
1390
+ elif source_balance == "equal":
1391
+ # Draw equally from each source
1392
+ per_source = target_total // n_sources
1393
+ remainder = target_total % n_sources
1394
+
1395
+ for i, source_name in enumerate(sorted(sources_with_type)):
1396
+ available = source_counts[source_name]
1397
+ target = per_source + (1 if i < remainder else 0)
1398
+ target = min(target, available)
1399
+ targets[source_name] = (target, available)
1400
+
1401
+ # Redistribute shortfall (when some sources can't provide enough)
1402
+ total_targeted = sum(t for t, _ in targets.values())
1403
+ shortfall = target_total - total_targeted
1404
+
1405
+ if shortfall > 0:
1406
+ for source_name in sources_with_type:
1407
+ if shortfall <= 0:
1408
+ break
1409
+ target, available = targets[source_name]
1410
+ capacity = available - target
1411
+ if capacity > 0:
1412
+ additional = min(capacity, shortfall)
1413
+ targets[source_name] = (target + additional, available)
1414
+ shortfall -= additional
1415
+
1416
+ return targets
1417
+
1418
+
1419
+ def _subsample_simple_cap(
1420
+ adata: ad.AnnData,
1421
+ cell_types: pd.Series,
1422
+ unique_types: np.ndarray,
1423
+ max_cells_per_type: int,
1424
+ min_cells_per_type: int,
1425
+ rng: np.random.Generator,
1426
+ ) -> ad.AnnData:
1427
+ """
1428
+ Simple per-type capping without source awareness.
1429
+
1430
+ Used when source_column=None is passed to subsample_balanced().
1431
+ """
1432
+ selected_indices = []
1433
+
1434
+ for cell_type in unique_types:
1435
+ type_mask = cell_types == cell_type
1436
+ type_indices = np.where(type_mask)[0]
1437
+ n_available = len(type_indices)
1438
+
1439
+ if n_available == 0:
1440
+ continue
1441
+ elif n_available <= min_cells_per_type:
1442
+ selected_indices.extend(type_indices)
1443
+ elif n_available <= max_cells_per_type:
1444
+ selected_indices.extend(type_indices)
1445
+ else:
1446
+ sampled = rng.choice(type_indices, size=max_cells_per_type, replace=False)
1447
+ selected_indices.extend(sampled)
1448
+
1449
+ selected_indices = sorted(selected_indices)
1450
+ return adata[selected_indices].copy()
1451
+
1452
+
1453
+
1454
+ # NOTE: harmonize_labels() function was removed in favor of using:
1455
+ # 1. add_ontology_ids() to fill missing CL IDs
1456
+ # 2. subsample_balanced(group_by_column="cell_type_ontology_term_id") for semantic grouping
1457
+ # See pipeline.py for the new recommended workflow.