scdataloader 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -22,30 +22,66 @@ from .mapped import MappedCollection, _Connect
22
22
  @dataclass
23
23
  class Dataset(torchDataset):
24
24
  """
25
- Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
25
+ PyTorch Dataset for loading single-cell data from a LaminDB Collection.
26
26
 
27
- This serves as a wrapper around lamin's mappedCollection to provide more features,
28
- mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
27
+ This class wraps LaminDB's MappedCollection to provide additional features:
28
+ - Management of hierarchical ontology labels (cell type, tissue, disease, etc.)
29
+ - Automatic encoding of categorical labels to integers
30
+ - Multi-species gene handling with unified gene indexing
31
+ - Optional metacell aggregation and KNN neighbor retrieval
29
32
 
30
- For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
31
-
32
- .. note::
33
-
34
- A related data loader exists `here
35
- <https://github.com/Genentech/scimilarity>`__.
33
+ The dataset lazily loads data from storage, making it memory-efficient for
34
+ large collections spanning multiple files.
36
35
 
37
36
  Args:
38
- ----
39
- lamin_dataset (lamindb.Dataset): lamin dataset to load
40
- genedf (pd.Dataframe): dataframe containing the genes to load
41
- obs (List[str]): list of observations to load from the Collection
42
- clss_to_predict (List[str]): list of observations to encode
43
- join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
44
- hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
45
- metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
46
- get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
47
- store_location (str, optional): The location to store the sampler indices. Defaults to None.
48
- force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
37
+ lamin_dataset (ln.Collection): LaminDB Collection containing the artifacts to load.
38
+ genedf (pd.DataFrame, optional): DataFrame with gene information, indexed by gene ID
39
+ with an 'organism' column. If None, automatically loaded based on organisms
40
+ in the dataset. Defaults to None.
41
+ clss_to_predict (List[str], optional): Observation columns to encode as prediction
42
+ targets. These will be integer-encoded in the output. Defaults to [].
43
+ hierarchical_clss (List[str], optional): Observation columns with hierarchical
44
+ ontology structure. These will have their ancestry relationships computed
45
+ using Bionty. Supported columns:
46
+ - "cell_type_ontology_term_id"
47
+ - "tissue_ontology_term_id"
48
+ - "disease_ontology_term_id"
49
+ - "development_stage_ontology_term_id"
50
+ - "assay_ontology_term_id"
51
+ - "self_reported_ethnicity_ontology_term_id"
52
+ Defaults to [].
53
+ join_vars (str, optional): How to join variables across artifacts.
54
+ "inner" for intersection, "outer" for union, None for no joining.
55
+ Defaults to None.
56
+ metacell_mode (float, optional): Probability of returning aggregated metacell
57
+ expression instead of single-cell. Defaults to 0.0.
58
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
59
+ expression in the output. Requires precomputed neighbors in the data.
60
+ Defaults to False.
61
+ store_location (str, optional): Directory path to cache computed indices.
62
+ Defaults to None.
63
+ force_recompute_indices (bool, optional): Force recomputation of cached data.
64
+ Defaults to False.
65
+
66
+ Attributes:
67
+ mapped_dataset (MappedCollection): Underlying mapped collection for data access.
68
+ genedf (pd.DataFrame): Gene information DataFrame.
69
+ organisms (List[str]): List of organism ontology term IDs in the dataset.
70
+ class_topred (dict[str, set]): Mapping from class name to set of valid labels.
71
+ labels_groupings (dict[str, dict]): Hierarchical groupings for ontology classes.
72
+ encoder (dict[str, dict]): Label encoders mapping strings to integers.
73
+
74
+ Raises:
75
+ ValueError: If genedf is None and "organism_ontology_term_id" is not in clss_to_predict.
76
+
77
+ Example:
78
+ >>> collection = ln.Collection.filter(key="my_collection").first()
79
+ >>> dataset = Dataset(
80
+ ... lamin_dataset=collection,
81
+ ... clss_to_predict=["organism_ontology_term_id", "cell_type_ontology_term_id"],
82
+ ... hierarchical_clss=["cell_type_ontology_term_id"],
83
+ ... )
84
+ >>> sample = dataset[0] # Returns dict with "X" and encoded labels
49
85
  """
50
86
 
51
87
  lamin_dataset: ln.Collection
@@ -165,7 +201,20 @@ class Dataset(torchDataset):
165
201
  self,
166
202
  obs_keys: Union[str, List[str]],
167
203
  ):
168
- """Get all categories for the given label keys."""
204
+ """
205
+ Get combined categorical codes for one or more label columns.
206
+
207
+ Retrieves labels from the mapped dataset and combines them into a single
208
+ categorical encoding. Useful for creating compound class labels for
209
+ stratified sampling.
210
+
211
+ Args:
212
+ obs_keys (str | List[str]): Column name(s) to retrieve and combine.
213
+
214
+ Returns:
215
+ np.ndarray: Integer codes representing the (combined) categories.
216
+ Shape: (n_samples,).
217
+ """
169
218
  if isinstance(obs_keys, str):
170
219
  obs_keys = [obs_keys]
171
220
  labels = None
@@ -179,25 +228,37 @@ class Dataset(torchDataset):
179
228
 
180
229
  def get_unseen_mapped_dataset_elements(self, idx: int):
181
230
  """
182
- get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
231
+ Get genes marked as unseen for a specific sample.
232
+
233
+ Retrieves the list of genes that were not observed (expression = 0 or
234
+ marked as unseen) for the sample at the given index.
183
235
 
184
236
  Args:
185
- idx (int): index of the element to get
237
+ idx (int): Sample index in the dataset.
186
238
 
187
239
  Returns:
188
- List[str]: list of unseen genes
240
+ List[str]: List of unseen gene identifiers.
189
241
  """
190
242
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
191
243
 
192
244
  def define_hierarchies(self, clsses: List[str]):
193
245
  """
194
- define_hierarchies is a method to define the hierarchies for the classes to predict
246
+ Define hierarchical label groupings from ontology relationships.
247
+
248
+ Uses Bionty to retrieve parent-child relationships for ontology terms,
249
+ then builds groupings mapping parent terms to their descendants.
250
+ Updates encoders to include parent terms and reorders labels so that
251
+ leaf terms (directly predictable) come first.
195
252
 
196
253
  Args:
197
- clsses (List[str]): list of classes to predict
254
+ clsses (List[str]): List of ontology column names to process.
198
255
 
199
256
  Raises:
200
- ValueError: if the class is not in the accepted classes
257
+ ValueError: If a class name is not in the supported ontology types.
258
+
259
+ Note:
260
+ Modifies self.labels_groupings, self.class_topred, and
261
+ self.mapped_dataset.encoders in place.
201
262
  """
202
263
  # TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
203
264
  self.labels_groupings = {}
@@ -318,6 +379,7 @@ class Dataset(torchDataset):
318
379
 
319
380
 
320
381
  class SimpleAnnDataset(torchDataset):
382
+
321
383
  def __init__(
322
384
  self,
323
385
  adata: AnnData,
@@ -327,16 +389,40 @@ class SimpleAnnDataset(torchDataset):
327
389
  encoder: Optional[dict[str, dict]] = None,
328
390
  ):
329
391
  """
330
- SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
331
- scDataloader and with your model during inference.
392
+ Simple PyTorch Dataset wrapper for a single AnnData object.
393
+
394
+ Provides a lightweight interface for using AnnData with PyTorch DataLoaders,
395
+ compatible with the scDataLoader collator. Useful for inference on new data
396
+ that isn't stored in LaminDB.
332
397
 
333
398
  Args:
334
- ----
335
- adata (anndata.AnnData): anndata object to use
336
- obs_to_output (List[str]): list of observations to output from anndata.obs
337
- layer (str): layer of the anndata to use
338
- get_knn_cells (bool): whether to get the knn cells
339
- encoder (dict[str, dict]): dictionary of encoders for the observations.
399
+ adata (AnnData): AnnData object containing expression data.
400
+ obs_to_output (List[str], optional): Observation columns to include in
401
+ output dictionaries. Defaults to [].
402
+ layer (str, optional): Layer name to use for expression values. If None,
403
+ uses adata.X. Defaults to None.
404
+ get_knn_cells (bool, optional): Whether to include k-nearest neighbor
405
+ expression data. Requires precomputed neighbors in adata.obsp.
406
+ Defaults to False.
407
+ encoder (dict[str, dict], optional): Dictionary mapping observation column
408
+ names to encoding dictionaries (str -> int). Defaults to None.
409
+
410
+ Attributes:
411
+ adataX (np.ndarray): Dense expression matrix.
412
+ encoder (dict): Label encoders.
413
+ obs_to_output (pd.DataFrame): Observation metadata to include.
414
+ distances (scipy.sparse matrix): KNN distance matrix (if get_knn_cells=True).
415
+
416
+ Raises:
417
+ ValueError: If get_knn_cells=True but "connectivities" not in adata.obsp.
418
+
419
+ Example:
420
+ >>> dataset = SimpleAnnDataset(
421
+ ... adata=my_adata,
422
+ ... obs_to_output=["cell_type", "organism_ontology_term_id"],
423
+ ... encoder={"organism_ontology_term_id": {"NCBITaxon:9606": 0}},
424
+ ... )
425
+ >>> loader = DataLoader(dataset, batch_size=32, collate_fn=collator)
340
426
  """
341
427
  self.adataX = adata.layers[layer] if layer is not None else adata.X
342
428
  self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
@@ -392,6 +478,46 @@ def mapped(
392
478
  store_location: str | None = None,
393
479
  force_recompute_indices: bool = False,
394
480
  ) -> MappedCollection:
481
+ """
482
+ Create a MappedCollection from a LaminDB Collection.
483
+
484
+ Factory function that handles artifact path resolution (staging or streaming)
485
+ and creates a MappedCollection for efficient access to multiple h5ad/zarr files.
486
+
487
+ Args:
488
+ dataset (ln.Collection): LaminDB Collection containing artifacts to map.
489
+ obs_keys (List[str], optional): Observation columns to load. Defaults to None.
490
+ obsm_keys (List[str], optional): Obsm keys to load. Defaults to None.
491
+ obs_filter (dict, optional): Filter observations by column values.
492
+ Keys are column names, values are allowed values. Defaults to None.
493
+ join (str, optional): How to join variables across files. "inner" for
494
+ intersection, "outer" for union. Defaults to "inner".
495
+ encode_labels (bool | List[str], optional): Whether/which columns to
496
+ integer-encode. True encodes all obs_keys. Defaults to True.
497
+ unknown_label (str | dict, optional): Label to use for unknown/missing
498
+ categories. Defaults to None.
499
+ cache_categories (bool, optional): Whether to cache category mappings.
500
+ Defaults to True.
501
+ parallel (bool, optional): Enable parallel data loading. Defaults to False.
502
+ dtype (str, optional): Data type for expression values. Defaults to None.
503
+ stream (bool, optional): If True, stream from cloud storage instead of
504
+ staging locally. Defaults to False.
505
+ is_run_input (bool, optional): Track as run input in LaminDB. Defaults to None.
506
+ metacell_mode (bool, optional): Enable metacell aggregation. Defaults to False.
507
+ meta_assays (List[str], optional): Assay types to treat as metacell-like.
508
+ Defaults to ["EFO:0022857", "EFO:0010961"].
509
+ get_knn_cells (bool, optional): Include KNN neighbor data. Defaults to False.
510
+ store_location (str, optional): Cache directory path. Defaults to None.
511
+ force_recompute_indices (bool, optional): Force recompute cached data.
512
+ Defaults to False.
513
+
514
+ Returns:
515
+ MappedCollection: Mapped collection for data access.
516
+
517
+ Note:
518
+ Artifacts with suffixes other than .h5ad, .zrad, .zarr are ignored.
519
+ Non-existent paths are skipped with a warning.
520
+ """
395
521
  path_list = []
396
522
  for artifact in dataset.artifacts.all():
397
523
  if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
@@ -425,14 +551,26 @@ def mapped(
425
551
 
426
552
 
427
553
  def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
428
- """Efficiently combine multiple categorical data using their codes,
429
- only creating categories for combinations that exist in the data.
554
+ """
555
+ Efficiently combine multiple categorical arrays into a single encoding.
556
+
557
+ Creates a combined categorical where each unique combination of input
558
+ categories gets a unique code. Only combinations that exist in the data
559
+ are assigned codes (sparse encoding).
430
560
 
431
561
  Args:
432
- series_list: List of pandas Categorical data
562
+ series_list (List[pd.Categorical]): List of categorical arrays to combine.
563
+ All arrays must have the same length.
433
564
 
434
565
  Returns:
435
- Combined Categorical with only existing combinations
566
+ pd.Categorical: Combined categorical with compressed codes representing
567
+ unique combinations present in the data.
568
+
569
+ Example:
570
+ >>> cat1 = pd.Categorical(["a", "a", "b", "b"])
571
+ >>> cat2 = pd.Categorical(["x", "y", "x", "y"])
572
+ >>> combined = concat_categorical_codes([cat1, cat2])
573
+ >>> # Results in 4 unique codes for (a,x), (a,y), (b,x), (b,y)
436
574
  """
437
575
  # Get the codes for each categorical
438
576
  codes_list = [s.codes.astype(np.int32) for s in series_list]