scdataloader 2.0.2__tar.gz → 2.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdataloader
3
- Version: 2.0.2
3
+ Version: 2.0.4
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Project-URL: repository, https://github.com/jkobject/scDataLoader
6
6
  Author-email: jkobject <jkobject@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdataloader"
3
- version = "2.0.2"
3
+ version = "2.0.4"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = [
6
6
  {name = "jkobject", email = "jkobject@gmail.com"}
@@ -15,6 +15,7 @@ dependencies = [
15
15
  "cellxgene-census>=0.1.0",
16
16
  "torch>=2.2.0",
17
17
  "pytorch-lightning>=2.3.0",
18
+ "lightning>=2.3.0",
18
19
  "anndata>=0.9.0",
19
20
  "zarr>=2.10.0",
20
21
  "matplotlib>=3.5.0",
@@ -27,8 +28,6 @@ dependencies = [
27
28
  "django>=4.0.0",
28
29
  "scikit-misc>=0.5.0",
29
30
  "jupytext>=1.16.0",
30
- "lightning>=2.3.0",
31
- "pytorch-lightning>=2.3.0",
32
31
  ]
33
32
 
34
33
  [project.optional-dependencies]
@@ -27,37 +27,60 @@ class Collator:
27
27
  genedf: Optional[pd.DataFrame] = None,
28
28
  ):
29
29
  """
30
- This class is responsible for collating data for the scPRINT model. It handles the
31
- organization and preparation of gene expression data from different organisms,
32
- allowing for various configurations such as maximum gene list length, normalization,
33
- and selection method for gene expression.
30
+ Collator for preparing gene expression data batches for the scPRINT model.
34
31
 
35
- This Collator should work with scVI's dataloader as well!
32
+ This class handles the organization and preparation of gene expression data from
33
+ different organisms, allowing for various configurations such as maximum gene list
34
+ length, normalization, binning, and gene selection strategies.
35
+
36
+ Compatible with scVI's dataloader and other PyTorch data loading pipelines.
36
37
 
37
38
  Args:
38
- organisms (list): List of organisms to be considered for gene expression data.
39
- it will drop any other organism it sees (might lead to batches of different sizes!)
40
- how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
41
- one of ["most expr", "random expr", "all", "some"]:
42
- "most expr": selects the max_len most expressed genes,
43
- if less genes are expressed, will sample random unexpressed genes,
44
- "random expr": uses a random set of max_len expressed genes.
45
- if less genes are expressed, will sample random unexpressed genes
46
- "all": uses all genes
47
- "some": uses only the genes provided through the genelist param
48
- org_to_id (dict): Dictionary mapping organisms to their respective IDs.
49
- valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
50
- it will drop any other genes from the input expression data (usefull when your model only works on some genes)
51
- max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
52
- n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
53
- add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
54
- logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
55
- norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
56
- organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
57
- tp_name (str, optional): Name of the heat diff. Defaults to None.
58
- class_names (list, optional): List of other classes to be considered. Defaults to [].
59
- genelist (list, optional): List of genes to be considered. Defaults to [].
60
- If [] all genes will be considered
39
+ organisms (List[str]): List of organism ontology term IDs to include.
40
+ Samples from other organisms will be dropped (may lead to variable batch sizes).
41
+ how (str, optional): Gene selection strategy. Defaults to "all".
42
+ - "most expr": Select the `max_len` most expressed genes. If fewer genes
43
+ are expressed, randomly sample unexpressed genes to fill.
44
+ - "random expr": Randomly select `max_len` expressed genes. If fewer genes
45
+ are expressed, randomly sample unexpressed genes to fill.
46
+ - "all": Use all genes without filtering.
47
+ - "some": Use only genes specified in the `genelist` parameter.
48
+ org_to_id (dict[str, int], optional): Mapping from organism names to integer IDs.
49
+ If None, organism names are used directly. Defaults to None.
50
+ valid_genes (List[str], optional): List of gene names to consider from input data.
51
+ Genes not in this list will be dropped. Useful when the model only supports
52
+ specific genes. Defaults to None (use all genes).
53
+ max_len (int, optional): Maximum number of genes to include when using "most expr"
54
+ or "random expr" selection methods. Defaults to 2000.
55
+ add_zero_genes (int, optional): Number of additional unexpressed genes to include
56
+ in the output. Only applies when `how` is "most expr" or "random expr".
57
+ Defaults to 0.
58
+ logp1 (bool, optional): Apply log2(1 + x) transformation to expression values.
59
+ Applied after normalization if both are enabled. Defaults to False.
60
+ norm_to (float, optional): Target sum for count normalization. Expression values
61
+ are scaled so that total counts equal this value. Defaults to None (no normalization).
62
+ n_bins (int, optional): Number of bins for expression value binning. If 0, no
63
+ binning is applied. Binning uses quantile-based discretization. Defaults to 0.
64
+ tp_name (str, optional): Column name in batch data for time point or heat diffusion
65
+ values. If None, time point values default to 0. Defaults to None.
66
+ organism_name (str, optional): Column name in batch data for organism ontology
67
+ term ID. Defaults to "organism_ontology_term_id".
68
+ class_names (List[str], optional): List of additional metadata column names to
69
+ include in the output. Defaults to [].
70
+ genelist (List[str], optional): List of specific genes to use when `how="some"`.
71
+ Required if `how="some"`. Defaults to [].
72
+ genedf (pd.DataFrame, optional): DataFrame containing gene information indexed by
73
+ gene name with an 'organism' column. If None, loaded automatically using
74
+ `load_genes()`. Defaults to None.
75
+
76
+ Attributes:
77
+ organism_ids (set): Set of organism IDs being processed.
78
+ start_idx (dict): Mapping from organism ID to starting gene index in the model.
79
+ accepted_genes (dict): Boolean masks for valid genes per organism.
80
+ to_subset (dict): Boolean masks for genelist filtering per organism.
81
+
82
+ Raises:
83
+ AssertionError: If `how="some"` but `genelist` is empty.
61
84
  """
62
85
  self.organisms = organisms
63
86
  self.max_len = max_len
@@ -77,6 +100,19 @@ class Collator:
77
100
  self._setup(genedf, org_to_id, valid_genes, genelist)
78
101
 
79
102
  def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
103
+ """
104
+ Initialize gene mappings and indices for each organism.
105
+
106
+ Sets up internal data structures for gene filtering, organism-specific
107
+ gene indices, and gene subsetting based on the provided configuration.
108
+
109
+ Args:
110
+ genedf (pd.DataFrame, optional): Gene information DataFrame. If None,
111
+ loaded via `load_genes()`. Defaults to None.
112
+ org_to_id (dict, optional): Organism name to ID mapping. Defaults to None.
113
+ valid_genes (List[str], optional): Genes to accept from input. Defaults to [].
114
+ genelist (List[str], optional): Genes to subset to when `how="some"`. Defaults to [].
115
+ """
80
116
  if genedf is None:
81
117
  genedf = load_genes(self.organisms)
82
118
  self.organism_ids = (
@@ -108,18 +144,45 @@ class Collator:
108
144
 
109
145
  def __call__(self, batch) -> dict[str, Tensor]:
110
146
  """
111
- __call__ applies the collator to a minibatch of data
147
+ Collate a minibatch of gene expression data.
148
+
149
+ Processes a list of sample dictionaries, applying gene selection, normalization,
150
+ log transformation, and binning as configured. Filters out samples from organisms
151
+ not in the configured organism list.
112
152
 
113
153
  Args:
114
- batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
115
- the first list is for the different samples, the second list is for the different elements with
116
- elem["X"]: gene expression
117
- elem["organism_name"]: organism ontology term id
118
- elem["tp_name"]: heat diff
119
- elem["class_names.."]: other classes
154
+ batch (List[dict]): List of sample dictionaries, each containing:
155
+ - "X" (array): Gene expression values.
156
+ - organism_name (any): Organism identifier (column name set by `organism_name`).
157
+ - tp_name (float, optional): Time point value (column name set by `tp_name`).
158
+ - class_names... (any, optional): Additional class labels.
159
+ - "_storage_idx" (int, optional): Dataset storage index.
160
+ - "is_meta" (int, optional): Metadata flag.
161
+ - "knn_cells" (array, optional): KNN neighbor expression data.
162
+ - "knn_cells_info" (array, optional): KNN neighbor metadata.
120
163
 
121
164
  Returns:
122
- List[Tensor]: List of tensors containing the collated data.
165
+ dict[str, Tensor]: Dictionary containing collated tensors:
166
+ - "x" (Tensor): Gene expression matrix of shape (batch_size, n_genes).
167
+ Values may be raw counts, normalized, log-transformed, or binned
168
+ depending on configuration.
169
+ - "genes" (Tensor): Gene indices of shape (batch_size, n_genes) as int32.
170
+ Indices correspond to positions in the model's gene vocabulary.
171
+ - "class" (Tensor): Class labels of shape (batch_size, n_classes) as int32.
172
+ - "tp" (Tensor): Time point values of shape (batch_size,).
173
+ - "depth" (Tensor): Total counts per cell of shape (batch_size,).
174
+ - "is_meta" (Tensor, optional): Metadata flags as int32. Present if input
175
+ contains "is_meta".
176
+ - "knn_cells" (Tensor, optional): KNN expression data. Present if input
177
+ contains "knn_cells".
178
+ - "knn_cells_info" (Tensor, optional): KNN metadata. Present if input
179
+ contains "knn_cells_info".
180
+ - "dataset" (Tensor, optional): Dataset indices as int64. Present if input
181
+ contains "_storage_idx".
182
+
183
+ Note:
184
+ Batch size in output may be smaller than input if some samples are filtered
185
+ out due to organism mismatch.
123
186
  """
124
187
  # do count selection
125
188
  # get the unseen info and don't add any unseen
@@ -22,30 +22,66 @@ from .mapped import MappedCollection, _Connect
22
22
  @dataclass
23
23
  class Dataset(torchDataset):
24
24
  """
25
- Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
25
+ PyTorch Dataset for loading single-cell data from a LaminDB Collection.
26
26
 
27
- This serves as a wrapper around lamin's mappedCollection to provide more features,
28
- mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
27
+ This class wraps LaminDB's MappedCollection to provide additional features:
28
+ - Management of hierarchical ontology labels (cell type, tissue, disease, etc.)
29
+ - Automatic encoding of categorical labels to integers
30
+ - Multi-species gene handling with unified gene indexing
31
+ - Optional metacell aggregation and KNN neighbor retrieval
29
32
 
30
- For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
31
-
32
- .. note::
33
-
34
- A related data loader exists `here
35
- <https://github.com/Genentech/scimilarity>`__.
33
+ The dataset lazily loads data from storage, making it memory-efficient for
34
+ large collections spanning multiple files.
36
35
 
37
36
  Args:
38
- ----
39
- lamin_dataset (lamindb.Dataset): lamin dataset to load
40
- genedf (pd.Dataframe): dataframe containing the genes to load
41
- obs (List[str]): list of observations to load from the Collection
42
- clss_to_predict (List[str]): list of observations to encode
43
- join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
44
- hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
45
- metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
46
- get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
47
- store_location (str, optional): The location to store the sampler indices. Defaults to None.
48
- force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
37
+ lamin_dataset (ln.Collection): LaminDB Collection containing the artifacts to load.
38
+ genedf (pd.DataFrame, optional): DataFrame with gene information, indexed by gene ID
39
+ with an 'organism' column. If None, automatically loaded based on organisms
40
+ in the dataset. Defaults to None.
41
+ clss_to_predict (List[str], optional): Observation columns to encode as prediction
42
+ targets. These will be integer-encoded in the output. Defaults to [].
43
+ hierarchical_clss (List[str], optional): Observation columns with hierarchical
44
+ ontology structure. These will have their ancestry relationships computed
45
+ using Bionty. Supported columns:
46
+ - "cell_type_ontology_term_id"
47
+ - "tissue_ontology_term_id"
48
+ - "disease_ontology_term_id"
49
+ - "development_stage_ontology_term_id"
50
+ - "assay_ontology_term_id"
51
+ - "self_reported_ethnicity_ontology_term_id"
52
+ Defaults to [].
53
+ join_vars (str, optional): How to join variables across artifacts.
54
+ "inner" for intersection, "outer" for union, None for no joining.
55
+ Defaults to None.
56
+ metacell_mode (float, optional): Probability of returning aggregated metacell
57
+ expression instead of single-cell. Defaults to 0.0.
58
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
59
+ expression in the output. Requires precomputed neighbors in the data.
60
+ Defaults to False.
61
+ store_location (str, optional): Directory path to cache computed indices.
62
+ Defaults to None.
63
+ force_recompute_indices (bool, optional): Force recomputation of cached data.
64
+ Defaults to False.
65
+
66
+ Attributes:
67
+ mapped_dataset (MappedCollection): Underlying mapped collection for data access.
68
+ genedf (pd.DataFrame): Gene information DataFrame.
69
+ organisms (List[str]): List of organism ontology term IDs in the dataset.
70
+ class_topred (dict[str, set]): Mapping from class name to set of valid labels.
71
+ labels_groupings (dict[str, dict]): Hierarchical groupings for ontology classes.
72
+ encoder (dict[str, dict]): Label encoders mapping strings to integers.
73
+
74
+ Raises:
75
+ ValueError: If genedf is None and "organism_ontology_term_id" is not in clss_to_predict.
76
+
77
+ Example:
78
+ >>> collection = ln.Collection.filter(key="my_collection").first()
79
+ >>> dataset = Dataset(
80
+ ... lamin_dataset=collection,
81
+ ... clss_to_predict=["organism_ontology_term_id", "cell_type_ontology_term_id"],
82
+ ... hierarchical_clss=["cell_type_ontology_term_id"],
83
+ ... )
84
+ >>> sample = dataset[0] # Returns dict with "X" and encoded labels
49
85
  """
50
86
 
51
87
  lamin_dataset: ln.Collection
@@ -165,7 +201,20 @@ class Dataset(torchDataset):
165
201
  self,
166
202
  obs_keys: Union[str, List[str]],
167
203
  ):
168
- """Get all categories for the given label keys."""
204
+ """
205
+ Get combined categorical codes for one or more label columns.
206
+
207
+ Retrieves labels from the mapped dataset and combines them into a single
208
+ categorical encoding. Useful for creating compound class labels for
209
+ stratified sampling.
210
+
211
+ Args:
212
+ obs_keys (str | List[str]): Column name(s) to retrieve and combine.
213
+
214
+ Returns:
215
+ np.ndarray: Integer codes representing the (combined) categories.
216
+ Shape: (n_samples,).
217
+ """
169
218
  if isinstance(obs_keys, str):
170
219
  obs_keys = [obs_keys]
171
220
  labels = None
@@ -179,25 +228,37 @@ class Dataset(torchDataset):
179
228
 
180
229
  def get_unseen_mapped_dataset_elements(self, idx: int):
181
230
  """
182
- get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
231
+ Get genes marked as unseen for a specific sample.
232
+
233
+ Retrieves the list of genes that were not observed (expression = 0 or
234
+ marked as unseen) for the sample at the given index.
183
235
 
184
236
  Args:
185
- idx (int): index of the element to get
237
+ idx (int): Sample index in the dataset.
186
238
 
187
239
  Returns:
188
- List[str]: list of unseen genes
240
+ List[str]: List of unseen gene identifiers.
189
241
  """
190
242
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
191
243
 
192
244
  def define_hierarchies(self, clsses: List[str]):
193
245
  """
194
- define_hierarchies is a method to define the hierarchies for the classes to predict
246
+ Define hierarchical label groupings from ontology relationships.
247
+
248
+ Uses Bionty to retrieve parent-child relationships for ontology terms,
249
+ then builds groupings mapping parent terms to their descendants.
250
+ Updates encoders to include parent terms and reorders labels so that
251
+ leaf terms (directly predictable) come first.
195
252
 
196
253
  Args:
197
- clsses (List[str]): list of classes to predict
254
+ clsses (List[str]): List of ontology column names to process.
198
255
 
199
256
  Raises:
200
- ValueError: if the class is not in the accepted classes
257
+ ValueError: If a class name is not in the supported ontology types.
258
+
259
+ Note:
260
+ Modifies self.labels_groupings, self.class_topred, and
261
+ self.mapped_dataset.encoders in place.
201
262
  """
202
263
  # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
203
264
  self.labels_groupings = {}
@@ -318,6 +379,7 @@ class Dataset(torchDataset):
318
379
 
319
380
 
320
381
  class SimpleAnnDataset(torchDataset):
382
+
321
383
  def __init__(
322
384
  self,
323
385
  adata: AnnData,
@@ -327,16 +389,40 @@ class SimpleAnnDataset(torchDataset):
327
389
  encoder: Optional[dict[str, dict]] = None,
328
390
  ):
329
391
  """
330
- SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
331
- scDataloader and with your model during inference.
392
+ Simple PyTorch Dataset wrapper for a single AnnData object.
393
+
394
+ Provides a lightweight interface for using AnnData with PyTorch DataLoaders,
395
+ compatible with the scDataLoader collator. Useful for inference on new data
396
+ that isn't stored in LaminDB.
332
397
 
333
398
  Args:
334
- ----
335
- adata (anndata.AnnData): anndata object to use
336
- obs_to_output (List[str]): list of observations to output from anndata.obs
337
- layer (str): layer of the anndata to use
338
- get_knn_cells (bool): whether to get the knn cells
339
- encoder (dict[str, dict]): dictionary of encoders for the observations.
399
+ adata (AnnData): AnnData object containing expression data.
400
+ obs_to_output (List[str], optional): Observation columns to include in
401
+ output dictionaries. Defaults to [].
402
+ layer (str, optional): Layer name to use for expression values. If None,
403
+ uses adata.X. Defaults to None.
404
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor
405
+ expression data. Requires precomputed neighbors in adata.obsp.
406
+ Defaults to False.
407
+ encoder (dict[str, dict], optional): Dictionary mapping observation column
408
+ names to encoding dictionaries (str -> int). Defaults to None.
409
+
410
+ Attributes:
411
+ adataX (np.ndarray): Dense expression matrix.
412
+ encoder (dict): Label encoders.
413
+ obs_to_output (pd.DataFrame): Observation metadata to include.
414
+ distances (scipy.sparse matrix): KNN distance matrix (if get_knn_cells=True).
415
+
416
+ Raises:
417
+ ValueError: If get_knn_cells=True but "connectivities" not in adata.obsp.
418
+
419
+ Example:
420
+ >>> dataset = SimpleAnnDataset(
421
+ ... adata=my_adata,
422
+ ... obs_to_output=["cell_type", "organism_ontology_term_id"],
423
+ ... encoder={"organism_ontology_term_id": {"NCBITaxon:9606": 0}},
424
+ ... )
425
+ >>> loader = DataLoader(dataset, batch_size=32, collate_fn=collator)
340
426
  """
341
427
  self.adataX = adata.layers[layer] if layer is not None else adata.X
342
428
  self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
@@ -392,6 +478,46 @@ def mapped(
392
478
  store_location: str | None = None,
393
479
  force_recompute_indices: bool = False,
394
480
  ) -> MappedCollection:
481
+ """
482
+ Create a MappedCollection from a LaminDB Collection.
483
+
484
+ Factory function that handles artifact path resolution (staging or streaming)
485
+ and creates a MappedCollection for efficient access to multiple h5ad/zarr files.
486
+
487
+ Args:
488
+ dataset (ln.Collection): LaminDB Collection containing artifacts to map.
489
+ obs_keys (List[str], optional): Observation columns to load. Defaults to None.
490
+ obsm_keys (List[str], optional): Obsm keys to load. Defaults to None.
491
+ obs_filter (dict, optional): Filter observations by column values.
492
+ Keys are column names, values are allowed values. Defaults to None.
493
+ join (str, optional): How to join variables across files. "inner" for
494
+ intersection, "outer" for union. Defaults to "inner".
495
+ encode_labels (bool | List[str], optional): Whether/which columns to
496
+ integer-encode. True encodes all obs_keys. Defaults to True.
497
+ unknown_label (str | dict, optional): Label to use for unknown/missing
498
+ categories. Defaults to None.
499
+ cache_categories (bool, optional): Whether to cache category mappings.
500
+ Defaults to True.
501
+ parallel (bool, optional): Enable parallel data loading. Defaults to False.
502
+ dtype (str, optional): Data type for expression values. Defaults to None.
503
+ stream (bool, optional): If True, stream from cloud storage instead of
504
+ staging locally. Defaults to False.
505
+ is_run_input (bool, optional): Track as run input in LaminDB. Defaults to None.
506
+ metacell_mode (bool, optional): Enable metacell aggregation. Defaults to False.
507
+ meta_assays (List[str], optional): Assay types to treat as metacell-like.
508
+ Defaults to ["EFO:0022857", "EFO:0010961"].
509
+ get_knn_cells (bool, optional): Include KNN neighbor data. Defaults to False.
510
+ store_location (str, optional): Cache directory path. Defaults to None.
511
+ force_recompute_indices (bool, optional): Force recompute cached data.
512
+ Defaults to False.
513
+
514
+ Returns:
515
+ MappedCollection: Mapped collection for data access.
516
+
517
+ Note:
518
+ Artifacts with suffixes other than .h5ad, .zrad, .zarr are ignored.
519
+ Non-existent paths are skipped with a warning.
520
+ """
395
521
  path_list = []
396
522
  for artifact in dataset.artifacts.all():
397
523
  if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
@@ -425,14 +551,26 @@ def mapped(
425
551
 
426
552
 
427
553
  def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
428
- """Efficiently combine multiple categorical data using their codes,
429
- only creating categories for combinations that exist in the data.
554
+ """
555
+ Efficiently combine multiple categorical arrays into a single encoding.
556
+
557
+ Creates a combined categorical where each unique combination of input
558
+ categories gets a unique code. Only combinations that exist in the data
559
+ are assigned codes (sparse encoding).
430
560
 
431
561
  Args:
432
- series_list: List of pandas Categorical data
562
+ series_list (List[pd.Categorical]): List of categorical arrays to combine.
563
+ All arrays must have the same length.
433
564
 
434
565
  Returns:
435
- Combined Categorical with only existing combinations
566
+ pd.Categorical: Combined categorical with compressed codes representing
567
+ unique combinations present in the data.
568
+
569
+ Example:
570
+ >>> cat1 = pd.Categorical(["a", "a", "b", "b"])
571
+ >>> cat2 = pd.Categorical(["x", "y", "x", "y"])
572
+ >>> combined = concat_categorical_codes([cat1, cat2])
573
+ >>> # Results in 4 unique codes for (a,x), (a,y), (b,x), (b,y)
436
574
  """
437
575
  # Get the codes for each categorical
438
576
  codes_list = [s.codes.astype(np.int32) for s in series_list]
@@ -65,42 +65,102 @@ class DataModule(L.LightningDataModule):
65
65
  genedf: Optional[pd.DataFrame] = None,
66
66
  n_bins: int = 0,
67
67
  curiculum: int = 0,
68
+ start_at: int = 0,
68
69
  **kwargs,
69
70
  ):
70
71
  """
71
- DataModule a pytorch lighting datamodule directly from a lamin Collection.
72
- it can work with bare pytorch too
72
+ PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.
73
73
 
74
- It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
75
- This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
74
+ This DataModule provides train/val/test dataloaders with configurable sampling strategies.
75
+ It combines MappedCollection, Dataset, and Collator to create efficient data pipelines
76
+ for training single-cell foundation models.
77
+
78
+ The training dataloader uses weighted random sampling based on class frequencies,
79
+ validation uses random sampling, and test uses sequential sampling on held-out datasets.
76
80
 
77
81
  Args:
78
- collection_name (str): The lamindb collection to be used.
79
- weight_scaler (int, optional): how much more you will see the most present vs less present category.
80
- n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
81
- validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
82
- test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
83
- it will use a full dataset and will round to the nearest dataset's cell count.
84
- use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
85
- clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
86
- assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
87
- gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
88
- the file must have ensembl_gene_id as index.
89
- This is used to subset the available genes further to the ones that have embeddings in your model.
90
- max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
91
- how (str, optional): The method to use for the collator. Defaults to "random expr".
92
- organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
93
- tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
94
- hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
95
- metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
96
- clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
97
- get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
98
- store_location (str, optional): The location to store the sampler indices. Defaults to None.
99
- force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
100
- sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
101
- sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
102
- **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
103
- see @file data.py and @file collator.py for more details about some of the parameters
82
+ collection_name (str): Key of the LaminDB Collection to load.
83
+ clss_to_weight (List[str], optional): Label columns to use for weighted sampling
84
+ in the training dataloader. Supports "nnz" for weighting by number of
85
+ non-zero genes. Defaults to ["organism_ontology_term_id"].
86
+ weight_scaler (int, optional): Controls balance between rare and common classes.
87
+ Higher values lead to more uniform sampling across classes. Set to 0 to
88
+ disable weighted sampling. Defaults to 10.
89
+ n_samples_per_epoch (int, optional): Number of samples to draw per training epoch.
90
+ Defaults to 2,000,000.
91
+ validation_split (float | int, optional): Proportion (float) or absolute number (int)
92
+ of samples for validation. Defaults to 0.2.
93
+ test_split (float | int, optional): Proportion (float) or absolute number (int)
94
+ of samples for testing. Uses entire datasets as test sets, rounding to
95
+ nearest dataset boundary. Defaults to 0.
96
+ use_default_col (bool, optional): Whether to use the default Collator for batch
97
+ preparation. If False, no collate_fn is applied. Defaults to True.
98
+ clss_to_predict (List[str], optional): Observation columns to encode as prediction
99
+ targets. Must include "organism_ontology_term_id". Defaults to
100
+ ["organism_ontology_term_id"].
101
+ hierarchical_clss (List[str], optional): Observation columns with hierarchical
102
+ ontology structure to be processed. Defaults to [].
103
+ how (str, optional): Gene selection strategy passed to Collator. One of
104
+ "most expr", "random expr", "all", "some". Defaults to "random expr".
105
+ organism_col (str, optional): Column name for organism ontology term ID.
106
+ Defaults to "organism_ontology_term_id".
107
+ max_len (int, optional): Maximum number of genes per sample passed to Collator.
108
+ Defaults to 1000.
109
+ replacement (bool, optional): Whether to sample with replacement in training.
110
+ Defaults to True.
111
+ gene_subset (List[str], optional): List of genes to restrict the dataset to.
112
+ Useful when model only supports specific genes. Defaults to None.
113
+ tp_name (str, optional): Column name for time point or heat diffusion values.
114
+ Defaults to None.
115
+ assays_to_drop (List[str], optional): List of assay ontology term IDs to exclude
116
+ from training. Defaults to ["EFO:0030007"] (ATAC-seq).
117
+ metacell_mode (float, optional): Probability of using metacell aggregation mode.
118
+ Cannot be used with get_knn_cells. Defaults to 0.0.
119
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
120
+ expression data. Cannot be used with metacell_mode. Defaults to False.
121
+ store_location (str, optional): Directory path to cache sampler indices and
122
+ labels for faster subsequent loading. Defaults to None.
123
+ force_recompute_indices (bool, optional): Force recomputation of cached indices
124
+ even if they exist. Defaults to False.
125
+ sampler_workers (int, optional): Number of parallel workers for building sampler
126
+ indices. Auto-determined based on available CPUs if None. Defaults to None.
127
+ sampler_chunk_size (int, optional): Chunk size for parallel sampler processing.
128
+ Auto-determined based on available memory if None. Defaults to None.
129
+ organisms (List[str], optional): List of organisms to include. If None, uses
130
+ all organisms in the dataset. Defaults to None.
131
+ genedf (pd.DataFrame, optional): Gene information DataFrame. If None, loaded
132
+ automatically. Defaults to None.
133
+ n_bins (int, optional): Number of bins for expression discretization. 0 means
134
+ no binning. Defaults to 0.
135
+ curiculum (int, optional): Curriculum learning parameter. If > 0, gradually
136
+ increases sampling weight balance over epochs. Defaults to 0.
137
+ start_at (int, optional): Starting index for resuming inference. Requires same
138
+ number of GPUs as previous run. Defaults to 0.
139
+ **kwargs: Additional arguments passed to PyTorch DataLoader (e.g., batch_size,
140
+ num_workers, pin_memory).
141
+
142
+ Attributes:
143
+ dataset (Dataset): The underlying Dataset instance.
144
+ classes (dict[str, int]): Mapping from class names to number of categories.
145
+ train_labels (np.ndarray): Label array for weighted sampling.
146
+ idx_full (np.ndarray): Indices for training samples.
147
+ valid_idx (np.ndarray): Indices for validation samples.
148
+ test_idx (np.ndarray): Indices for test samples.
149
+ test_datasets (List[str]): Paths to datasets used for testing.
150
+
151
+ Raises:
152
+ ValueError: If "organism_ontology_term_id" not in clss_to_predict.
153
+ ValueError: If both metacell_mode > 0 and get_knn_cells are True.
154
+
155
+ Example:
156
+ >>> dm = DataModule(
157
+ ... collection_name="my_collection",
158
+ ... batch_size=32,
159
+ ... num_workers=4,
160
+ ... max_len=2000,
161
+ ... )
162
+ >>> dm.setup()
163
+ >>> train_loader = dm.train_dataloader()
104
164
  """
105
165
  if "organism_ontology_term_id" not in clss_to_predict:
106
166
  raise ValueError(
@@ -162,6 +222,7 @@ class DataModule(L.LightningDataModule):
162
222
  self.sampler_chunk_size = sampler_chunk_size
163
223
  self.store_location = store_location
164
224
  self.nnz = None
225
+ self.start_at = start_at
165
226
  self.idx_full = None
166
227
  self.max_len = max_len
167
228
  self.test_datasets = []
@@ -283,12 +344,24 @@ class DataModule(L.LightningDataModule):
283
344
 
284
345
  def setup(self, stage=None):
285
346
  """
286
- setup method is used to prepare the data for the training, validation, and test sets.
287
- It shuffles the data, calculates weights for each set, and creates samplers for each set.
347
+ Prepare data splits for training, validation, and testing.
348
+
349
+ This method shuffles the data, computes sample weights for weighted sampling,
350
+ removes samples from dropped assays, and creates train/val/test splits.
351
+ Test splits use entire datasets to ensure evaluation on unseen data sources.
352
+
353
+ Results can be cached to `store_location` for faster subsequent runs.
288
354
 
289
355
  Args:
290
- stage (str, optional): The stage of the model training process.
291
- It can be either 'fit' or 'test'. Defaults to None.
356
+ stage (str, optional): Training stage ('fit', 'test', or None for both).
357
+ Currently not used but kept for Lightning compatibility. Defaults to None.
358
+
359
+ Returns:
360
+ List[str]: List of paths to test datasets.
361
+
362
+ Note:
363
+ Must be called before using dataloaders. The train/val/test split is
364
+ deterministic when loading from cache.
292
365
  """
293
366
  print("setting up the datamodule")
294
367
  start_time = time.time()
@@ -324,9 +397,9 @@ class DataModule(L.LightningDataModule):
324
397
  len_test = self.test_split
325
398
  else:
326
399
  len_test = int(self.n_samples * self.test_split)
327
- assert len_test + len_valid < self.n_samples, (
328
- "test set + valid set size is configured to be larger than entire dataset."
329
- )
400
+ assert (
401
+ len_test + len_valid < self.n_samples
402
+ ), "test set + valid set size is configured to be larger than entire dataset."
330
403
 
331
404
  idx_full = []
332
405
  if len(self.assays_to_drop) > 0:
@@ -439,6 +512,22 @@ class DataModule(L.LightningDataModule):
439
512
  return self.test_datasets
440
513
 
441
514
  def train_dataloader(self, **kwargs):
515
+ """
516
+ Create the training DataLoader with weighted random sampling.
517
+
518
+ Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0
519
+ and clss_to_weight is specified. Otherwise uses RankShardSampler for
520
+ distributed training without weighting.
521
+
522
+ Args:
523
+ **kwargs: Additional arguments passed to DataLoader, overriding defaults.
524
+
525
+ Returns:
526
+ DataLoader: Training DataLoader instance.
527
+
528
+ Raises:
529
+ ValueError: If setup() has not been called.
530
+ """
442
531
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
443
532
  try:
444
533
  print("Setting up the parallel train sampler...")
@@ -461,7 +550,7 @@ class DataModule(L.LightningDataModule):
461
550
  dataset = None
462
551
  else:
463
552
  dataset = Subset(self.dataset, self.idx_full)
464
- train_sampler = RankShardSampler(len(dataset))
553
+ train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
465
554
  current_loader_kwargs = kwargs.copy()
466
555
  current_loader_kwargs.update(self.kwargs)
467
556
  return DataLoader(
@@ -471,6 +560,12 @@ class DataModule(L.LightningDataModule):
471
560
  )
472
561
 
473
562
  def val_dataloader(self):
563
+ """
564
+ Create the validation DataLoader.
565
+
566
+ Returns:
567
+ DataLoader | List: Validation DataLoader, or empty list if no validation split.
568
+ """
474
569
  return (
475
570
  DataLoader(
476
571
  Subset(self.dataset, self.valid_idx),
@@ -481,6 +576,12 @@ class DataModule(L.LightningDataModule):
481
576
  )
482
577
 
483
578
  def test_dataloader(self):
579
+ """
580
+ Create the test DataLoader with sequential sampling.
581
+
582
+ Returns:
583
+ DataLoader | List: Test DataLoader, or empty list if no test split.
584
+ """
484
585
  return (
485
586
  DataLoader(
486
587
  self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
@@ -490,24 +591,23 @@ class DataModule(L.LightningDataModule):
490
591
  )
491
592
 
492
593
  def predict_dataloader(self):
594
+ """
595
+ Create a DataLoader for prediction over all training data.
596
+
597
+ Uses RankShardSampler for distributed inference.
598
+
599
+ Returns:
600
+ DataLoader: Prediction DataLoader instance.
601
+ """
493
602
  subset = Subset(self.dataset, self.idx_full)
494
603
  return DataLoader(
495
- subset,
496
- sampler=RankShardSampler(len(subset)),
604
+ self.dataset,
605
+ sampler=RankShardSampler(len(subset), start_at=self.start_at),
497
606
  **self.kwargs,
498
607
  )
499
608
 
500
609
 
501
610
  class LabelWeightedSampler(Sampler[int]):
502
- """
503
- A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
504
-
505
- This sampler is designed to handle very large datasets efficiently, with optimizations for:
506
- 1. Parallel building of class indices
507
- 2. Chunked processing for large arrays
508
- 3. Efficient memory management
509
- 4. Proper handling of replacement and non-replacement sampling
510
- """
511
611
 
512
612
  label_weights: torch.Tensor
513
613
  klass_indices: dict[int, torch.Tensor]
@@ -529,16 +629,58 @@ class LabelWeightedSampler(Sampler[int]):
529
629
  curiculum: int = 0,
530
630
  ) -> None:
531
631
  """
532
- Initialize the sampler with parallel processing for large datasets.
632
+ Weighted random sampler balancing both class frequencies and element weights.
633
+
634
+ This sampler is optimized for very large datasets (millions of samples) with:
635
+ - Parallel construction of class indices using multiple CPU workers
636
+ - Chunked processing to manage memory usage
637
+ - Support for curriculum learning via progressive weight scaling
638
+ - Optional per-element weights (e.g., by number of expressed genes)
639
+
640
+ The sampling process:
641
+ 1. Sample class labels according to class weights
642
+ 2. For each sampled class, sample elements according to element weights
643
+ 3. Shuffle all sampled indices
533
644
 
534
645
  Args:
535
- weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
536
- labels: Class label for each dataset element (length = dataset size)
537
- num_samples: Number of samples to draw
538
- replacement: Whether to sample with replacement
539
- element_weights: Optional weights for each element within classes
540
- n_workers: Number of parallel workers to use (default: number of CPUs-1)
541
- chunk_size: Size of chunks to process in parallel (default: 10M elements)
646
+ labels (np.ndarray): Integer class label for each dataset element.
647
+ Shape: (dataset_size,). The last unique label is treated as
648
+ "excluded" with weight 0.
649
+ num_samples (int): Number of samples to draw per epoch.
650
+ replacement (bool, optional): Whether to sample with replacement.
651
+ Defaults to True.
652
+ weight_scaler (float, optional): Controls class weight balance.
653
+ Weight formula: (scaler * count) / (count + scaler).
654
+ Higher values = more uniform sampling. Defaults to None.
655
+ element_weights (Sequence[float], optional): Per-element sampling weights.
656
+ Shape: (dataset_size,). Defaults to None (uniform within class).
657
+ n_workers (int, optional): Number of parallel workers for index building.
658
+ Defaults to min(20, num_cpus - 1).
659
+ chunk_size (int, optional): Elements per chunk for parallel processing.
660
+ Auto-determined based on available memory if None.
661
+ store_location (str, optional): Directory to cache computed indices.
662
+ Defaults to None.
663
+ force_recompute_indices (bool, optional): Recompute indices even if cached.
664
+ Defaults to False.
665
+ curiculum (int, optional): Curriculum learning epochs. If > 0, weight
666
+ exponent increases from 0 to 1 over this many epochs. Defaults to 0.
667
+
668
+ Attributes:
669
+ label_weights (torch.Tensor): Computed weights per class label.
670
+ klass_indices (torch.Tensor): Concatenated indices for all classes.
671
+ klass_offsets (torch.Tensor): Starting offset for each class in klass_indices.
672
+ count (int): Number of times __iter__ has been called (for curriculum).
673
+
674
+ Example:
675
+ >>> sampler = LabelWeightedSampler(
676
+ ... labels=train_labels,
677
+ ... num_samples=1_000_000,
678
+ ... weight_scaler=10,
679
+ ... element_weights=nnz_weights,
680
+ ... )
681
+ >>> for idx in sampler:
682
+ ... # Process sample at idx
683
+ ... pass
542
684
  """
543
685
  print("Initializing optimized parallel weighted sampler...")
544
686
  super(LabelWeightedSampler, self).__init__(None)
@@ -667,7 +809,9 @@ class LabelWeightedSampler(Sampler[int]):
667
809
  unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
668
810
 
669
811
  # Initialize result tensor
670
- result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
812
+ result_indices_list = (
813
+ []
814
+ ) # Changed name to avoid conflict if you had result_indices elsewhere
671
815
 
672
816
  # Process only the classes that were actually sampled
673
817
  for i, (label, count) in tqdm(
@@ -847,11 +991,36 @@ class LabelWeightedSampler(Sampler[int]):
847
991
 
848
992
 
849
993
  class RankShardSampler(Sampler[int]):
850
- """Shards a dataset contiguously across ranks without padding or duplicates.
851
- Preserves the existing order (e.g., your pre-shuffled idx_full)."""
994
+ """
995
+ Sampler that shards data contiguously across distributed ranks.
996
+
997
+ Divides the dataset into contiguous chunks, one per rank, without
998
+ padding or duplicating samples. Preserves the original data order
999
+ within each shard (useful for pre-shuffled data).
1000
+
1001
+ Args:
1002
+ data_len (int): Total number of samples in the dataset.
1003
+ start_at (int, optional): Global starting index for resuming training.
1004
+ Requires the same number of GPUs as the previous run. Defaults to 0.
1005
+
1006
+ Attributes:
1007
+ rank (int): Current process rank (0 if not distributed).
1008
+ world_size (int): Total number of processes (1 if not distributed).
1009
+ start (int): Starting index for this rank's shard.
1010
+ end (int): Ending index (exclusive) for this rank's shard.
1011
+
1012
+ Note:
1013
+ The last rank may have fewer samples than others if the dataset
1014
+ size is not evenly divisible by world_size.
1015
+
1016
+ Example:
1017
+ >>> sampler = RankShardSampler(len(dataset))
1018
+ >>> loader = DataLoader(dataset, sampler=sampler)
1019
+ """
852
1020
 
853
- def __init__(self, data_len: int):
1021
+ def __init__(self, data_len: int, start_at: int = 0) -> None:
854
1022
  self.data_len = data_len
1023
+ self.start_at = start_at
855
1024
  if torch.distributed.is_available() and torch.distributed.is_initialized():
856
1025
  self.rank = torch.distributed.get_rank()
857
1026
  self.world_size = torch.distributed.get_world_size()
@@ -859,9 +1028,16 @@ class RankShardSampler(Sampler[int]):
859
1028
  self.rank, self.world_size = 0, 1
860
1029
 
861
1030
  # contiguous chunk per rank (last rank may be shorter)
1031
+ if self.start_at > 0:
1032
+ print(
1033
+ "!!!!ATTENTION: make sure that you are running on the exact same \
1034
+ number of GPU as your previous run!!!!!"
1035
+ )
1036
+ print(f"Sharding data of size {data_len} over {self.world_size} ranks")
862
1037
  per_rank = math.ceil(self.data_len / self.world_size)
863
- self.start = self.rank * per_rank
864
- self.end = min(self.start + per_rank, self.data_len)
1038
+ self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
1039
+ self.end = min((self.rank + 1) * per_rank, self.data_len)
1040
+ print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
865
1041
 
866
1042
  def __iter__(self):
867
1043
  return iter(range(self.start, self.end))
@@ -1,3 +1,10 @@
1
+ """
2
+ Preprocessing utilities for single-cell gene expression data.
3
+
4
+ This module provides functions for normalizing, transforming, and discretizing
5
+ gene expression values for use with scPRINT and similar models.
6
+ """
7
+
1
8
  import gc
2
9
  import time
3
10
  from typing import Callable, List, Optional, Union
@@ -664,39 +671,41 @@ def is_log1p(adata: AnnData) -> bool:
664
671
  return True
665
672
 
666
673
 
667
- def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
674
+ def _digitize(values: np.ndarray, bins: np.ndarray) -> np.ndarray:
668
675
  """
669
- Digitize the data into bins. This method spreads data uniformly when bins
670
- have same values.
676
+ Digitize values into discrete bins with 1-based indexing.
671
677
 
672
- Args:
678
+ Similar to np.digitize but ensures output is 1-indexed (bin 0 reserved for
679
+ zero values) and handles edge cases for expression binning.
673
680
 
674
- x (:class:`np.ndarray`):
675
- The data to digitize.
676
- bins (:class:`np.ndarray`):
677
- The bins to use for digitization, in increasing order.
678
- side (:class:`str`, optional):
679
- The side to use for digitization. If "one", the left side is used. If
680
- "both", the left and right side are used. Default to "one".
681
+ Args:
682
+ values (np.ndarray): Array of values to discretize. Should be non-zero
683
+ expression values.
684
+ bins (np.ndarray): Bin edges from np.quantile or similar. Values are
685
+ assigned to bins based on which edges they fall between.
681
686
 
682
687
  Returns:
683
-
684
- :class:`np.ndarray`:
685
- The digitized data.
688
+ np.ndarray: Integer bin indices, 1-indexed. Values equal to bins[i]
689
+ are assigned to bin i+1.
690
+
691
+ Example:
692
+ >>> values = np.array([0.5, 1.5, 2.5, 3.5])
693
+ >>> bins = np.array([1.0, 2.0, 3.0])
694
+ >>> _digitize(values, bins)
695
+ array([1, 2, 3, 3])
696
+
697
+ Note:
698
+ This function is used internally by the Collator for expression binning.
699
+ Zero values should be handled separately before calling this function.
686
700
  """
687
- assert x.ndim == 1 and bins.ndim == 1
688
-
689
- left_digits = np.digitize(x, bins)
690
- if side == "one":
691
- return left_digits
701
+ assert values.ndim == 1 and bins.ndim == 1
692
702
 
693
- right_difits = np.digitize(x, bins, right=True)
703
+ left_digits = np.digitize(values, bins)
704
+ return left_digits
694
705
 
695
- rands = np.random.rand(len(x)) # uniform random numbers
696
706
 
697
- digits = rands * (right_difits - left_digits) + left_digits
698
- digits = np.ceil(digits).astype(np.int64)
699
- return digits
707
+ # Add documentation for any other functions in preprocess.py
708
+ # ...existing code...
700
709
 
701
710
 
702
711
  def binning(row: np.ndarray, n_bins: int) -> np.ndarray:
File without changes
File without changes
File without changes