scdataloader 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/collator.py +99 -36
- scdataloader/config.py +1151 -0
- scdataloader/data.py +177 -39
- scdataloader/datamodule.py +222 -57
- scdataloader/preprocess.py +33 -24
- scdataloader/utils.py +31 -181
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/METADATA +1 -1
- scdataloader-2.0.5.dist-info/RECORD +16 -0
- scdataloader-2.0.3.dist-info/RECORD +0 -16
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/WHEEL +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/entry_points.txt +0 -0
- {scdataloader-2.0.3.dist-info → scdataloader-2.0.5.dist-info}/licenses/LICENSE +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -69,39 +69,98 @@ class DataModule(L.LightningDataModule):
|
|
|
69
69
|
**kwargs,
|
|
70
70
|
):
|
|
71
71
|
"""
|
|
72
|
-
DataModule
|
|
73
|
-
it can work with bare pytorch too
|
|
72
|
+
PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
This DataModule provides train/val/test dataloaders with configurable sampling strategies.
|
|
75
|
+
It combines MappedCollection, Dataset, and Collator to create efficient data pipelines
|
|
76
|
+
for training single-cell foundation models.
|
|
77
|
+
|
|
78
|
+
The training dataloader uses weighted random sampling based on class frequencies,
|
|
79
|
+
validation uses random sampling, and test uses sequential sampling on held-out datasets.
|
|
77
80
|
|
|
78
81
|
Args:
|
|
79
|
-
collection_name (str):
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
82
|
+
collection_name (str): Key of the LaminDB Collection to load.
|
|
83
|
+
clss_to_weight (List[str], optional): Label columns to use for weighted sampling
|
|
84
|
+
in the training dataloader. Supports "nnz" for weighting by number of
|
|
85
|
+
non-zero genes. Defaults to ["organism_ontology_term_id"].
|
|
86
|
+
weight_scaler (int, optional): Controls balance between rare and common classes.
|
|
87
|
+
Higher values lead to more uniform sampling across classes. Set to 0 to
|
|
88
|
+
disable weighted sampling. Defaults to 10.
|
|
89
|
+
n_samples_per_epoch (int, optional): Number of samples to draw per training epoch.
|
|
90
|
+
Defaults to 2,000,000.
|
|
91
|
+
validation_split (float | int, optional): Proportion (float) or absolute number (int)
|
|
92
|
+
of samples for validation. Defaults to 0.2.
|
|
93
|
+
test_split (float | int, optional): Proportion (float) or absolute number (int)
|
|
94
|
+
of samples for testing. Uses entire datasets as test sets, rounding to
|
|
95
|
+
nearest dataset boundary. Defaults to 0.
|
|
96
|
+
use_default_col (bool, optional): Whether to use the default Collator for batch
|
|
97
|
+
preparation. If False, no collate_fn is applied. Defaults to True.
|
|
98
|
+
clss_to_predict (List[str], optional): Observation columns to encode as prediction
|
|
99
|
+
targets. Must include "organism_ontology_term_id". Defaults to
|
|
100
|
+
["organism_ontology_term_id"].
|
|
101
|
+
hierarchical_clss (List[str], optional): Observation columns with hierarchical
|
|
102
|
+
ontology structure to be processed. Defaults to [].
|
|
103
|
+
how (str, optional): Gene selection strategy passed to Collator. One of
|
|
104
|
+
"most expr", "random expr", "all", "some". Defaults to "random expr".
|
|
105
|
+
organism_col (str, optional): Column name for organism ontology term ID.
|
|
106
|
+
Defaults to "organism_ontology_term_id".
|
|
107
|
+
max_len (int, optional): Maximum number of genes per sample passed to Collator.
|
|
108
|
+
Defaults to 1000.
|
|
109
|
+
replacement (bool, optional): Whether to sample with replacement in training.
|
|
110
|
+
Defaults to True.
|
|
111
|
+
gene_subset (List[str], optional): List of genes to restrict the dataset to.
|
|
112
|
+
Useful when model only supports specific genes. Defaults to None.
|
|
113
|
+
tp_name (str, optional): Column name for time point or heat diffusion values.
|
|
114
|
+
Defaults to None.
|
|
115
|
+
assays_to_drop (List[str], optional): List of assay ontology term IDs to exclude
|
|
116
|
+
from training. Defaults to ["EFO:0030007"] (ATAC-seq).
|
|
117
|
+
metacell_mode (float, optional): Probability of using metacell aggregation mode.
|
|
118
|
+
Cannot be used with get_knn_cells. Defaults to 0.0.
|
|
119
|
+
get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
|
|
120
|
+
expression data. Cannot be used with metacell_mode. Defaults to False.
|
|
121
|
+
store_location (str, optional): Directory path to cache sampler indices and
|
|
122
|
+
labels for faster subsequent loading. Defaults to None.
|
|
123
|
+
force_recompute_indices (bool, optional): Force recomputation of cached indices
|
|
124
|
+
even if they exist. Defaults to False.
|
|
125
|
+
sampler_workers (int, optional): Number of parallel workers for building sampler
|
|
126
|
+
indices. Auto-determined based on available CPUs if None. Defaults to None.
|
|
127
|
+
sampler_chunk_size (int, optional): Chunk size for parallel sampler processing.
|
|
128
|
+
Auto-determined based on available memory if None. Defaults to None.
|
|
129
|
+
organisms (List[str], optional): List of organisms to include. If None, uses
|
|
130
|
+
all organisms in the dataset. Defaults to None.
|
|
131
|
+
genedf (pd.DataFrame, optional): Gene information DataFrame. If None, loaded
|
|
132
|
+
automatically. Defaults to None.
|
|
133
|
+
n_bins (int, optional): Number of bins for expression discretization. 0 means
|
|
134
|
+
no binning. Defaults to 0.
|
|
135
|
+
curiculum (int, optional): Curriculum learning parameter. If > 0, gradually
|
|
136
|
+
increases sampling weight balance over epochs. Defaults to 0.
|
|
137
|
+
start_at (int, optional): Starting index for resuming inference. Requires same
|
|
138
|
+
number of GPUs as previous run. Defaults to 0.
|
|
139
|
+
**kwargs: Additional arguments passed to PyTorch DataLoader (e.g., batch_size,
|
|
140
|
+
num_workers, pin_memory).
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
dataset (Dataset): The underlying Dataset instance.
|
|
144
|
+
classes (dict[str, int]): Mapping from class names to number of categories.
|
|
145
|
+
train_labels (np.ndarray): Label array for weighted sampling.
|
|
146
|
+
idx_full (np.ndarray): Indices for training samples.
|
|
147
|
+
valid_idx (np.ndarray): Indices for validation samples.
|
|
148
|
+
test_idx (np.ndarray): Indices for test samples.
|
|
149
|
+
test_datasets (List[str]): Paths to datasets used for testing.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: If "organism_ontology_term_id" not in clss_to_predict.
|
|
153
|
+
ValueError: If both metacell_mode > 0 and get_knn_cells are True.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
>>> dm = DataModule(
|
|
157
|
+
... collection_name="my_collection",
|
|
158
|
+
... batch_size=32,
|
|
159
|
+
... num_workers=4,
|
|
160
|
+
... max_len=2000,
|
|
161
|
+
... )
|
|
162
|
+
>>> dm.setup()
|
|
163
|
+
>>> train_loader = dm.train_dataloader()
|
|
105
164
|
"""
|
|
106
165
|
if "organism_ontology_term_id" not in clss_to_predict:
|
|
107
166
|
raise ValueError(
|
|
@@ -131,6 +190,7 @@ class DataModule(L.LightningDataModule):
|
|
|
131
190
|
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
132
191
|
# we might want not to order the genes by expression (or do it?)
|
|
133
192
|
# we might want to not introduce zeros and
|
|
193
|
+
|
|
134
194
|
if use_default_col:
|
|
135
195
|
kwargs["collate_fn"] = Collator(
|
|
136
196
|
organisms=mdataset.organisms if organisms is None else organisms,
|
|
@@ -285,12 +345,24 @@ class DataModule(L.LightningDataModule):
|
|
|
285
345
|
|
|
286
346
|
def setup(self, stage=None):
|
|
287
347
|
"""
|
|
288
|
-
|
|
289
|
-
|
|
348
|
+
Prepare data splits for training, validation, and testing.
|
|
349
|
+
|
|
350
|
+
This method shuffles the data, computes sample weights for weighted sampling,
|
|
351
|
+
removes samples from dropped assays, and creates train/val/test splits.
|
|
352
|
+
Test splits use entire datasets to ensure evaluation on unseen data sources.
|
|
353
|
+
|
|
354
|
+
Results can be cached to `store_location` for faster subsequent runs.
|
|
290
355
|
|
|
291
356
|
Args:
|
|
292
|
-
stage (str, optional):
|
|
293
|
-
|
|
357
|
+
stage (str, optional): Training stage ('fit', 'test', or None for both).
|
|
358
|
+
Currently not used but kept for Lightning compatibility. Defaults to None.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
List[str]: List of paths to test datasets.
|
|
362
|
+
|
|
363
|
+
Note:
|
|
364
|
+
Must be called before using dataloaders. The train/val/test split is
|
|
365
|
+
deterministic when loading from cache.
|
|
294
366
|
"""
|
|
295
367
|
print("setting up the datamodule")
|
|
296
368
|
start_time = time.time()
|
|
@@ -441,6 +513,22 @@ class DataModule(L.LightningDataModule):
|
|
|
441
513
|
return self.test_datasets
|
|
442
514
|
|
|
443
515
|
def train_dataloader(self, **kwargs):
|
|
516
|
+
"""
|
|
517
|
+
Create the training DataLoader with weighted random sampling.
|
|
518
|
+
|
|
519
|
+
Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0
|
|
520
|
+
and clss_to_weight is specified. Otherwise uses RankShardSampler for
|
|
521
|
+
distributed training without weighting.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
**kwargs: Additional arguments passed to DataLoader, overriding defaults.
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
DataLoader: Training DataLoader instance.
|
|
528
|
+
|
|
529
|
+
Raises:
|
|
530
|
+
ValueError: If setup() has not been called.
|
|
531
|
+
"""
|
|
444
532
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
445
533
|
try:
|
|
446
534
|
print("Setting up the parallel train sampler...")
|
|
@@ -473,6 +561,12 @@ class DataModule(L.LightningDataModule):
|
|
|
473
561
|
)
|
|
474
562
|
|
|
475
563
|
def val_dataloader(self):
|
|
564
|
+
"""
|
|
565
|
+
Create the validation DataLoader.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
DataLoader | List: Validation DataLoader, or empty list if no validation split.
|
|
569
|
+
"""
|
|
476
570
|
return (
|
|
477
571
|
DataLoader(
|
|
478
572
|
Subset(self.dataset, self.valid_idx),
|
|
@@ -483,6 +577,12 @@ class DataModule(L.LightningDataModule):
|
|
|
483
577
|
)
|
|
484
578
|
|
|
485
579
|
def test_dataloader(self):
|
|
580
|
+
"""
|
|
581
|
+
Create the test DataLoader with sequential sampling.
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
DataLoader | List: Test DataLoader, or empty list if no test split.
|
|
585
|
+
"""
|
|
486
586
|
return (
|
|
487
587
|
DataLoader(
|
|
488
588
|
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
@@ -492,6 +592,14 @@ class DataModule(L.LightningDataModule):
|
|
|
492
592
|
)
|
|
493
593
|
|
|
494
594
|
def predict_dataloader(self):
|
|
595
|
+
"""
|
|
596
|
+
Create a DataLoader for prediction over all training data.
|
|
597
|
+
|
|
598
|
+
Uses RankShardSampler for distributed inference.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
DataLoader: Prediction DataLoader instance.
|
|
602
|
+
"""
|
|
495
603
|
subset = Subset(self.dataset, self.idx_full)
|
|
496
604
|
return DataLoader(
|
|
497
605
|
self.dataset,
|
|
@@ -501,15 +609,6 @@ class DataModule(L.LightningDataModule):
|
|
|
501
609
|
|
|
502
610
|
|
|
503
611
|
class LabelWeightedSampler(Sampler[int]):
|
|
504
|
-
"""
|
|
505
|
-
A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
|
|
506
|
-
|
|
507
|
-
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
508
|
-
1. Parallel building of class indices
|
|
509
|
-
2. Chunked processing for large arrays
|
|
510
|
-
3. Efficient memory management
|
|
511
|
-
4. Proper handling of replacement and non-replacement sampling
|
|
512
|
-
"""
|
|
513
612
|
|
|
514
613
|
label_weights: torch.Tensor
|
|
515
614
|
klass_indices: dict[int, torch.Tensor]
|
|
@@ -531,16 +630,58 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
531
630
|
curiculum: int = 0,
|
|
532
631
|
) -> None:
|
|
533
632
|
"""
|
|
534
|
-
|
|
633
|
+
Weighted random sampler balancing both class frequencies and element weights.
|
|
634
|
+
|
|
635
|
+
This sampler is optimized for very large datasets (millions of samples) with:
|
|
636
|
+
- Parallel construction of class indices using multiple CPU workers
|
|
637
|
+
- Chunked processing to manage memory usage
|
|
638
|
+
- Support for curriculum learning via progressive weight scaling
|
|
639
|
+
- Optional per-element weights (e.g., by number of expressed genes)
|
|
640
|
+
|
|
641
|
+
The sampling process:
|
|
642
|
+
1. Sample class labels according to class weights
|
|
643
|
+
2. For each sampled class, sample elements according to element weights
|
|
644
|
+
3. Shuffle all sampled indices
|
|
535
645
|
|
|
536
646
|
Args:
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
647
|
+
labels (np.ndarray): Integer class label for each dataset element.
|
|
648
|
+
Shape: (dataset_size,). The last unique label is treated as
|
|
649
|
+
"excluded" with weight 0.
|
|
650
|
+
num_samples (int): Number of samples to draw per epoch.
|
|
651
|
+
replacement (bool, optional): Whether to sample with replacement.
|
|
652
|
+
Defaults to True.
|
|
653
|
+
weight_scaler (float, optional): Controls class weight balance.
|
|
654
|
+
Weight formula: (scaler * count) / (count + scaler).
|
|
655
|
+
Higher values = more uniform sampling. Defaults to None.
|
|
656
|
+
element_weights (Sequence[float], optional): Per-element sampling weights.
|
|
657
|
+
Shape: (dataset_size,). Defaults to None (uniform within class).
|
|
658
|
+
n_workers (int, optional): Number of parallel workers for index building.
|
|
659
|
+
Defaults to min(20, num_cpus - 1).
|
|
660
|
+
chunk_size (int, optional): Elements per chunk for parallel processing.
|
|
661
|
+
Auto-determined based on available memory if None.
|
|
662
|
+
store_location (str, optional): Directory to cache computed indices.
|
|
663
|
+
Defaults to None.
|
|
664
|
+
force_recompute_indices (bool, optional): Recompute indices even if cached.
|
|
665
|
+
Defaults to False.
|
|
666
|
+
curiculum (int, optional): Curriculum learning epochs. If > 0, weight
|
|
667
|
+
exponent increases from 0 to 1 over this many epochs. Defaults to 0.
|
|
668
|
+
|
|
669
|
+
Attributes:
|
|
670
|
+
label_weights (torch.Tensor): Computed weights per class label.
|
|
671
|
+
klass_indices (torch.Tensor): Concatenated indices for all classes.
|
|
672
|
+
klass_offsets (torch.Tensor): Starting offset for each class in klass_indices.
|
|
673
|
+
count (int): Number of times __iter__ has been called (for curriculum).
|
|
674
|
+
|
|
675
|
+
Example:
|
|
676
|
+
>>> sampler = LabelWeightedSampler(
|
|
677
|
+
... labels=train_labels,
|
|
678
|
+
... num_samples=1_000_000,
|
|
679
|
+
... weight_scaler=10,
|
|
680
|
+
... element_weights=nnz_weights,
|
|
681
|
+
... )
|
|
682
|
+
>>> for idx in sampler:
|
|
683
|
+
... # Process sample at idx
|
|
684
|
+
... pass
|
|
544
685
|
"""
|
|
545
686
|
print("Initializing optimized parallel weighted sampler...")
|
|
546
687
|
super(LabelWeightedSampler, self).__init__(None)
|
|
@@ -851,8 +992,32 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
851
992
|
|
|
852
993
|
|
|
853
994
|
class RankShardSampler(Sampler[int]):
|
|
854
|
-
"""
|
|
855
|
-
|
|
995
|
+
"""
|
|
996
|
+
Sampler that shards data contiguously across distributed ranks.
|
|
997
|
+
|
|
998
|
+
Divides the dataset into contiguous chunks, one per rank, without
|
|
999
|
+
padding or duplicating samples. Preserves the original data order
|
|
1000
|
+
within each shard (useful for pre-shuffled data).
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
data_len (int): Total number of samples in the dataset.
|
|
1004
|
+
start_at (int, optional): Global starting index for resuming training.
|
|
1005
|
+
Requires the same number of GPUs as the previous run. Defaults to 0.
|
|
1006
|
+
|
|
1007
|
+
Attributes:
|
|
1008
|
+
rank (int): Current process rank (0 if not distributed).
|
|
1009
|
+
world_size (int): Total number of processes (1 if not distributed).
|
|
1010
|
+
start (int): Starting index for this rank's shard.
|
|
1011
|
+
end (int): Ending index (exclusive) for this rank's shard.
|
|
1012
|
+
|
|
1013
|
+
Note:
|
|
1014
|
+
The last rank may have fewer samples than others if the dataset
|
|
1015
|
+
size is not evenly divisible by world_size.
|
|
1016
|
+
|
|
1017
|
+
Example:
|
|
1018
|
+
>>> sampler = RankShardSampler(len(dataset))
|
|
1019
|
+
>>> loader = DataLoader(dataset, sampler=sampler)
|
|
1020
|
+
"""
|
|
856
1021
|
|
|
857
1022
|
def __init__(self, data_len: int, start_at: int = 0) -> None:
|
|
858
1023
|
self.data_len = data_len
|
|
@@ -866,13 +1031,13 @@ class RankShardSampler(Sampler[int]):
|
|
|
866
1031
|
# contiguous chunk per rank (last rank may be shorter)
|
|
867
1032
|
if self.start_at > 0:
|
|
868
1033
|
print(
|
|
869
|
-
"!!!!
|
|
870
|
-
|
|
1034
|
+
"!!!!ATTENTION: make sure that you are running on the exact same \
|
|
1035
|
+
number of GPU as your previous run!!!!!"
|
|
871
1036
|
)
|
|
872
1037
|
print(f"Sharding data of size {data_len} over {self.world_size} ranks")
|
|
873
|
-
per_rank = math.ceil(
|
|
1038
|
+
per_rank = math.ceil(self.data_len / self.world_size)
|
|
874
1039
|
self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
|
|
875
|
-
self.end = min(self.
|
|
1040
|
+
self.end = min((self.rank + 1) * per_rank, self.data_len)
|
|
876
1041
|
print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
|
|
877
1042
|
|
|
878
1043
|
def __iter__(self):
|
scdataloader/preprocess.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Preprocessing utilities for single-cell gene expression data.
|
|
3
|
+
|
|
4
|
+
This module provides functions for normalizing, transforming, and discretizing
|
|
5
|
+
gene expression values for use with scPRINT and similar models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
1
8
|
import gc
|
|
2
9
|
import time
|
|
3
10
|
from typing import Callable, List, Optional, Union
|
|
@@ -664,39 +671,41 @@ def is_log1p(adata: AnnData) -> bool:
|
|
|
664
671
|
return True
|
|
665
672
|
|
|
666
673
|
|
|
667
|
-
def _digitize(
|
|
674
|
+
def _digitize(values: np.ndarray, bins: np.ndarray) -> np.ndarray:
|
|
668
675
|
"""
|
|
669
|
-
Digitize
|
|
670
|
-
have same values.
|
|
676
|
+
Digitize values into discrete bins with 1-based indexing.
|
|
671
677
|
|
|
672
|
-
|
|
678
|
+
Similar to np.digitize but ensures output is 1-indexed (bin 0 reserved for
|
|
679
|
+
zero values) and handles edge cases for expression binning.
|
|
673
680
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
The side to use for digitization. If "one", the left side is used. If
|
|
680
|
-
"both", the left and right side are used. Default to "one".
|
|
681
|
+
Args:
|
|
682
|
+
values (np.ndarray): Array of values to discretize. Should be non-zero
|
|
683
|
+
expression values.
|
|
684
|
+
bins (np.ndarray): Bin edges from np.quantile or similar. Values are
|
|
685
|
+
assigned to bins based on which edges they fall between.
|
|
681
686
|
|
|
682
687
|
Returns:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
688
|
+
np.ndarray: Integer bin indices, 1-indexed. Values equal to bins[i]
|
|
689
|
+
are assigned to bin i+1.
|
|
690
|
+
|
|
691
|
+
Example:
|
|
692
|
+
>>> values = np.array([0.5, 1.5, 2.5, 3.5])
|
|
693
|
+
>>> bins = np.array([1.0, 2.0, 3.0])
|
|
694
|
+
>>> _digitize(values, bins)
|
|
695
|
+
array([1, 2, 3, 3])
|
|
696
|
+
|
|
697
|
+
Note:
|
|
698
|
+
This function is used internally by the Collator for expression binning.
|
|
699
|
+
Zero values should be handled separately before calling this function.
|
|
686
700
|
"""
|
|
687
|
-
assert
|
|
688
|
-
|
|
689
|
-
left_digits = np.digitize(x, bins)
|
|
690
|
-
if side == "one":
|
|
691
|
-
return left_digits
|
|
701
|
+
assert values.ndim == 1 and bins.ndim == 1
|
|
692
702
|
|
|
693
|
-
|
|
703
|
+
left_digits = np.digitize(values, bins)
|
|
704
|
+
return left_digits
|
|
694
705
|
|
|
695
|
-
rands = np.random.rand(len(x)) # uniform random numbers
|
|
696
706
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
return digits
|
|
707
|
+
# Add documentation for any other functions in preprocess.py
|
|
708
|
+
# ...existing code...
|
|
700
709
|
|
|
701
710
|
|
|
702
711
|
def binning(row: np.ndarray, n_bins: int) -> np.ndarray:
|