scdataloader 2.0.2__tar.gz → 2.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scdataloader-2.0.2 → scdataloader-2.0.3}/PKG-INFO +1 -1
- {scdataloader-2.0.2 → scdataloader-2.0.3}/pyproject.toml +2 -3
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/datamodule.py +22 -10
- {scdataloader-2.0.2 → scdataloader-2.0.3}/.gitignore +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/LICENSE +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/README.md +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/__init__.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/__main__.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/base.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/collator.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/config.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/data.json +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/data.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/mapped.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/preprocess.py +0 -0
- {scdataloader-2.0.2 → scdataloader-2.0.3}/scdataloader/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "scdataloader"
|
|
3
|
-
version = "2.0.
|
|
3
|
+
version = "2.0.3"
|
|
4
4
|
description = "a dataloader for single cell data in lamindb"
|
|
5
5
|
authors = [
|
|
6
6
|
{name = "jkobject", email = "jkobject@gmail.com"}
|
|
@@ -15,6 +15,7 @@ dependencies = [
|
|
|
15
15
|
"cellxgene-census>=0.1.0",
|
|
16
16
|
"torch>=2.2.0",
|
|
17
17
|
"pytorch-lightning>=2.3.0",
|
|
18
|
+
"lightning>=2.3.0",
|
|
18
19
|
"anndata>=0.9.0",
|
|
19
20
|
"zarr>=2.10.0",
|
|
20
21
|
"matplotlib>=3.5.0",
|
|
@@ -27,8 +28,6 @@ dependencies = [
|
|
|
27
28
|
"django>=4.0.0",
|
|
28
29
|
"scikit-misc>=0.5.0",
|
|
29
30
|
"jupytext>=1.16.0",
|
|
30
|
-
"lightning>=2.3.0",
|
|
31
|
-
"pytorch-lightning>=2.3.0",
|
|
32
31
|
]
|
|
33
32
|
|
|
34
33
|
[project.optional-dependencies]
|
|
@@ -65,6 +65,7 @@ class DataModule(L.LightningDataModule):
|
|
|
65
65
|
genedf: Optional[pd.DataFrame] = None,
|
|
66
66
|
n_bins: int = 0,
|
|
67
67
|
curiculum: int = 0,
|
|
68
|
+
start_at: int = 0,
|
|
68
69
|
**kwargs,
|
|
69
70
|
):
|
|
70
71
|
"""
|
|
@@ -162,6 +163,7 @@ class DataModule(L.LightningDataModule):
|
|
|
162
163
|
self.sampler_chunk_size = sampler_chunk_size
|
|
163
164
|
self.store_location = store_location
|
|
164
165
|
self.nnz = None
|
|
166
|
+
self.start_at = start_at
|
|
165
167
|
self.idx_full = None
|
|
166
168
|
self.max_len = max_len
|
|
167
169
|
self.test_datasets = []
|
|
@@ -324,9 +326,9 @@ class DataModule(L.LightningDataModule):
|
|
|
324
326
|
len_test = self.test_split
|
|
325
327
|
else:
|
|
326
328
|
len_test = int(self.n_samples * self.test_split)
|
|
327
|
-
assert
|
|
328
|
-
|
|
329
|
-
)
|
|
329
|
+
assert (
|
|
330
|
+
len_test + len_valid < self.n_samples
|
|
331
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
330
332
|
|
|
331
333
|
idx_full = []
|
|
332
334
|
if len(self.assays_to_drop) > 0:
|
|
@@ -461,7 +463,7 @@ class DataModule(L.LightningDataModule):
|
|
|
461
463
|
dataset = None
|
|
462
464
|
else:
|
|
463
465
|
dataset = Subset(self.dataset, self.idx_full)
|
|
464
|
-
train_sampler = RankShardSampler(len(dataset))
|
|
466
|
+
train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
|
|
465
467
|
current_loader_kwargs = kwargs.copy()
|
|
466
468
|
current_loader_kwargs.update(self.kwargs)
|
|
467
469
|
return DataLoader(
|
|
@@ -492,8 +494,8 @@ class DataModule(L.LightningDataModule):
|
|
|
492
494
|
def predict_dataloader(self):
|
|
493
495
|
subset = Subset(self.dataset, self.idx_full)
|
|
494
496
|
return DataLoader(
|
|
495
|
-
|
|
496
|
-
sampler=RankShardSampler(len(subset)),
|
|
497
|
+
self.dataset,
|
|
498
|
+
sampler=RankShardSampler(len(subset), start_at=self.start_at),
|
|
497
499
|
**self.kwargs,
|
|
498
500
|
)
|
|
499
501
|
|
|
@@ -667,7 +669,9 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
667
669
|
unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
|
|
668
670
|
|
|
669
671
|
# Initialize result tensor
|
|
670
|
-
result_indices_list =
|
|
672
|
+
result_indices_list = (
|
|
673
|
+
[]
|
|
674
|
+
) # Changed name to avoid conflict if you had result_indices elsewhere
|
|
671
675
|
|
|
672
676
|
# Process only the classes that were actually sampled
|
|
673
677
|
for i, (label, count) in tqdm(
|
|
@@ -850,8 +854,9 @@ class RankShardSampler(Sampler[int]):
|
|
|
850
854
|
"""Shards a dataset contiguously across ranks without padding or duplicates.
|
|
851
855
|
Preserves the existing order (e.g., your pre-shuffled idx_full)."""
|
|
852
856
|
|
|
853
|
-
def __init__(self, data_len: int):
|
|
857
|
+
def __init__(self, data_len: int, start_at: int = 0) -> None:
|
|
854
858
|
self.data_len = data_len
|
|
859
|
+
self.start_at = start_at
|
|
855
860
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
856
861
|
self.rank = torch.distributed.get_rank()
|
|
857
862
|
self.world_size = torch.distributed.get_world_size()
|
|
@@ -859,9 +864,16 @@ class RankShardSampler(Sampler[int]):
|
|
|
859
864
|
self.rank, self.world_size = 0, 1
|
|
860
865
|
|
|
861
866
|
# contiguous chunk per rank (last rank may be shorter)
|
|
862
|
-
|
|
863
|
-
|
|
867
|
+
if self.start_at > 0:
|
|
868
|
+
print(
|
|
869
|
+
"!!!!ATTTENTION: make sure that you are running on the exact same \
|
|
870
|
+
number of GPU as your previous run!!!!!"
|
|
871
|
+
)
|
|
872
|
+
print(f"Sharding data of size {data_len} over {self.world_size} ranks")
|
|
873
|
+
per_rank = math.ceil((self.data_len - self.start_at) / self.world_size)
|
|
874
|
+
self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
|
|
864
875
|
self.end = min(self.start + per_rank, self.data_len)
|
|
876
|
+
print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
|
|
865
877
|
|
|
866
878
|
def __iter__(self):
|
|
867
879
|
return iter(range(self.start, self.end))
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|