scdataloader 2.0.2__tar.gz → 2.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdataloader
3
- Version: 2.0.2
3
+ Version: 2.0.3
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Project-URL: repository, https://github.com/jkobject/scDataLoader
6
6
  Author-email: jkobject <jkobject@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdataloader"
3
- version = "2.0.2"
3
+ version = "2.0.3"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = [
6
6
  {name = "jkobject", email = "jkobject@gmail.com"}
@@ -15,6 +15,7 @@ dependencies = [
15
15
  "cellxgene-census>=0.1.0",
16
16
  "torch>=2.2.0",
17
17
  "pytorch-lightning>=2.3.0",
18
+ "lightning>=2.3.0",
18
19
  "anndata>=0.9.0",
19
20
  "zarr>=2.10.0",
20
21
  "matplotlib>=3.5.0",
@@ -27,8 +28,6 @@ dependencies = [
27
28
  "django>=4.0.0",
28
29
  "scikit-misc>=0.5.0",
29
30
  "jupytext>=1.16.0",
30
- "lightning>=2.3.0",
31
- "pytorch-lightning>=2.3.0",
32
31
  ]
33
32
 
34
33
  [project.optional-dependencies]
@@ -65,6 +65,7 @@ class DataModule(L.LightningDataModule):
65
65
  genedf: Optional[pd.DataFrame] = None,
66
66
  n_bins: int = 0,
67
67
  curiculum: int = 0,
68
+ start_at: int = 0,
68
69
  **kwargs,
69
70
  ):
70
71
  """
@@ -162,6 +163,7 @@ class DataModule(L.LightningDataModule):
162
163
  self.sampler_chunk_size = sampler_chunk_size
163
164
  self.store_location = store_location
164
165
  self.nnz = None
166
+ self.start_at = start_at
165
167
  self.idx_full = None
166
168
  self.max_len = max_len
167
169
  self.test_datasets = []
@@ -324,9 +326,9 @@ class DataModule(L.LightningDataModule):
324
326
  len_test = self.test_split
325
327
  else:
326
328
  len_test = int(self.n_samples * self.test_split)
327
- assert len_test + len_valid < self.n_samples, (
328
- "test set + valid set size is configured to be larger than entire dataset."
329
- )
329
+ assert (
330
+ len_test + len_valid < self.n_samples
331
+ ), "test set + valid set size is configured to be larger than entire dataset."
330
332
 
331
333
  idx_full = []
332
334
  if len(self.assays_to_drop) > 0:
@@ -461,7 +463,7 @@ class DataModule(L.LightningDataModule):
461
463
  dataset = None
462
464
  else:
463
465
  dataset = Subset(self.dataset, self.idx_full)
464
- train_sampler = RankShardSampler(len(dataset))
466
+ train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
465
467
  current_loader_kwargs = kwargs.copy()
466
468
  current_loader_kwargs.update(self.kwargs)
467
469
  return DataLoader(
@@ -492,8 +494,8 @@ class DataModule(L.LightningDataModule):
492
494
  def predict_dataloader(self):
493
495
  subset = Subset(self.dataset, self.idx_full)
494
496
  return DataLoader(
495
- subset,
496
- sampler=RankShardSampler(len(subset)),
497
+ self.dataset,
498
+ sampler=RankShardSampler(len(subset), start_at=self.start_at),
497
499
  **self.kwargs,
498
500
  )
499
501
 
@@ -667,7 +669,9 @@ class LabelWeightedSampler(Sampler[int]):
667
669
  unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
668
670
 
669
671
  # Initialize result tensor
670
- result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
672
+ result_indices_list = (
673
+ []
674
+ ) # Changed name to avoid conflict if you had result_indices elsewhere
671
675
 
672
676
  # Process only the classes that were actually sampled
673
677
  for i, (label, count) in tqdm(
@@ -850,8 +854,9 @@ class RankShardSampler(Sampler[int]):
850
854
  """Shards a dataset contiguously across ranks without padding or duplicates.
851
855
  Preserves the existing order (e.g., your pre-shuffled idx_full)."""
852
856
 
853
- def __init__(self, data_len: int):
857
+ def __init__(self, data_len: int, start_at: int = 0) -> None:
854
858
  self.data_len = data_len
859
+ self.start_at = start_at
855
860
  if torch.distributed.is_available() and torch.distributed.is_initialized():
856
861
  self.rank = torch.distributed.get_rank()
857
862
  self.world_size = torch.distributed.get_world_size()
@@ -859,9 +864,16 @@ class RankShardSampler(Sampler[int]):
859
864
  self.rank, self.world_size = 0, 1
860
865
 
861
866
  # contiguous chunk per rank (last rank may be shorter)
862
- per_rank = math.ceil(self.data_len / self.world_size)
863
- self.start = self.rank * per_rank
867
+ if self.start_at > 0:
868
+ print(
869
+ "!!!!ATTTENTION: make sure that you are running on the exact same \
870
+ number of GPU as your previous run!!!!!"
871
+ )
872
+ print(f"Sharding data of size {data_len} over {self.world_size} ranks")
873
+ per_rank = math.ceil((self.data_len - self.start_at) / self.world_size)
874
+ self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
864
875
  self.end = min(self.start + per_rank, self.data_len)
876
+ print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
865
877
 
866
878
  def __iter__(self):
867
879
  return iter(range(self.start, self.end))
File without changes
File without changes
File without changes