scdataloader 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,39 +69,98 @@ class DataModule(L.LightningDataModule):
69
69
  **kwargs,
70
70
  ):
71
71
  """
72
- DataModule a pytorch lighting datamodule directly from a lamin Collection.
73
- it can work with bare pytorch too
72
+ PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.
74
73
 
75
- It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
76
- This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
74
+ This DataModule provides train/val/test dataloaders with configurable sampling strategies.
75
+ It combines MappedCollection, Dataset, and Collator to create efficient data pipelines
76
+ for training single-cell foundation models.
77
+
78
+ The training dataloader uses weighted random sampling based on class frequencies,
79
+ validation uses random sampling, and test uses sequential sampling on held-out datasets.
77
80
 
78
81
  Args:
79
- collection_name (str): The lamindb collection to be used.
80
- weight_scaler (int, optional): how much more you will see the most present vs less present category.
81
- n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
82
- validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
83
- test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
84
- it will use a full dataset and will round to the nearest dataset's cell count.
85
- use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
86
- clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
87
- assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
88
- gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
89
- the file must have ensembl_gene_id as index.
90
- This is used to subset the available genes further to the ones that have embeddings in your model.
91
- max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
92
- how (str, optional): The method to use for the collator. Defaults to "random expr".
93
- organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
94
- tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
95
- hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
96
- metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
97
- clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
98
- get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
99
- store_location (str, optional): The location to store the sampler indices. Defaults to None.
100
- force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
101
- sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
102
- sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
103
- **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
104
- see @file data.py and @file collator.py for more details about some of the parameters
82
+ collection_name (str): Key of the LaminDB Collection to load.
83
+ clss_to_weight (List[str], optional): Label columns to use for weighted sampling
84
+ in the training dataloader. Supports "nnz" for weighting by number of
85
+ non-zero genes. Defaults to ["organism_ontology_term_id"].
86
+ weight_scaler (int, optional): Controls balance between rare and common classes.
87
+ Higher values lead to more uniform sampling across classes. Set to 0 to
88
+ disable weighted sampling. Defaults to 10.
89
+ n_samples_per_epoch (int, optional): Number of samples to draw per training epoch.
90
+ Defaults to 2,000,000.
91
+ validation_split (float | int, optional): Proportion (float) or absolute number (int)
92
+ of samples for validation. Defaults to 0.2.
93
+ test_split (float | int, optional): Proportion (float) or absolute number (int)
94
+ of samples for testing. Uses entire datasets as test sets, rounding to
95
+ nearest dataset boundary. Defaults to 0.
96
+ use_default_col (bool, optional): Whether to use the default Collator for batch
97
+ preparation. If False, no collate_fn is applied. Defaults to True.
98
+ clss_to_predict (List[str], optional): Observation columns to encode as prediction
99
+ targets. Must include "organism_ontology_term_id". Defaults to
100
+ ["organism_ontology_term_id"].
101
+ hierarchical_clss (List[str], optional): Observation columns with hierarchical
102
+ ontology structure to be processed. Defaults to [].
103
+ how (str, optional): Gene selection strategy passed to Collator. One of
104
+ "most expr", "random expr", "all", "some". Defaults to "random expr".
105
+ organism_col (str, optional): Column name for organism ontology term ID.
106
+ Defaults to "organism_ontology_term_id".
107
+ max_len (int, optional): Maximum number of genes per sample passed to Collator.
108
+ Defaults to 1000.
109
+ replacement (bool, optional): Whether to sample with replacement in training.
110
+ Defaults to True.
111
+ gene_subset (List[str], optional): List of genes to restrict the dataset to.
112
+ Useful when model only supports specific genes. Defaults to None.
113
+ tp_name (str, optional): Column name for time point or heat diffusion values.
114
+ Defaults to None.
115
+ assays_to_drop (List[str], optional): List of assay ontology term IDs to exclude
116
+ from training. Defaults to ["EFO:0030007"] (ATAC-seq).
117
+ metacell_mode (float, optional): Probability of using metacell aggregation mode.
118
+ Cannot be used with get_knn_cells. Defaults to 0.0.
119
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
120
+ expression data. Cannot be used with metacell_mode. Defaults to False.
121
+ store_location (str, optional): Directory path to cache sampler indices and
122
+ labels for faster subsequent loading. Defaults to None.
123
+ force_recompute_indices (bool, optional): Force recomputation of cached indices
124
+ even if they exist. Defaults to False.
125
+ sampler_workers (int, optional): Number of parallel workers for building sampler
126
+ indices. Auto-determined based on available CPUs if None. Defaults to None.
127
+ sampler_chunk_size (int, optional): Chunk size for parallel sampler processing.
128
+ Auto-determined based on available memory if None. Defaults to None.
129
+ organisms (List[str], optional): List of organisms to include. If None, uses
130
+ all organisms in the dataset. Defaults to None.
131
+ genedf (pd.DataFrame, optional): Gene information DataFrame. If None, loaded
132
+ automatically. Defaults to None.
133
+ n_bins (int, optional): Number of bins for expression discretization. 0 means
134
+ no binning. Defaults to 0.
135
+ curiculum (int, optional): Curriculum learning parameter. If > 0, gradually
136
+ increases sampling weight balance over epochs. Defaults to 0.
137
+ start_at (int, optional): Starting index for resuming inference. Requires same
138
+ number of GPUs as previous run. Defaults to 0.
139
+ **kwargs: Additional arguments passed to PyTorch DataLoader (e.g., batch_size,
140
+ num_workers, pin_memory).
141
+
142
+ Attributes:
143
+ dataset (Dataset): The underlying Dataset instance.
144
+ classes (dict[str, int]): Mapping from class names to number of categories.
145
+ train_labels (np.ndarray): Label array for weighted sampling.
146
+ idx_full (np.ndarray): Indices for training samples.
147
+ valid_idx (np.ndarray): Indices for validation samples.
148
+ test_idx (np.ndarray): Indices for test samples.
149
+ test_datasets (List[str]): Paths to datasets used for testing.
150
+
151
+ Raises:
152
+ ValueError: If "organism_ontology_term_id" not in clss_to_predict.
153
+ ValueError: If both metacell_mode > 0 and get_knn_cells are True.
154
+
155
+ Example:
156
+ >>> dm = DataModule(
157
+ ... collection_name="my_collection",
158
+ ... batch_size=32,
159
+ ... num_workers=4,
160
+ ... max_len=2000,
161
+ ... )
162
+ >>> dm.setup()
163
+ >>> train_loader = dm.train_dataloader()
105
164
  """
106
165
  if "organism_ontology_term_id" not in clss_to_predict:
107
166
  raise ValueError(
@@ -131,6 +190,7 @@ class DataModule(L.LightningDataModule):
131
190
  self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
132
191
  # we might want not to order the genes by expression (or do it?)
133
192
  # we might want to not introduce zeros and
193
+
134
194
  if use_default_col:
135
195
  kwargs["collate_fn"] = Collator(
136
196
  organisms=mdataset.organisms if organisms is None else organisms,
@@ -285,12 +345,24 @@ class DataModule(L.LightningDataModule):
285
345
 
286
346
  def setup(self, stage=None):
287
347
  """
288
- setup method is used to prepare the data for the training, validation, and test sets.
289
- It shuffles the data, calculates weights for each set, and creates samplers for each set.
348
+ Prepare data splits for training, validation, and testing.
349
+
350
+ This method shuffles the data, computes sample weights for weighted sampling,
351
+ removes samples from dropped assays, and creates train/val/test splits.
352
+ Test splits use entire datasets to ensure evaluation on unseen data sources.
353
+
354
+ Results can be cached to `store_location` for faster subsequent runs.
290
355
 
291
356
  Args:
292
- stage (str, optional): The stage of the model training process.
293
- It can be either 'fit' or 'test'. Defaults to None.
357
+ stage (str, optional): Training stage ('fit', 'test', or None for both).
358
+ Currently not used but kept for Lightning compatibility. Defaults to None.
359
+
360
+ Returns:
361
+ List[str]: List of paths to test datasets.
362
+
363
+ Note:
364
+ Must be called before using dataloaders. The train/val/test split is
365
+ deterministic when loading from cache.
294
366
  """
295
367
  print("setting up the datamodule")
296
368
  start_time = time.time()
@@ -441,6 +513,22 @@ class DataModule(L.LightningDataModule):
441
513
  return self.test_datasets
442
514
 
443
515
  def train_dataloader(self, **kwargs):
516
+ """
517
+ Create the training DataLoader with weighted random sampling.
518
+
519
+ Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0
520
+ and clss_to_weight is specified. Otherwise uses RankShardSampler for
521
+ distributed training without weighting.
522
+
523
+ Args:
524
+ **kwargs: Additional arguments passed to DataLoader, overriding defaults.
525
+
526
+ Returns:
527
+ DataLoader: Training DataLoader instance.
528
+
529
+ Raises:
530
+ ValueError: If setup() has not been called.
531
+ """
444
532
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
445
533
  try:
446
534
  print("Setting up the parallel train sampler...")
@@ -473,6 +561,12 @@ class DataModule(L.LightningDataModule):
473
561
  )
474
562
 
475
563
  def val_dataloader(self):
564
+ """
565
+ Create the validation DataLoader.
566
+
567
+ Returns:
568
+ DataLoader | List: Validation DataLoader, or empty list if no validation split.
569
+ """
476
570
  return (
477
571
  DataLoader(
478
572
  Subset(self.dataset, self.valid_idx),
@@ -483,6 +577,12 @@ class DataModule(L.LightningDataModule):
483
577
  )
484
578
 
485
579
  def test_dataloader(self):
580
+ """
581
+ Create the test DataLoader with sequential sampling.
582
+
583
+ Returns:
584
+ DataLoader | List: Test DataLoader, or empty list if no test split.
585
+ """
486
586
  return (
487
587
  DataLoader(
488
588
  self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
@@ -492,6 +592,14 @@ class DataModule(L.LightningDataModule):
492
592
  )
493
593
 
494
594
  def predict_dataloader(self):
595
+ """
596
+ Create a DataLoader for prediction over all training data.
597
+
598
+ Uses RankShardSampler for distributed inference.
599
+
600
+ Returns:
601
+ DataLoader: Prediction DataLoader instance.
602
+ """
495
603
  subset = Subset(self.dataset, self.idx_full)
496
604
  return DataLoader(
497
605
  self.dataset,
@@ -501,15 +609,6 @@ class DataModule(L.LightningDataModule):
501
609
 
502
610
 
503
611
  class LabelWeightedSampler(Sampler[int]):
504
- """
505
- A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
506
-
507
- This sampler is designed to handle very large datasets efficiently, with optimizations for:
508
- 1. Parallel building of class indices
509
- 2. Chunked processing for large arrays
510
- 3. Efficient memory management
511
- 4. Proper handling of replacement and non-replacement sampling
512
- """
513
612
 
514
613
  label_weights: torch.Tensor
515
614
  klass_indices: dict[int, torch.Tensor]
@@ -531,16 +630,58 @@ class LabelWeightedSampler(Sampler[int]):
531
630
  curiculum: int = 0,
532
631
  ) -> None:
533
632
  """
534
- Initialize the sampler with parallel processing for large datasets.
633
+ Weighted random sampler balancing both class frequencies and element weights.
634
+
635
+ This sampler is optimized for very large datasets (millions of samples) with:
636
+ - Parallel construction of class indices using multiple CPU workers
637
+ - Chunked processing to manage memory usage
638
+ - Support for curriculum learning via progressive weight scaling
639
+ - Optional per-element weights (e.g., by number of expressed genes)
640
+
641
+ The sampling process:
642
+ 1. Sample class labels according to class weights
643
+ 2. For each sampled class, sample elements according to element weights
644
+ 3. Shuffle all sampled indices
535
645
 
536
646
  Args:
537
- weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
538
- labels: Class label for each dataset element (length = dataset size)
539
- num_samples: Number of samples to draw
540
- replacement: Whether to sample with replacement
541
- element_weights: Optional weights for each element within classes
542
- n_workers: Number of parallel workers to use (default: number of CPUs-1)
543
- chunk_size: Size of chunks to process in parallel (default: 10M elements)
647
+ labels (np.ndarray): Integer class label for each dataset element.
648
+ Shape: (dataset_size,). The last unique label is treated as
649
+ "excluded" with weight 0.
650
+ num_samples (int): Number of samples to draw per epoch.
651
+ replacement (bool, optional): Whether to sample with replacement.
652
+ Defaults to True.
653
+ weight_scaler (float, optional): Controls class weight balance.
654
+ Weight formula: (scaler * count) / (count + scaler).
655
+ Higher values = more uniform sampling. Defaults to None.
656
+ element_weights (Sequence[float], optional): Per-element sampling weights.
657
+ Shape: (dataset_size,). Defaults to None (uniform within class).
658
+ n_workers (int, optional): Number of parallel workers for index building.
659
+ Defaults to min(20, num_cpus - 1).
660
+ chunk_size (int, optional): Elements per chunk for parallel processing.
661
+ Auto-determined based on available memory if None.
662
+ store_location (str, optional): Directory to cache computed indices.
663
+ Defaults to None.
664
+ force_recompute_indices (bool, optional): Recompute indices even if cached.
665
+ Defaults to False.
666
+ curiculum (int, optional): Curriculum learning epochs. If > 0, weight
667
+ exponent increases from 0 to 1 over this many epochs. Defaults to 0.
668
+
669
+ Attributes:
670
+ label_weights (torch.Tensor): Computed weights per class label.
671
+ klass_indices (torch.Tensor): Concatenated indices for all classes.
672
+ klass_offsets (torch.Tensor): Starting offset for each class in klass_indices.
673
+ count (int): Number of times __iter__ has been called (for curriculum).
674
+
675
+ Example:
676
+ >>> sampler = LabelWeightedSampler(
677
+ ... labels=train_labels,
678
+ ... num_samples=1_000_000,
679
+ ... weight_scaler=10,
680
+ ... element_weights=nnz_weights,
681
+ ... )
682
+ >>> for idx in sampler:
683
+ ... # Process sample at idx
684
+ ... pass
544
685
  """
545
686
  print("Initializing optimized parallel weighted sampler...")
546
687
  super(LabelWeightedSampler, self).__init__(None)
@@ -851,8 +992,32 @@ class LabelWeightedSampler(Sampler[int]):
851
992
 
852
993
 
853
994
  class RankShardSampler(Sampler[int]):
854
- """Shards a dataset contiguously across ranks without padding or duplicates.
855
- Preserves the existing order (e.g., your pre-shuffled idx_full)."""
995
+ """
996
+ Sampler that shards data contiguously across distributed ranks.
997
+
998
+ Divides the dataset into contiguous chunks, one per rank, without
999
+ padding or duplicating samples. Preserves the original data order
1000
+ within each shard (useful for pre-shuffled data).
1001
+
1002
+ Args:
1003
+ data_len (int): Total number of samples in the dataset.
1004
+ start_at (int, optional): Global starting index for resuming training.
1005
+ Requires the same number of GPUs as the previous run. Defaults to 0.
1006
+
1007
+ Attributes:
1008
+ rank (int): Current process rank (0 if not distributed).
1009
+ world_size (int): Total number of processes (1 if not distributed).
1010
+ start (int): Starting index for this rank's shard.
1011
+ end (int): Ending index (exclusive) for this rank's shard.
1012
+
1013
+ Note:
1014
+ The last rank may have fewer samples than others if the dataset
1015
+ size is not evenly divisible by world_size.
1016
+
1017
+ Example:
1018
+ >>> sampler = RankShardSampler(len(dataset))
1019
+ >>> loader = DataLoader(dataset, sampler=sampler)
1020
+ """
856
1021
 
857
1022
  def __init__(self, data_len: int, start_at: int = 0) -> None:
858
1023
  self.data_len = data_len
@@ -866,13 +1031,13 @@ class RankShardSampler(Sampler[int]):
866
1031
  # contiguous chunk per rank (last rank may be shorter)
867
1032
  if self.start_at > 0:
868
1033
  print(
869
- "!!!!ATTTENTION: make sure that you are running on the exact same \
870
- number of GPU as your previous run!!!!!"
1034
+ "!!!!ATTENTION: make sure that you are running on the exact same \
1035
+ number of GPU as your previous run!!!!!"
871
1036
  )
872
1037
  print(f"Sharding data of size {data_len} over {self.world_size} ranks")
873
- per_rank = math.ceil((self.data_len - self.start_at) / self.world_size)
1038
+ per_rank = math.ceil(self.data_len / self.world_size)
874
1039
  self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
875
- self.end = min(self.start + per_rank, self.data_len)
1040
+ self.end = min((self.rank + 1) * per_rank, self.data_len)
876
1041
  print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
877
1042
 
878
1043
  def __iter__(self):
@@ -1,3 +1,10 @@
1
+ """
2
+ Preprocessing utilities for single-cell gene expression data.
3
+
4
+ This module provides functions for normalizing, transforming, and discretizing
5
+ gene expression values for use with scPRINT and similar models.
6
+ """
7
+
1
8
  import gc
2
9
  import time
3
10
  from typing import Callable, List, Optional, Union
@@ -664,39 +671,41 @@ def is_log1p(adata: AnnData) -> bool:
664
671
  return True
665
672
 
666
673
 
667
- def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
674
+ def _digitize(values: np.ndarray, bins: np.ndarray) -> np.ndarray:
668
675
  """
669
- Digitize the data into bins. This method spreads data uniformly when bins
670
- have same values.
676
+ Digitize values into discrete bins with 1-based indexing.
671
677
 
672
- Args:
678
+ Similar to np.digitize but ensures output is 1-indexed (bin 0 reserved for
679
+ zero values) and handles edge cases for expression binning.
673
680
 
674
- x (:class:`np.ndarray`):
675
- The data to digitize.
676
- bins (:class:`np.ndarray`):
677
- The bins to use for digitization, in increasing order.
678
- side (:class:`str`, optional):
679
- The side to use for digitization. If "one", the left side is used. If
680
- "both", the left and right side are used. Default to "one".
681
+ Args:
682
+ values (np.ndarray): Array of values to discretize. Should be non-zero
683
+ expression values.
684
+ bins (np.ndarray): Bin edges from np.quantile or similar. Values are
685
+ assigned to bins based on which edges they fall between.
681
686
 
682
687
  Returns:
683
-
684
- :class:`np.ndarray`:
685
- The digitized data.
688
+ np.ndarray: Integer bin indices, 1-indexed. Values equal to bins[i]
689
+ are assigned to bin i+1.
690
+
691
+ Example:
692
+ >>> values = np.array([0.5, 1.5, 2.5, 3.5])
693
+ >>> bins = np.array([1.0, 2.0, 3.0])
694
+ >>> _digitize(values, bins)
695
+ array([1, 2, 3, 3])
696
+
697
+ Note:
698
+ This function is used internally by the Collator for expression binning.
699
+ Zero values should be handled separately before calling this function.
686
700
  """
687
- assert x.ndim == 1 and bins.ndim == 1
688
-
689
- left_digits = np.digitize(x, bins)
690
- if side == "one":
691
- return left_digits
701
+ assert values.ndim == 1 and bins.ndim == 1
692
702
 
693
- right_difits = np.digitize(x, bins, right=True)
703
+ left_digits = np.digitize(values, bins)
704
+ return left_digits
694
705
 
695
- rands = np.random.rand(len(x)) # uniform random numbers
696
706
 
697
- digits = rands * (right_difits - left_digits) + left_digits
698
- digits = np.ceil(digits).astype(np.int64)
699
- return digits
707
+ # Add documentation for any other functions in preprocess.py
708
+ # ...existing code...
700
709
 
701
710
 
702
711
  def binning(row: np.ndarray, n_bins: int) -> np.ndarray: