scdataloader 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/collator.py +99 -36
- scdataloader/config.py +1151 -0
- scdataloader/data.py +177 -39
- scdataloader/datamodule.py +222 -57
- scdataloader/preprocess.py +33 -24
- scdataloader/utils.py +31 -181
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/METADATA +1 -1
- scdataloader-2.0.5.dist-info/RECORD +16 -0
- scdataloader-2.0.3.dist-info/RECORD +0 -16
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/WHEEL +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/entry_points.txt +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/licenses/LICENSE +0 -0
scdataloader/data.py
CHANGED
|
@@ -22,30 +22,66 @@ from .mapped import MappedCollection, _Connect
|
|
|
22
22
|
@dataclass
|
|
23
23
|
class Dataset(torchDataset):
|
|
24
24
|
"""
|
|
25
|
-
Dataset
|
|
25
|
+
PyTorch Dataset for loading single-cell data from a LaminDB Collection.
|
|
26
26
|
|
|
27
|
-
This
|
|
28
|
-
|
|
27
|
+
This class wraps LaminDB's MappedCollection to provide additional features:
|
|
28
|
+
- Management of hierarchical ontology labels (cell type, tissue, disease, etc.)
|
|
29
|
+
- Automatic encoding of categorical labels to integers
|
|
30
|
+
- Multi-species gene handling with unified gene indexing
|
|
31
|
+
- Optional metacell aggregation and KNN neighbor retrieval
|
|
29
32
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
.. note::
|
|
33
|
-
|
|
34
|
-
A related data loader exists `here
|
|
35
|
-
<https://github.com/Genentech/scimilarity>`__.
|
|
33
|
+
The dataset lazily loads data from storage, making it memory-efficient for
|
|
34
|
+
large collections spanning multiple files.
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
clss_to_predict (List[str]):
|
|
43
|
-
|
|
44
|
-
hierarchical_clss
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
37
|
+
lamin_dataset (ln.Collection): LaminDB Collection containing the artifacts to load.
|
|
38
|
+
genedf (pd.DataFrame, optional): DataFrame with gene information, indexed by gene ID
|
|
39
|
+
with an 'organism' column. If None, automatically loaded based on organisms
|
|
40
|
+
in the dataset. Defaults to None.
|
|
41
|
+
clss_to_predict (List[str], optional): Observation columns to encode as prediction
|
|
42
|
+
targets. These will be integer-encoded in the output. Defaults to [].
|
|
43
|
+
hierarchical_clss (List[str], optional): Observation columns with hierarchical
|
|
44
|
+
ontology structure. These will have their ancestry relationships computed
|
|
45
|
+
using Bionty. Supported columns:
|
|
46
|
+
- "cell_type_ontology_term_id"
|
|
47
|
+
- "tissue_ontology_term_id"
|
|
48
|
+
- "disease_ontology_term_id"
|
|
49
|
+
- "development_stage_ontology_term_id"
|
|
50
|
+
- "assay_ontology_term_id"
|
|
51
|
+
- "self_reported_ethnicity_ontology_term_id"
|
|
52
|
+
Defaults to [].
|
|
53
|
+
join_vars (str, optional): How to join variables across artifacts.
|
|
54
|
+
"inner" for intersection, "outer" for union, None for no joining.
|
|
55
|
+
Defaults to None.
|
|
56
|
+
metacell_mode (float, optional): Probability of returning aggregated metacell
|
|
57
|
+
expression instead of single-cell. Defaults to 0.0.
|
|
58
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
|
|
59
|
+
expression in the output. Requires precomputed neighbors in the data.
|
|
60
|
+
Defaults to False.
|
|
61
|
+
store_location (str, optional): Directory path to cache computed indices.
|
|
62
|
+
Defaults to None.
|
|
63
|
+
force_recompute_indices (bool, optional): Force recomputation of cached data.
|
|
64
|
+
Defaults to False.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
mapped_dataset (MappedCollection): Underlying mapped collection for data access.
|
|
68
|
+
genedf (pd.DataFrame): Gene information DataFrame.
|
|
69
|
+
organisms (List[str]): List of organism ontology term IDs in the dataset.
|
|
70
|
+
class_topred (dict[str, set]): Mapping from class name to set of valid labels.
|
|
71
|
+
labels_groupings (dict[str, dict]): Hierarchical groupings for ontology classes.
|
|
72
|
+
encoder (dict[str, dict]): Label encoders mapping strings to integers.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If genedf is None and "organism_ontology_term_id" is not in clss_to_predict.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> collection = ln.Collection.filter(key="my_collection").first()
|
|
79
|
+
>>> dataset = Dataset(
|
|
80
|
+
... lamin_dataset=collection,
|
|
81
|
+
... clss_to_predict=["organism_ontology_term_id", "cell_type_ontology_term_id"],
|
|
82
|
+
... hierarchical_clss=["cell_type_ontology_term_id"],
|
|
83
|
+
... )
|
|
84
|
+
>>> sample = dataset[0] # Returns dict with "X" and encoded labels
|
|
49
85
|
"""
|
|
50
86
|
|
|
51
87
|
lamin_dataset: ln.Collection
|
|
@@ -165,7 +201,20 @@ class Dataset(torchDataset):
|
|
|
165
201
|
self,
|
|
166
202
|
obs_keys: Union[str, List[str]],
|
|
167
203
|
):
|
|
168
|
-
"""
|
|
204
|
+
"""
|
|
205
|
+
Get combined categorical codes for one or more label columns.
|
|
206
|
+
|
|
207
|
+
Retrieves labels from the mapped dataset and combines them into a single
|
|
208
|
+
categorical encoding. Useful for creating compound class labels for
|
|
209
|
+
stratified sampling.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
obs_keys (str | List[str]): Column name(s) to retrieve and combine.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
np.ndarray: Integer codes representing the (combined) categories.
|
|
216
|
+
Shape: (n_samples,).
|
|
217
|
+
"""
|
|
169
218
|
if isinstance(obs_keys, str):
|
|
170
219
|
obs_keys = [obs_keys]
|
|
171
220
|
labels = None
|
|
@@ -179,25 +228,37 @@ class Dataset(torchDataset):
|
|
|
179
228
|
|
|
180
229
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
181
230
|
"""
|
|
182
|
-
|
|
231
|
+
Get genes marked as unseen for a specific sample.
|
|
232
|
+
|
|
233
|
+
Retrieves the list of genes that were not observed (expression = 0 or
|
|
234
|
+
marked as unseen) for the sample at the given index.
|
|
183
235
|
|
|
184
236
|
Args:
|
|
185
|
-
idx (int): index
|
|
237
|
+
idx (int): Sample index in the dataset.
|
|
186
238
|
|
|
187
239
|
Returns:
|
|
188
|
-
List[str]:
|
|
240
|
+
List[str]: List of unseen gene identifiers.
|
|
189
241
|
"""
|
|
190
242
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
191
243
|
|
|
192
244
|
def define_hierarchies(self, clsses: List[str]):
|
|
193
245
|
"""
|
|
194
|
-
|
|
246
|
+
Define hierarchical label groupings from ontology relationships.
|
|
247
|
+
|
|
248
|
+
Uses Bionty to retrieve parent-child relationships for ontology terms,
|
|
249
|
+
then builds groupings mapping parent terms to their descendants.
|
|
250
|
+
Updates encoders to include parent terms and reorders labels so that
|
|
251
|
+
leaf terms (directly predictable) come first.
|
|
195
252
|
|
|
196
253
|
Args:
|
|
197
|
-
clsses (List[str]):
|
|
254
|
+
clsses (List[str]): List of ontology column names to process.
|
|
198
255
|
|
|
199
256
|
Raises:
|
|
200
|
-
ValueError:
|
|
257
|
+
ValueError: If a class name is not in the supported ontology types.
|
|
258
|
+
|
|
259
|
+
Note:
|
|
260
|
+
Modifies self.labels_groupings, self.class_topred, and
|
|
261
|
+
self.mapped_dataset.encoders in place.
|
|
201
262
|
"""
|
|
202
263
|
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
203
264
|
self.labels_groupings = {}
|
|
@@ -318,6 +379,7 @@ class Dataset(torchDataset):
|
|
|
318
379
|
|
|
319
380
|
|
|
320
381
|
class SimpleAnnDataset(torchDataset):
|
|
382
|
+
|
|
321
383
|
def __init__(
|
|
322
384
|
self,
|
|
323
385
|
adata: AnnData,
|
|
@@ -327,16 +389,40 @@ class SimpleAnnDataset(torchDataset):
|
|
|
327
389
|
encoder: Optional[dict[str, dict]] = None,
|
|
328
390
|
):
|
|
329
391
|
"""
|
|
330
|
-
|
|
331
|
-
|
|
392
|
+
Simple PyTorch Dataset wrapper for a single AnnData object.
|
|
393
|
+
|
|
394
|
+
Provides a lightweight interface for using AnnData with PyTorch DataLoaders,
|
|
395
|
+
compatible with the scDataLoader collator. Useful for inference on new data
|
|
396
|
+
that isn't stored in LaminDB.
|
|
332
397
|
|
|
333
398
|
Args:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
layer (str):
|
|
338
|
-
|
|
339
|
-
|
|
399
|
+
adata (AnnData): AnnData object containing expression data.
|
|
400
|
+
obs_to_output (List[str], optional): Observation columns to include in
|
|
401
|
+
output dictionaries. Defaults to [].
|
|
402
|
+
layer (str, optional): Layer name to use for expression values. If None,
|
|
403
|
+
uses adata.X. Defaults to None.
|
|
404
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor
|
|
405
|
+
expression data. Requires precomputed neighbors in adata.obsp.
|
|
406
|
+
Defaults to False.
|
|
407
|
+
encoder (dict[str, dict], optional): Dictionary mapping observation column
|
|
408
|
+
names to encoding dictionaries (str -> int). Defaults to None.
|
|
409
|
+
|
|
410
|
+
Attributes:
|
|
411
|
+
adataX (np.ndarray): Dense expression matrix.
|
|
412
|
+
encoder (dict): Label encoders.
|
|
413
|
+
obs_to_output (pd.DataFrame): Observation metadata to include.
|
|
414
|
+
distances (scipy.sparse matrix): KNN distance matrix (if get_knn_cells=True).
|
|
415
|
+
|
|
416
|
+
Raises:
|
|
417
|
+
ValueError: If get_knn_cells=True but "connectivities" not in adata.obsp.
|
|
418
|
+
|
|
419
|
+
Example:
|
|
420
|
+
>>> dataset = SimpleAnnDataset(
|
|
421
|
+
... adata=my_adata,
|
|
422
|
+
... obs_to_output=["cell_type", "organism_ontology_term_id"],
|
|
423
|
+
... encoder={"organism_ontology_term_id": {"NCBITaxon:9606": 0}},
|
|
424
|
+
... )
|
|
425
|
+
>>> loader = DataLoader(dataset, batch_size=32, collate_fn=collator)
|
|
340
426
|
"""
|
|
341
427
|
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
342
428
|
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
@@ -392,6 +478,46 @@ def mapped(
|
|
|
392
478
|
store_location: str | None = None,
|
|
393
479
|
force_recompute_indices: bool = False,
|
|
394
480
|
) -> MappedCollection:
|
|
481
|
+
"""
|
|
482
|
+
Create a MappedCollection from a LaminDB Collection.
|
|
483
|
+
|
|
484
|
+
Factory function that handles artifact path resolution (staging or streaming)
|
|
485
|
+
and creates a MappedCollection for efficient access to multiple h5ad/zarr files.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
dataset (ln.Collection): LaminDB Collection containing artifacts to map.
|
|
489
|
+
obs_keys (List[str], optional): Observation columns to load. Defaults to None.
|
|
490
|
+
obsm_keys (List[str], optional): Obsm keys to load. Defaults to None.
|
|
491
|
+
obs_filter (dict, optional): Filter observations by column values.
|
|
492
|
+
Keys are column names, values are allowed values. Defaults to None.
|
|
493
|
+
join (str, optional): How to join variables across files. "inner" for
|
|
494
|
+
intersection, "outer" for union. Defaults to "inner".
|
|
495
|
+
encode_labels (bool | List[str], optional): Whether/which columns to
|
|
496
|
+
integer-encode. True encodes all obs_keys. Defaults to True.
|
|
497
|
+
unknown_label (str | dict, optional): Label to use for unknown/missing
|
|
498
|
+
categories. Defaults to None.
|
|
499
|
+
cache_categories (bool, optional): Whether to cache category mappings.
|
|
500
|
+
Defaults to True.
|
|
501
|
+
parallel (bool, optional): Enable parallel data loading. Defaults to False.
|
|
502
|
+
dtype (str, optional): Data type for expression values. Defaults to None.
|
|
503
|
+
stream (bool, optional): If True, stream from cloud storage instead of
|
|
504
|
+
staging locally. Defaults to False.
|
|
505
|
+
is_run_input (bool, optional): Track as run input in LaminDB. Defaults to None.
|
|
506
|
+
metacell_mode (bool, optional): Enable metacell aggregation. Defaults to False.
|
|
507
|
+
meta_assays (List[str], optional): Assay types to treat as metacell-like.
|
|
508
|
+
Defaults to ["EFO:0022857", "EFO:0010961"].
|
|
509
|
+
get_knn_cells (bool, optional): Include KNN neighbor data. Defaults to False.
|
|
510
|
+
store_location (str, optional): Cache directory path. Defaults to None.
|
|
511
|
+
force_recompute_indices (bool, optional): Force recompute cached data.
|
|
512
|
+
Defaults to False.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
MappedCollection: Mapped collection for data access.
|
|
516
|
+
|
|
517
|
+
Note:
|
|
518
|
+
Artifacts with suffixes other than .h5ad, .zrad, .zarr are ignored.
|
|
519
|
+
Non-existent paths are skipped with a warning.
|
|
520
|
+
"""
|
|
395
521
|
path_list = []
|
|
396
522
|
for artifact in dataset.artifacts.all():
|
|
397
523
|
if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
|
|
@@ -425,14 +551,26 @@ def mapped(
|
|
|
425
551
|
|
|
426
552
|
|
|
427
553
|
def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
|
|
428
|
-
"""
|
|
429
|
-
|
|
554
|
+
"""
|
|
555
|
+
Efficiently combine multiple categorical arrays into a single encoding.
|
|
556
|
+
|
|
557
|
+
Creates a combined categorical where each unique combination of input
|
|
558
|
+
categories gets a unique code. Only combinations that exist in the data
|
|
559
|
+
are assigned codes (sparse encoding).
|
|
430
560
|
|
|
431
561
|
Args:
|
|
432
|
-
series_list: List of
|
|
562
|
+
series_list (List[pd.Categorical]): List of categorical arrays to combine.
|
|
563
|
+
All arrays must have the same length.
|
|
433
564
|
|
|
434
565
|
Returns:
|
|
435
|
-
Combined
|
|
566
|
+
pd.Categorical: Combined categorical with compressed codes representing
|
|
567
|
+
unique combinations present in the data.
|
|
568
|
+
|
|
569
|
+
Example:
|
|
570
|
+
>>> cat1 = pd.Categorical(["a", "a", "b", "b"])
|
|
571
|
+
>>> cat2 = pd.Categorical(["x", "y", "x", "y"])
|
|
572
|
+
>>> combined = concat_categorical_codes([cat1, cat2])
|
|
573
|
+
>>> # Results in 4 unique codes for (a,x), (a,y), (b,x), (b,y)
|
|
436
574
|
"""
|
|
437
575
|
# Get the codes for each categorical
|
|
438
576
|
codes_list = [s.codes.astype(np.int32) for s in series_list]
|