scdataloader 2.0.3__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/collator.py +99 -36
- scdataloader/data.py +177 -39
- scdataloader/datamodule.py +221 -57
- scdataloader/preprocess.py +33 -24
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.4.dist-info}/METADATA +1 -1
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.4.dist-info}/RECORD +9 -9
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.4.dist-info}/WHEEL +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.4.dist-info}/entry_points.txt +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.4.dist-info}/licenses/LICENSE +0 -0
scdataloader/collator.py
CHANGED
|
@@ -27,37 +27,60 @@ class Collator:
|
|
|
27
27
|
genedf: Optional[pd.DataFrame] = None,
|
|
28
28
|
):
|
|
29
29
|
"""
|
|
30
|
-
|
|
31
|
-
organization and preparation of gene expression data from different organisms,
|
|
32
|
-
allowing for various configurations such as maximum gene list length, normalization,
|
|
33
|
-
and selection method for gene expression.
|
|
30
|
+
Collator for preparing gene expression data batches for the scPRINT model.
|
|
34
31
|
|
|
35
|
-
This
|
|
32
|
+
This class handles the organization and preparation of gene expression data from
|
|
33
|
+
different organisms, allowing for various configurations such as maximum gene list
|
|
34
|
+
length, normalization, binning, and gene selection strategies.
|
|
35
|
+
|
|
36
|
+
Compatible with scVI's dataloader and other PyTorch data loading pipelines.
|
|
36
37
|
|
|
37
38
|
Args:
|
|
38
|
-
organisms (
|
|
39
|
-
|
|
40
|
-
how (
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
"
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
valid_genes (
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
39
|
+
organisms (List[str]): List of organism ontology term IDs to include.
|
|
40
|
+
Samples from other organisms will be dropped (may lead to variable batch sizes).
|
|
41
|
+
how (str, optional): Gene selection strategy. Defaults to "all".
|
|
42
|
+
- "most expr": Select the `max_len` most expressed genes. If fewer genes
|
|
43
|
+
are expressed, randomly sample unexpressed genes to fill.
|
|
44
|
+
- "random expr": Randomly select `max_len` expressed genes. If fewer genes
|
|
45
|
+
are expressed, randomly sample unexpressed genes to fill.
|
|
46
|
+
- "all": Use all genes without filtering.
|
|
47
|
+
- "some": Use only genes specified in the `genelist` parameter.
|
|
48
|
+
org_to_id (dict[str, int], optional): Mapping from organism names to integer IDs.
|
|
49
|
+
If None, organism names are used directly. Defaults to None.
|
|
50
|
+
valid_genes (List[str], optional): List of gene names to consider from input data.
|
|
51
|
+
Genes not in this list will be dropped. Useful when the model only supports
|
|
52
|
+
specific genes. Defaults to None (use all genes).
|
|
53
|
+
max_len (int, optional): Maximum number of genes to include when using "most expr"
|
|
54
|
+
or "random expr" selection methods. Defaults to 2000.
|
|
55
|
+
add_zero_genes (int, optional): Number of additional unexpressed genes to include
|
|
56
|
+
in the output. Only applies when `how` is "most expr" or "random expr".
|
|
57
|
+
Defaults to 0.
|
|
58
|
+
logp1 (bool, optional): Apply log2(1 + x) transformation to expression values.
|
|
59
|
+
Applied after normalization if both are enabled. Defaults to False.
|
|
60
|
+
norm_to (float, optional): Target sum for count normalization. Expression values
|
|
61
|
+
are scaled so that total counts equal this value. Defaults to None (no normalization).
|
|
62
|
+
n_bins (int, optional): Number of bins for expression value binning. If 0, no
|
|
63
|
+
binning is applied. Binning uses quantile-based discretization. Defaults to 0.
|
|
64
|
+
tp_name (str, optional): Column name in batch data for time point or heat diffusion
|
|
65
|
+
values. If None, time point values default to 0. Defaults to None.
|
|
66
|
+
organism_name (str, optional): Column name in batch data for organism ontology
|
|
67
|
+
term ID. Defaults to "organism_ontology_term_id".
|
|
68
|
+
class_names (List[str], optional): List of additional metadata column names to
|
|
69
|
+
include in the output. Defaults to [].
|
|
70
|
+
genelist (List[str], optional): List of specific genes to use when `how="some"`.
|
|
71
|
+
Required if `how="some"`. Defaults to [].
|
|
72
|
+
genedf (pd.DataFrame, optional): DataFrame containing gene information indexed by
|
|
73
|
+
gene name with an 'organism' column. If None, loaded automatically using
|
|
74
|
+
`load_genes()`. Defaults to None.
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
organism_ids (set): Set of organism IDs being processed.
|
|
78
|
+
start_idx (dict): Mapping from organism ID to starting gene index in the model.
|
|
79
|
+
accepted_genes (dict): Boolean masks for valid genes per organism.
|
|
80
|
+
to_subset (dict): Boolean masks for genelist filtering per organism.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
AssertionError: If `how="some"` but `genelist` is empty.
|
|
61
84
|
"""
|
|
62
85
|
self.organisms = organisms
|
|
63
86
|
self.max_len = max_len
|
|
@@ -77,6 +100,19 @@ class Collator:
|
|
|
77
100
|
self._setup(genedf, org_to_id, valid_genes, genelist)
|
|
78
101
|
|
|
79
102
|
def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
|
|
103
|
+
"""
|
|
104
|
+
Initialize gene mappings and indices for each organism.
|
|
105
|
+
|
|
106
|
+
Sets up internal data structures for gene filtering, organism-specific
|
|
107
|
+
gene indices, and gene subsetting based on the provided configuration.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
genedf (pd.DataFrame, optional): Gene information DataFrame. If None,
|
|
111
|
+
loaded via `load_genes()`. Defaults to None.
|
|
112
|
+
org_to_id (dict, optional): Organism name to ID mapping. Defaults to None.
|
|
113
|
+
valid_genes (List[str], optional): Genes to accept from input. Defaults to [].
|
|
114
|
+
genelist (List[str], optional): Genes to subset to when `how="some"`. Defaults to [].
|
|
115
|
+
"""
|
|
80
116
|
if genedf is None:
|
|
81
117
|
genedf = load_genes(self.organisms)
|
|
82
118
|
self.organism_ids = (
|
|
@@ -108,18 +144,45 @@ class Collator:
|
|
|
108
144
|
|
|
109
145
|
def __call__(self, batch) -> dict[str, Tensor]:
|
|
110
146
|
"""
|
|
111
|
-
|
|
147
|
+
Collate a minibatch of gene expression data.
|
|
148
|
+
|
|
149
|
+
Processes a list of sample dictionaries, applying gene selection, normalization,
|
|
150
|
+
log transformation, and binning as configured. Filters out samples from organisms
|
|
151
|
+
not in the configured organism list.
|
|
112
152
|
|
|
113
153
|
Args:
|
|
114
|
-
batch (List[dict
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
154
|
+
batch (List[dict]): List of sample dictionaries, each containing:
|
|
155
|
+
- "X" (array): Gene expression values.
|
|
156
|
+
- organism_name (any): Organism identifier (column name set by `organism_name`).
|
|
157
|
+
- tp_name (float, optional): Time point value (column name set by `tp_name`).
|
|
158
|
+
- class_names... (any, optional): Additional class labels.
|
|
159
|
+
- "_storage_idx" (int, optional): Dataset storage index.
|
|
160
|
+
- "is_meta" (int, optional): Metadata flag.
|
|
161
|
+
- "knn_cells" (array, optional): KNN neighbor expression data.
|
|
162
|
+
- "knn_cells_info" (array, optional): KNN neighbor metadata.
|
|
120
163
|
|
|
121
164
|
Returns:
|
|
122
|
-
|
|
165
|
+
dict[str, Tensor]: Dictionary containing collated tensors:
|
|
166
|
+
- "x" (Tensor): Gene expression matrix of shape (batch_size, n_genes).
|
|
167
|
+
Values may be raw counts, normalized, log-transformed, or binned
|
|
168
|
+
depending on configuration.
|
|
169
|
+
- "genes" (Tensor): Gene indices of shape (batch_size, n_genes) as int32.
|
|
170
|
+
Indices correspond to positions in the model's gene vocabulary.
|
|
171
|
+
- "class" (Tensor): Class labels of shape (batch_size, n_classes) as int32.
|
|
172
|
+
- "tp" (Tensor): Time point values of shape (batch_size,).
|
|
173
|
+
- "depth" (Tensor): Total counts per cell of shape (batch_size,).
|
|
174
|
+
- "is_meta" (Tensor, optional): Metadata flags as int32. Present if input
|
|
175
|
+
contains "is_meta".
|
|
176
|
+
- "knn_cells" (Tensor, optional): KNN expression data. Present if input
|
|
177
|
+
contains "knn_cells".
|
|
178
|
+
- "knn_cells_info" (Tensor, optional): KNN metadata. Present if input
|
|
179
|
+
contains "knn_cells_info".
|
|
180
|
+
- "dataset" (Tensor, optional): Dataset indices as int64. Present if input
|
|
181
|
+
contains "_storage_idx".
|
|
182
|
+
|
|
183
|
+
Note:
|
|
184
|
+
Batch size in output may be smaller than input if some samples are filtered
|
|
185
|
+
out due to organism mismatch.
|
|
123
186
|
"""
|
|
124
187
|
# do count selection
|
|
125
188
|
# get the unseen info and don't add any unseen
|
scdataloader/data.py
CHANGED
|
@@ -22,30 +22,66 @@ from .mapped import MappedCollection, _Connect
|
|
|
22
22
|
@dataclass
|
|
23
23
|
class Dataset(torchDataset):
|
|
24
24
|
"""
|
|
25
|
-
Dataset
|
|
25
|
+
PyTorch Dataset for loading single-cell data from a LaminDB Collection.
|
|
26
26
|
|
|
27
|
-
This
|
|
28
|
-
|
|
27
|
+
This class wraps LaminDB's MappedCollection to provide additional features:
|
|
28
|
+
- Management of hierarchical ontology labels (cell type, tissue, disease, etc.)
|
|
29
|
+
- Automatic encoding of categorical labels to integers
|
|
30
|
+
- Multi-species gene handling with unified gene indexing
|
|
31
|
+
- Optional metacell aggregation and KNN neighbor retrieval
|
|
29
32
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
.. note::
|
|
33
|
-
|
|
34
|
-
A related data loader exists `here
|
|
35
|
-
<https://github.com/Genentech/scimilarity>`__.
|
|
33
|
+
The dataset lazily loads data from storage, making it memory-efficient for
|
|
34
|
+
large collections spanning multiple files.
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
clss_to_predict (List[str]):
|
|
43
|
-
|
|
44
|
-
hierarchical_clss
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
37
|
+
lamin_dataset (ln.Collection): LaminDB Collection containing the artifacts to load.
|
|
38
|
+
genedf (pd.DataFrame, optional): DataFrame with gene information, indexed by gene ID
|
|
39
|
+
with an 'organism' column. If None, automatically loaded based on organisms
|
|
40
|
+
in the dataset. Defaults to None.
|
|
41
|
+
clss_to_predict (List[str], optional): Observation columns to encode as prediction
|
|
42
|
+
targets. These will be integer-encoded in the output. Defaults to [].
|
|
43
|
+
hierarchical_clss (List[str], optional): Observation columns with hierarchical
|
|
44
|
+
ontology structure. These will have their ancestry relationships computed
|
|
45
|
+
using Bionty. Supported columns:
|
|
46
|
+
- "cell_type_ontology_term_id"
|
|
47
|
+
- "tissue_ontology_term_id"
|
|
48
|
+
- "disease_ontology_term_id"
|
|
49
|
+
- "development_stage_ontology_term_id"
|
|
50
|
+
- "assay_ontology_term_id"
|
|
51
|
+
- "self_reported_ethnicity_ontology_term_id"
|
|
52
|
+
Defaults to [].
|
|
53
|
+
join_vars (str, optional): How to join variables across artifacts.
|
|
54
|
+
"inner" for intersection, "outer" for union, None for no joining.
|
|
55
|
+
Defaults to None.
|
|
56
|
+
metacell_mode (float, optional): Probability of returning aggregated metacell
|
|
57
|
+
expression instead of single-cell. Defaults to 0.0.
|
|
58
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
|
|
59
|
+
expression in the output. Requires precomputed neighbors in the data.
|
|
60
|
+
Defaults to False.
|
|
61
|
+
store_location (str, optional): Directory path to cache computed indices.
|
|
62
|
+
Defaults to None.
|
|
63
|
+
force_recompute_indices (bool, optional): Force recomputation of cached data.
|
|
64
|
+
Defaults to False.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
mapped_dataset (MappedCollection): Underlying mapped collection for data access.
|
|
68
|
+
genedf (pd.DataFrame): Gene information DataFrame.
|
|
69
|
+
organisms (List[str]): List of organism ontology term IDs in the dataset.
|
|
70
|
+
class_topred (dict[str, set]): Mapping from class name to set of valid labels.
|
|
71
|
+
labels_groupings (dict[str, dict]): Hierarchical groupings for ontology classes.
|
|
72
|
+
encoder (dict[str, dict]): Label encoders mapping strings to integers.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If genedf is None and "organism_ontology_term_id" is not in clss_to_predict.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> collection = ln.Collection.filter(key="my_collection").first()
|
|
79
|
+
>>> dataset = Dataset(
|
|
80
|
+
... lamin_dataset=collection,
|
|
81
|
+
... clss_to_predict=["organism_ontology_term_id", "cell_type_ontology_term_id"],
|
|
82
|
+
... hierarchical_clss=["cell_type_ontology_term_id"],
|
|
83
|
+
... )
|
|
84
|
+
>>> sample = dataset[0] # Returns dict with "X" and encoded labels
|
|
49
85
|
"""
|
|
50
86
|
|
|
51
87
|
lamin_dataset: ln.Collection
|
|
@@ -165,7 +201,20 @@ class Dataset(torchDataset):
|
|
|
165
201
|
self,
|
|
166
202
|
obs_keys: Union[str, List[str]],
|
|
167
203
|
):
|
|
168
|
-
"""
|
|
204
|
+
"""
|
|
205
|
+
Get combined categorical codes for one or more label columns.
|
|
206
|
+
|
|
207
|
+
Retrieves labels from the mapped dataset and combines them into a single
|
|
208
|
+
categorical encoding. Useful for creating compound class labels for
|
|
209
|
+
stratified sampling.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
obs_keys (str | List[str]): Column name(s) to retrieve and combine.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
np.ndarray: Integer codes representing the (combined) categories.
|
|
216
|
+
Shape: (n_samples,).
|
|
217
|
+
"""
|
|
169
218
|
if isinstance(obs_keys, str):
|
|
170
219
|
obs_keys = [obs_keys]
|
|
171
220
|
labels = None
|
|
@@ -179,25 +228,37 @@ class Dataset(torchDataset):
|
|
|
179
228
|
|
|
180
229
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
181
230
|
"""
|
|
182
|
-
|
|
231
|
+
Get genes marked as unseen for a specific sample.
|
|
232
|
+
|
|
233
|
+
Retrieves the list of genes that were not observed (expression = 0 or
|
|
234
|
+
marked as unseen) for the sample at the given index.
|
|
183
235
|
|
|
184
236
|
Args:
|
|
185
|
-
idx (int): index
|
|
237
|
+
idx (int): Sample index in the dataset.
|
|
186
238
|
|
|
187
239
|
Returns:
|
|
188
|
-
List[str]:
|
|
240
|
+
List[str]: List of unseen gene identifiers.
|
|
189
241
|
"""
|
|
190
242
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
191
243
|
|
|
192
244
|
def define_hierarchies(self, clsses: List[str]):
|
|
193
245
|
"""
|
|
194
|
-
|
|
246
|
+
Define hierarchical label groupings from ontology relationships.
|
|
247
|
+
|
|
248
|
+
Uses Bionty to retrieve parent-child relationships for ontology terms,
|
|
249
|
+
then builds groupings mapping parent terms to their descendants.
|
|
250
|
+
Updates encoders to include parent terms and reorders labels so that
|
|
251
|
+
leaf terms (directly predictable) come first.
|
|
195
252
|
|
|
196
253
|
Args:
|
|
197
|
-
clsses (List[str]):
|
|
254
|
+
clsses (List[str]): List of ontology column names to process.
|
|
198
255
|
|
|
199
256
|
Raises:
|
|
200
|
-
ValueError:
|
|
257
|
+
ValueError: If a class name is not in the supported ontology types.
|
|
258
|
+
|
|
259
|
+
Note:
|
|
260
|
+
Modifies self.labels_groupings, self.class_topred, and
|
|
261
|
+
self.mapped_dataset.encoders in place.
|
|
201
262
|
"""
|
|
202
263
|
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
203
264
|
self.labels_groupings = {}
|
|
@@ -318,6 +379,7 @@ class Dataset(torchDataset):
|
|
|
318
379
|
|
|
319
380
|
|
|
320
381
|
class SimpleAnnDataset(torchDataset):
|
|
382
|
+
|
|
321
383
|
def __init__(
|
|
322
384
|
self,
|
|
323
385
|
adata: AnnData,
|
|
@@ -327,16 +389,40 @@ class SimpleAnnDataset(torchDataset):
|
|
|
327
389
|
encoder: Optional[dict[str, dict]] = None,
|
|
328
390
|
):
|
|
329
391
|
"""
|
|
330
|
-
|
|
331
|
-
|
|
392
|
+
Simple PyTorch Dataset wrapper for a single AnnData object.
|
|
393
|
+
|
|
394
|
+
Provides a lightweight interface for using AnnData with PyTorch DataLoaders,
|
|
395
|
+
compatible with the scDataLoader collator. Useful for inference on new data
|
|
396
|
+
that isn't stored in LaminDB.
|
|
332
397
|
|
|
333
398
|
Args:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
layer (str):
|
|
338
|
-
|
|
339
|
-
|
|
399
|
+
adata (AnnData): AnnData object containing expression data.
|
|
400
|
+
obs_to_output (List[str], optional): Observation columns to include in
|
|
401
|
+
output dictionaries. Defaults to [].
|
|
402
|
+
layer (str, optional): Layer name to use for expression values. If None,
|
|
403
|
+
uses adata.X. Defaults to None.
|
|
404
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor
|
|
405
|
+
expression data. Requires precomputed neighbors in adata.obsp.
|
|
406
|
+
Defaults to False.
|
|
407
|
+
encoder (dict[str, dict], optional): Dictionary mapping observation column
|
|
408
|
+
names to encoding dictionaries (str -> int). Defaults to None.
|
|
409
|
+
|
|
410
|
+
Attributes:
|
|
411
|
+
adataX (np.ndarray): Dense expression matrix.
|
|
412
|
+
encoder (dict): Label encoders.
|
|
413
|
+
obs_to_output (pd.DataFrame): Observation metadata to include.
|
|
414
|
+
distances (scipy.sparse matrix): KNN distance matrix (if get_knn_cells=True).
|
|
415
|
+
|
|
416
|
+
Raises:
|
|
417
|
+
ValueError: If get_knn_cells=True but "connectivities" not in adata.obsp.
|
|
418
|
+
|
|
419
|
+
Example:
|
|
420
|
+
>>> dataset = SimpleAnnDataset(
|
|
421
|
+
... adata=my_adata,
|
|
422
|
+
... obs_to_output=["cell_type", "organism_ontology_term_id"],
|
|
423
|
+
... encoder={"organism_ontology_term_id": {"NCBITaxon:9606": 0}},
|
|
424
|
+
... )
|
|
425
|
+
>>> loader = DataLoader(dataset, batch_size=32, collate_fn=collator)
|
|
340
426
|
"""
|
|
341
427
|
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
342
428
|
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
@@ -392,6 +478,46 @@ def mapped(
|
|
|
392
478
|
store_location: str | None = None,
|
|
393
479
|
force_recompute_indices: bool = False,
|
|
394
480
|
) -> MappedCollection:
|
|
481
|
+
"""
|
|
482
|
+
Create a MappedCollection from a LaminDB Collection.
|
|
483
|
+
|
|
484
|
+
Factory function that handles artifact path resolution (staging or streaming)
|
|
485
|
+
and creates a MappedCollection for efficient access to multiple h5ad/zarr files.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
dataset (ln.Collection): LaminDB Collection containing artifacts to map.
|
|
489
|
+
obs_keys (List[str], optional): Observation columns to load. Defaults to None.
|
|
490
|
+
obsm_keys (List[str], optional): Obsm keys to load. Defaults to None.
|
|
491
|
+
obs_filter (dict, optional): Filter observations by column values.
|
|
492
|
+
Keys are column names, values are allowed values. Defaults to None.
|
|
493
|
+
join (str, optional): How to join variables across files. "inner" for
|
|
494
|
+
intersection, "outer" for union. Defaults to "inner".
|
|
495
|
+
encode_labels (bool | List[str], optional): Whether/which columns to
|
|
496
|
+
integer-encode. True encodes all obs_keys. Defaults to True.
|
|
497
|
+
unknown_label (str | dict, optional): Label to use for unknown/missing
|
|
498
|
+
categories. Defaults to None.
|
|
499
|
+
cache_categories (bool, optional): Whether to cache category mappings.
|
|
500
|
+
Defaults to True.
|
|
501
|
+
parallel (bool, optional): Enable parallel data loading. Defaults to False.
|
|
502
|
+
dtype (str, optional): Data type for expression values. Defaults to None.
|
|
503
|
+
stream (bool, optional): If True, stream from cloud storage instead of
|
|
504
|
+
staging locally. Defaults to False.
|
|
505
|
+
is_run_input (bool, optional): Track as run input in LaminDB. Defaults to None.
|
|
506
|
+
metacell_mode (bool, optional): Enable metacell aggregation. Defaults to False.
|
|
507
|
+
meta_assays (List[str], optional): Assay types to treat as metacell-like.
|
|
508
|
+
Defaults to ["EFO:0022857", "EFO:0010961"].
|
|
509
|
+
get_knn_cells (bool, optional): Include KNN neighbor data. Defaults to False.
|
|
510
|
+
store_location (str, optional): Cache directory path. Defaults to None.
|
|
511
|
+
force_recompute_indices (bool, optional): Force recompute cached data.
|
|
512
|
+
Defaults to False.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
MappedCollection: Mapped collection for data access.
|
|
516
|
+
|
|
517
|
+
Note:
|
|
518
|
+
Artifacts with suffixes other than .h5ad, .zrad, .zarr are ignored.
|
|
519
|
+
Non-existent paths are skipped with a warning.
|
|
520
|
+
"""
|
|
395
521
|
path_list = []
|
|
396
522
|
for artifact in dataset.artifacts.all():
|
|
397
523
|
if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
|
|
@@ -425,14 +551,26 @@ def mapped(
|
|
|
425
551
|
|
|
426
552
|
|
|
427
553
|
def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
|
|
428
|
-
"""
|
|
429
|
-
|
|
554
|
+
"""
|
|
555
|
+
Efficiently combine multiple categorical arrays into a single encoding.
|
|
556
|
+
|
|
557
|
+
Creates a combined categorical where each unique combination of input
|
|
558
|
+
categories gets a unique code. Only combinations that exist in the data
|
|
559
|
+
are assigned codes (sparse encoding).
|
|
430
560
|
|
|
431
561
|
Args:
|
|
432
|
-
series_list: List of
|
|
562
|
+
series_list (List[pd.Categorical]): List of categorical arrays to combine.
|
|
563
|
+
All arrays must have the same length.
|
|
433
564
|
|
|
434
565
|
Returns:
|
|
435
|
-
Combined
|
|
566
|
+
pd.Categorical: Combined categorical with compressed codes representing
|
|
567
|
+
unique combinations present in the data.
|
|
568
|
+
|
|
569
|
+
Example:
|
|
570
|
+
>>> cat1 = pd.Categorical(["a", "a", "b", "b"])
|
|
571
|
+
>>> cat2 = pd.Categorical(["x", "y", "x", "y"])
|
|
572
|
+
>>> combined = concat_categorical_codes([cat1, cat2])
|
|
573
|
+
>>> # Results in 4 unique codes for (a,x), (a,y), (b,x), (b,y)
|
|
436
574
|
"""
|
|
437
575
|
# Get the codes for each categorical
|
|
438
576
|
codes_list = [s.codes.astype(np.int32) for s in series_list]
|
scdataloader/datamodule.py
CHANGED
|
@@ -69,39 +69,98 @@ class DataModule(L.LightningDataModule):
|
|
|
69
69
|
**kwargs,
|
|
70
70
|
):
|
|
71
71
|
"""
|
|
72
|
-
DataModule
|
|
73
|
-
it can work with bare pytorch too
|
|
72
|
+
PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
This DataModule provides train/val/test dataloaders with configurable sampling strategies.
|
|
75
|
+
It combines MappedCollection, Dataset, and Collator to create efficient data pipelines
|
|
76
|
+
for training single-cell foundation models.
|
|
77
|
+
|
|
78
|
+
The training dataloader uses weighted random sampling based on class frequencies,
|
|
79
|
+
validation uses random sampling, and test uses sequential sampling on held-out datasets.
|
|
77
80
|
|
|
78
81
|
Args:
|
|
79
|
-
collection_name (str):
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
82
|
+
collection_name (str): Key of the LaminDB Collection to load.
|
|
83
|
+
clss_to_weight (List[str], optional): Label columns to use for weighted sampling
|
|
84
|
+
in the training dataloader. Supports "nnz" for weighting by number of
|
|
85
|
+
non-zero genes. Defaults to ["organism_ontology_term_id"].
|
|
86
|
+
weight_scaler (int, optional): Controls balance between rare and common classes.
|
|
87
|
+
Higher values lead to more uniform sampling across classes. Set to 0 to
|
|
88
|
+
disable weighted sampling. Defaults to 10.
|
|
89
|
+
n_samples_per_epoch (int, optional): Number of samples to draw per training epoch.
|
|
90
|
+
Defaults to 2,000,000.
|
|
91
|
+
validation_split (float | int, optional): Proportion (float) or absolute number (int)
|
|
92
|
+
of samples for validation. Defaults to 0.2.
|
|
93
|
+
test_split (float | int, optional): Proportion (float) or absolute number (int)
|
|
94
|
+
of samples for testing. Uses entire datasets as test sets, rounding to
|
|
95
|
+
nearest dataset boundary. Defaults to 0.
|
|
96
|
+
use_default_col (bool, optional): Whether to use the default Collator for batch
|
|
97
|
+
preparation. If False, no collate_fn is applied. Defaults to True.
|
|
98
|
+
clss_to_predict (List[str], optional): Observation columns to encode as prediction
|
|
99
|
+
targets. Must include "organism_ontology_term_id". Defaults to
|
|
100
|
+
["organism_ontology_term_id"].
|
|
101
|
+
hierarchical_clss (List[str], optional): Observation columns with hierarchical
|
|
102
|
+
ontology structure to be processed. Defaults to [].
|
|
103
|
+
how (str, optional): Gene selection strategy passed to Collator. One of
|
|
104
|
+
"most expr", "random expr", "all", "some". Defaults to "random expr".
|
|
105
|
+
organism_col (str, optional): Column name for organism ontology term ID.
|
|
106
|
+
Defaults to "organism_ontology_term_id".
|
|
107
|
+
max_len (int, optional): Maximum number of genes per sample passed to Collator.
|
|
108
|
+
Defaults to 1000.
|
|
109
|
+
replacement (bool, optional): Whether to sample with replacement in training.
|
|
110
|
+
Defaults to True.
|
|
111
|
+
gene_subset (List[str], optional): List of genes to restrict the dataset to.
|
|
112
|
+
Useful when model only supports specific genes. Defaults to None.
|
|
113
|
+
tp_name (str, optional): Column name for time point or heat diffusion values.
|
|
114
|
+
Defaults to None.
|
|
115
|
+
assays_to_drop (List[str], optional): List of assay ontology term IDs to exclude
|
|
116
|
+
from training. Defaults to ["EFO:0030007"] (ATAC-seq).
|
|
117
|
+
metacell_mode (float, optional): Probability of using metacell aggregation mode.
|
|
118
|
+
Cannot be used with get_knn_cells. Defaults to 0.0.
|
|
119
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
|
|
120
|
+
expression data. Cannot be used with metacell_mode. Defaults to False.
|
|
121
|
+
store_location (str, optional): Directory path to cache sampler indices and
|
|
122
|
+
labels for faster subsequent loading. Defaults to None.
|
|
123
|
+
force_recompute_indices (bool, optional): Force recomputation of cached indices
|
|
124
|
+
even if they exist. Defaults to False.
|
|
125
|
+
sampler_workers (int, optional): Number of parallel workers for building sampler
|
|
126
|
+
indices. Auto-determined based on available CPUs if None. Defaults to None.
|
|
127
|
+
sampler_chunk_size (int, optional): Chunk size for parallel sampler processing.
|
|
128
|
+
Auto-determined based on available memory if None. Defaults to None.
|
|
129
|
+
organisms (List[str], optional): List of organisms to include. If None, uses
|
|
130
|
+
all organisms in the dataset. Defaults to None.
|
|
131
|
+
genedf (pd.DataFrame, optional): Gene information DataFrame. If None, loaded
|
|
132
|
+
automatically. Defaults to None.
|
|
133
|
+
n_bins (int, optional): Number of bins for expression discretization. 0 means
|
|
134
|
+
no binning. Defaults to 0.
|
|
135
|
+
curiculum (int, optional): Curriculum learning parameter. If > 0, gradually
|
|
136
|
+
increases sampling weight balance over epochs. Defaults to 0.
|
|
137
|
+
start_at (int, optional): Starting index for resuming inference. Requires same
|
|
138
|
+
number of GPUs as previous run. Defaults to 0.
|
|
139
|
+
**kwargs: Additional arguments passed to PyTorch DataLoader (e.g., batch_size,
|
|
140
|
+
num_workers, pin_memory).
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
dataset (Dataset): The underlying Dataset instance.
|
|
144
|
+
classes (dict[str, int]): Mapping from class names to number of categories.
|
|
145
|
+
train_labels (np.ndarray): Label array for weighted sampling.
|
|
146
|
+
idx_full (np.ndarray): Indices for training samples.
|
|
147
|
+
valid_idx (np.ndarray): Indices for validation samples.
|
|
148
|
+
test_idx (np.ndarray): Indices for test samples.
|
|
149
|
+
test_datasets (List[str]): Paths to datasets used for testing.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: If "organism_ontology_term_id" not in clss_to_predict.
|
|
153
|
+
ValueError: If both metacell_mode > 0 and get_knn_cells are True.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
>>> dm = DataModule(
|
|
157
|
+
... collection_name="my_collection",
|
|
158
|
+
... batch_size=32,
|
|
159
|
+
... num_workers=4,
|
|
160
|
+
... max_len=2000,
|
|
161
|
+
... )
|
|
162
|
+
>>> dm.setup()
|
|
163
|
+
>>> train_loader = dm.train_dataloader()
|
|
105
164
|
"""
|
|
106
165
|
if "organism_ontology_term_id" not in clss_to_predict:
|
|
107
166
|
raise ValueError(
|
|
@@ -285,12 +344,24 @@ class DataModule(L.LightningDataModule):
|
|
|
285
344
|
|
|
286
345
|
def setup(self, stage=None):
|
|
287
346
|
"""
|
|
288
|
-
|
|
289
|
-
|
|
347
|
+
Prepare data splits for training, validation, and testing.
|
|
348
|
+
|
|
349
|
+
This method shuffles the data, computes sample weights for weighted sampling,
|
|
350
|
+
removes samples from dropped assays, and creates train/val/test splits.
|
|
351
|
+
Test splits use entire datasets to ensure evaluation on unseen data sources.
|
|
352
|
+
|
|
353
|
+
Results can be cached to `store_location` for faster subsequent runs.
|
|
290
354
|
|
|
291
355
|
Args:
|
|
292
|
-
stage (str, optional):
|
|
293
|
-
|
|
356
|
+
stage (str, optional): Training stage ('fit', 'test', or None for both).
|
|
357
|
+
Currently not used but kept for Lightning compatibility. Defaults to None.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
List[str]: List of paths to test datasets.
|
|
361
|
+
|
|
362
|
+
Note:
|
|
363
|
+
Must be called before using dataloaders. The train/val/test split is
|
|
364
|
+
deterministic when loading from cache.
|
|
294
365
|
"""
|
|
295
366
|
print("setting up the datamodule")
|
|
296
367
|
start_time = time.time()
|
|
@@ -441,6 +512,22 @@ class DataModule(L.LightningDataModule):
|
|
|
441
512
|
return self.test_datasets
|
|
442
513
|
|
|
443
514
|
def train_dataloader(self, **kwargs):
|
|
515
|
+
"""
|
|
516
|
+
Create the training DataLoader with weighted random sampling.
|
|
517
|
+
|
|
518
|
+
Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0
|
|
519
|
+
and clss_to_weight is specified. Otherwise uses RankShardSampler for
|
|
520
|
+
distributed training without weighting.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
**kwargs: Additional arguments passed to DataLoader, overriding defaults.
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
DataLoader: Training DataLoader instance.
|
|
527
|
+
|
|
528
|
+
Raises:
|
|
529
|
+
ValueError: If setup() has not been called.
|
|
530
|
+
"""
|
|
444
531
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
445
532
|
try:
|
|
446
533
|
print("Setting up the parallel train sampler...")
|
|
@@ -473,6 +560,12 @@ class DataModule(L.LightningDataModule):
|
|
|
473
560
|
)
|
|
474
561
|
|
|
475
562
|
def val_dataloader(self):
|
|
563
|
+
"""
|
|
564
|
+
Create the validation DataLoader.
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
DataLoader | List: Validation DataLoader, or empty list if no validation split.
|
|
568
|
+
"""
|
|
476
569
|
return (
|
|
477
570
|
DataLoader(
|
|
478
571
|
Subset(self.dataset, self.valid_idx),
|
|
@@ -483,6 +576,12 @@ class DataModule(L.LightningDataModule):
|
|
|
483
576
|
)
|
|
484
577
|
|
|
485
578
|
def test_dataloader(self):
|
|
579
|
+
"""
|
|
580
|
+
Create the test DataLoader with sequential sampling.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
DataLoader | List: Test DataLoader, or empty list if no test split.
|
|
584
|
+
"""
|
|
486
585
|
return (
|
|
487
586
|
DataLoader(
|
|
488
587
|
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
@@ -492,6 +591,14 @@ class DataModule(L.LightningDataModule):
|
|
|
492
591
|
)
|
|
493
592
|
|
|
494
593
|
def predict_dataloader(self):
|
|
594
|
+
"""
|
|
595
|
+
Create a DataLoader for prediction over all training data.
|
|
596
|
+
|
|
597
|
+
Uses RankShardSampler for distributed inference.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
DataLoader: Prediction DataLoader instance.
|
|
601
|
+
"""
|
|
495
602
|
subset = Subset(self.dataset, self.idx_full)
|
|
496
603
|
return DataLoader(
|
|
497
604
|
self.dataset,
|
|
@@ -501,15 +608,6 @@ class DataModule(L.LightningDataModule):
|
|
|
501
608
|
|
|
502
609
|
|
|
503
610
|
class LabelWeightedSampler(Sampler[int]):
|
|
504
|
-
"""
|
|
505
|
-
A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
|
|
506
|
-
|
|
507
|
-
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
508
|
-
1. Parallel building of class indices
|
|
509
|
-
2. Chunked processing for large arrays
|
|
510
|
-
3. Efficient memory management
|
|
511
|
-
4. Proper handling of replacement and non-replacement sampling
|
|
512
|
-
"""
|
|
513
611
|
|
|
514
612
|
label_weights: torch.Tensor
|
|
515
613
|
klass_indices: dict[int, torch.Tensor]
|
|
@@ -531,16 +629,58 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
531
629
|
curiculum: int = 0,
|
|
532
630
|
) -> None:
|
|
533
631
|
"""
|
|
534
|
-
|
|
632
|
+
Weighted random sampler balancing both class frequencies and element weights.
|
|
633
|
+
|
|
634
|
+
This sampler is optimized for very large datasets (millions of samples) with:
|
|
635
|
+
- Parallel construction of class indices using multiple CPU workers
|
|
636
|
+
- Chunked processing to manage memory usage
|
|
637
|
+
- Support for curriculum learning via progressive weight scaling
|
|
638
|
+
- Optional per-element weights (e.g., by number of expressed genes)
|
|
639
|
+
|
|
640
|
+
The sampling process:
|
|
641
|
+
1. Sample class labels according to class weights
|
|
642
|
+
2. For each sampled class, sample elements according to element weights
|
|
643
|
+
3. Shuffle all sampled indices
|
|
535
644
|
|
|
536
645
|
Args:
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
646
|
+
labels (np.ndarray): Integer class label for each dataset element.
|
|
647
|
+
Shape: (dataset_size,). The last unique label is treated as
|
|
648
|
+
"excluded" with weight 0.
|
|
649
|
+
num_samples (int): Number of samples to draw per epoch.
|
|
650
|
+
replacement (bool, optional): Whether to sample with replacement.
|
|
651
|
+
Defaults to True.
|
|
652
|
+
weight_scaler (float, optional): Controls class weight balance.
|
|
653
|
+
Weight formula: (scaler * count) / (count + scaler).
|
|
654
|
+
Higher values = more uniform sampling. Defaults to None.
|
|
655
|
+
element_weights (Sequence[float], optional): Per-element sampling weights.
|
|
656
|
+
Shape: (dataset_size,). Defaults to None (uniform within class).
|
|
657
|
+
n_workers (int, optional): Number of parallel workers for index building.
|
|
658
|
+
Defaults to min(20, num_cpus - 1).
|
|
659
|
+
chunk_size (int, optional): Elements per chunk for parallel processing.
|
|
660
|
+
Auto-determined based on available memory if None.
|
|
661
|
+
store_location (str, optional): Directory to cache computed indices.
|
|
662
|
+
Defaults to None.
|
|
663
|
+
force_recompute_indices (bool, optional): Recompute indices even if cached.
|
|
664
|
+
Defaults to False.
|
|
665
|
+
curiculum (int, optional): Curriculum learning epochs. If > 0, weight
|
|
666
|
+
exponent increases from 0 to 1 over this many epochs. Defaults to 0.
|
|
667
|
+
|
|
668
|
+
Attributes:
|
|
669
|
+
label_weights (torch.Tensor): Computed weights per class label.
|
|
670
|
+
klass_indices (torch.Tensor): Concatenated indices for all classes.
|
|
671
|
+
klass_offsets (torch.Tensor): Starting offset for each class in klass_indices.
|
|
672
|
+
count (int): Number of times __iter__ has been called (for curriculum).
|
|
673
|
+
|
|
674
|
+
Example:
|
|
675
|
+
>>> sampler = LabelWeightedSampler(
|
|
676
|
+
... labels=train_labels,
|
|
677
|
+
... num_samples=1_000_000,
|
|
678
|
+
... weight_scaler=10,
|
|
679
|
+
... element_weights=nnz_weights,
|
|
680
|
+
... )
|
|
681
|
+
>>> for idx in sampler:
|
|
682
|
+
... # Process sample at idx
|
|
683
|
+
... pass
|
|
544
684
|
"""
|
|
545
685
|
print("Initializing optimized parallel weighted sampler...")
|
|
546
686
|
super(LabelWeightedSampler, self).__init__(None)
|
|
@@ -851,8 +991,32 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
851
991
|
|
|
852
992
|
|
|
853
993
|
class RankShardSampler(Sampler[int]):
|
|
854
|
-
"""
|
|
855
|
-
|
|
994
|
+
"""
|
|
995
|
+
Sampler that shards data contiguously across distributed ranks.
|
|
996
|
+
|
|
997
|
+
Divides the dataset into contiguous chunks, one per rank, without
|
|
998
|
+
padding or duplicating samples. Preserves the original data order
|
|
999
|
+
within each shard (useful for pre-shuffled data).
|
|
1000
|
+
|
|
1001
|
+
Args:
|
|
1002
|
+
data_len (int): Total number of samples in the dataset.
|
|
1003
|
+
start_at (int, optional): Global starting index for resuming training.
|
|
1004
|
+
Requires the same number of GPUs as the previous run. Defaults to 0.
|
|
1005
|
+
|
|
1006
|
+
Attributes:
|
|
1007
|
+
rank (int): Current process rank (0 if not distributed).
|
|
1008
|
+
world_size (int): Total number of processes (1 if not distributed).
|
|
1009
|
+
start (int): Starting index for this rank's shard.
|
|
1010
|
+
end (int): Ending index (exclusive) for this rank's shard.
|
|
1011
|
+
|
|
1012
|
+
Note:
|
|
1013
|
+
The last rank may have fewer samples than others if the dataset
|
|
1014
|
+
size is not evenly divisible by world_size.
|
|
1015
|
+
|
|
1016
|
+
Example:
|
|
1017
|
+
>>> sampler = RankShardSampler(len(dataset))
|
|
1018
|
+
>>> loader = DataLoader(dataset, sampler=sampler)
|
|
1019
|
+
"""
|
|
856
1020
|
|
|
857
1021
|
def __init__(self, data_len: int, start_at: int = 0) -> None:
|
|
858
1022
|
self.data_len = data_len
|
|
@@ -866,13 +1030,13 @@ class RankShardSampler(Sampler[int]):
|
|
|
866
1030
|
# contiguous chunk per rank (last rank may be shorter)
|
|
867
1031
|
if self.start_at > 0:
|
|
868
1032
|
print(
|
|
869
|
-
"!!!!
|
|
870
|
-
|
|
1033
|
+
"!!!!ATTENTION: make sure that you are running on the exact same \
|
|
1034
|
+
number of GPU as your previous run!!!!!"
|
|
871
1035
|
)
|
|
872
1036
|
print(f"Sharding data of size {data_len} over {self.world_size} ranks")
|
|
873
|
-
per_rank = math.ceil(
|
|
1037
|
+
per_rank = math.ceil(self.data_len / self.world_size)
|
|
874
1038
|
self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
|
|
875
|
-
self.end = min(self.
|
|
1039
|
+
self.end = min((self.rank + 1) * per_rank, self.data_len)
|
|
876
1040
|
print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
|
|
877
1041
|
|
|
878
1042
|
def __iter__(self):
|
scdataloader/preprocess.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Preprocessing utilities for single-cell gene expression data.
|
|
3
|
+
|
|
4
|
+
This module provides functions for normalizing, transforming, and discretizing
|
|
5
|
+
gene expression values for use with scPRINT and similar models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
import gc
|
|
2
9
|
import time
|
|
3
10
|
from typing import Callable, List, Optional, Union
|
|
@@ -664,39 +671,41 @@ def is_log1p(adata: AnnData) -> bool:
|
|
|
664
671
|
return True
|
|
665
672
|
|
|
666
673
|
|
|
667
|
-
def _digitize(
|
|
674
|
+
def _digitize(values: np.ndarray, bins: np.ndarray) -> np.ndarray:
|
|
668
675
|
"""
|
|
669
|
-
Digitize
|
|
670
|
-
have same values.
|
|
676
|
+
Digitize values into discrete bins with 1-based indexing.
|
|
671
677
|
|
|
672
|
-
|
|
678
|
+
Similar to np.digitize but ensures output is 1-indexed (bin 0 reserved for
|
|
679
|
+
zero values) and handles edge cases for expression binning.
|
|
673
680
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
The side to use for digitization. If "one", the left side is used. If
|
|
680
|
-
"both", the left and right side are used. Default to "one".
|
|
681
|
+
Args:
|
|
682
|
+
values (np.ndarray): Array of values to discretize. Should be non-zero
|
|
683
|
+
expression values.
|
|
684
|
+
bins (np.ndarray): Bin edges from np.quantile or similar. Values are
|
|
685
|
+
assigned to bins based on which edges they fall between.
|
|
681
686
|
|
|
682
687
|
Returns:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
688
|
+
np.ndarray: Integer bin indices, 1-indexed. Values equal to bins[i]
|
|
689
|
+
are assigned to bin i+1.
|
|
690
|
+
|
|
691
|
+
Example:
|
|
692
|
+
>>> values = np.array([0.5, 1.5, 2.5, 3.5])
|
|
693
|
+
>>> bins = np.array([1.0, 2.0, 3.0])
|
|
694
|
+
>>> _digitize(values, bins)
|
|
695
|
+
array([1, 2, 3, 3])
|
|
696
|
+
|
|
697
|
+
Note:
|
|
698
|
+
This function is used internally by the Collator for expression binning.
|
|
699
|
+
Zero values should be handled separately before calling this function.
|
|
686
700
|
"""
|
|
687
|
-
assert
|
|
688
|
-
|
|
689
|
-
left_digits = np.digitize(x, bins)
|
|
690
|
-
if side == "one":
|
|
691
|
-
return left_digits
|
|
701
|
+
assert values.ndim == 1 and bins.ndim == 1
|
|
692
702
|
|
|
693
|
-
|
|
703
|
+
left_digits = np.digitize(values, bins)
|
|
704
|
+
return left_digits
|
|
694
705
|
|
|
695
|
-
rands = np.random.rand(len(x)) # uniform random numbers
|
|
696
706
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
return digits
|
|
707
|
+
# Add documentation for any other functions in preprocess.py
|
|
708
|
+
# ...existing code...
|
|
700
709
|
|
|
701
710
|
|
|
702
711
|
def binning(row: np.ndarray, n_bins: int) -> np.ndarray:
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
scdataloader/__init__.py,sha256=Z5HURehoWw1GrecImmTXIkv4ih8Q5RxNQWPm8zjjXOA,226
|
|
2
2
|
scdataloader/__main__.py,sha256=xPOtrEpQQQZUGTnm8KTvsQcA_jR45oMG_VHqd0Ny7_M,8677
|
|
3
3
|
scdataloader/base.py,sha256=M1gD59OffRdLOgS1vHKygOomUoAMuzjpRtAfM3SBKF8,338
|
|
4
|
-
scdataloader/collator.py,sha256=
|
|
4
|
+
scdataloader/collator.py,sha256=VcFJcVAIeKvYkG1DPRXzoBaw2wQ6D_0lsv5Mcv-9USI,17419
|
|
5
5
|
scdataloader/config.py,sha256=nM8J11z2-lornryy1KxDE9675Rcxge4RGhdmpeiMhuI,7173
|
|
6
6
|
scdataloader/data.json,sha256=Zb8c27yk3rwMgtAU8kkiWWAyUwYBrlCqKUyEtaAx9i8,8785
|
|
7
|
-
scdataloader/data.py,sha256=
|
|
8
|
-
scdataloader/datamodule.py,sha256=
|
|
7
|
+
scdataloader/data.py,sha256=fMW1OgllPCz87si3DpkzOSoqnufgKlh8aW5rEVmeC_c,25133
|
|
8
|
+
scdataloader/datamodule.py,sha256=6B5nwo8NG_b8dNGPDRtDyFt5Hj095xiHDFa3ga0_s-Y,43599
|
|
9
9
|
scdataloader/mapped.py,sha256=h9YKQ8SG9tyZL8c6_Wu5Xov5ODGK6FzVuFopz58xwN4,29887
|
|
10
|
-
scdataloader/preprocess.py,sha256=
|
|
10
|
+
scdataloader/preprocess.py,sha256=oAGMilgdIgggyp9B9c9627kdo6SCco2tnFhhIHY4-yc,39642
|
|
11
11
|
scdataloader/utils.py,sha256=Z6td0cIphrYDLVrPrV8q4jUC_HtwGQmi-NcbpdbWrns,31034
|
|
12
|
-
scdataloader-2.0.
|
|
13
|
-
scdataloader-2.0.
|
|
14
|
-
scdataloader-2.0.
|
|
15
|
-
scdataloader-2.0.
|
|
16
|
-
scdataloader-2.0.
|
|
12
|
+
scdataloader-2.0.4.dist-info/METADATA,sha256=--g4uHOlhQ2Y_Jkxo9LOr--tH0BPL_sxODaLhUCMcw8,10314
|
|
13
|
+
scdataloader-2.0.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
14
|
+
scdataloader-2.0.4.dist-info/entry_points.txt,sha256=VXAN1m_CjbdLJ6SKYR0sBLGDV4wvv31ri7fWWuwbpno,60
|
|
15
|
+
scdataloader-2.0.4.dist-info/licenses/LICENSE,sha256=rGy_eYmnxtbOvKs7qt5V0czSWxJwgX_MlgMyTZwDHbc,1073
|
|
16
|
+
scdataloader-2.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|