scdataloader 1.9.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,9 @@
1
+ import multiprocessing as mp
1
2
  import os
3
+ import random
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor, as_completed
6
+ from functools import partial
2
7
  from typing import Optional, Sequence, Union
3
8
 
4
9
  import lamindb as ln
@@ -13,10 +18,11 @@ from torch.utils.data.sampler import (
13
18
  SubsetRandomSampler,
14
19
  WeightedRandomSampler,
15
20
  )
21
+ from tqdm import tqdm
16
22
 
17
23
  from .collator import Collator
18
24
  from .data import Dataset
19
- from .utils import getBiomartTable, slurm_restart_count
25
+ from .utils import fileToList, getBiomartTable, listToFile
20
26
 
21
27
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
22
28
 
@@ -26,9 +32,8 @@ class DataModule(L.LightningDataModule):
26
32
  self,
27
33
  collection_name: str,
28
34
  clss_to_weight: list = ["organism_ontology_term_id"],
29
- organisms: list = ["NCBITaxon:9606"],
30
35
  weight_scaler: int = 10,
31
- train_oversampling_per_epoch: float = 0.1,
36
+ n_samples_per_epoch: int = 2_000_000,
32
37
  validation_split: float = 0.2,
33
38
  test_split: float = 0,
34
39
  gene_embeddings: str = "",
@@ -43,7 +48,7 @@ class DataModule(L.LightningDataModule):
43
48
  max_len: int = 1000,
44
49
  add_zero_genes: int = 100,
45
50
  replacement: bool = True,
46
- do_gene_pos: Union[bool, str] = True,
51
+ do_gene_pos: str = "",
47
52
  tp_name: Optional[str] = None, # "heat_diff"
48
53
  assays_to_drop: list = [
49
54
  # "EFO:0008853", #patch seq
@@ -53,7 +58,10 @@ class DataModule(L.LightningDataModule):
53
58
  ],
54
59
  metacell_mode: float = 0.0,
55
60
  get_knn_cells: bool = False,
56
- modify_seed_on_requeue: bool = True,
61
+ store_location: str = None,
62
+ force_recompute_indices: bool = False,
63
+ sampler_workers: int = None,
64
+ sampler_chunk_size: int = None,
57
65
  **kwargs,
58
66
  ):
59
67
  """
@@ -65,9 +73,8 @@ class DataModule(L.LightningDataModule):
65
73
 
66
74
  Args:
67
75
  collection_name (str): The lamindb collection to be used.
68
- organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
69
76
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
70
- train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
77
+ n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
71
78
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
72
79
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
73
80
  it will use a full dataset and will round to the nearest dataset's cell count.
@@ -88,61 +95,38 @@ class DataModule(L.LightningDataModule):
88
95
  hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
89
96
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
90
97
  clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
91
- modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
92
98
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
99
+ store_location (str, optional): The location to store the sampler indices. Defaults to None.
100
+ force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
101
+ sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
102
+ sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
93
103
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
94
104
  see @file data.py and @file collator.py for more details about some of the parameters
95
105
  """
96
- if collection_name is not None:
97
- mdataset = Dataset(
98
- ln.Collection.filter(name=collection_name).first(),
99
- organisms=organisms,
100
- clss_to_predict=clss_to_predict,
101
- hierarchical_clss=hierarchical_clss,
102
- metacell_mode=metacell_mode,
103
- get_knn_cells=get_knn_cells,
106
+ if "organism_ontology_term_id" not in clss_to_predict:
107
+ raise ValueError(
108
+ "need 'organism_ontology_term_id' in the set of classes at least"
104
109
  )
110
+ mdataset = Dataset(
111
+ ln.Collection.filter(name=collection_name, is_latest=True).first(),
112
+ clss_to_predict=clss_to_predict,
113
+ hierarchical_clss=hierarchical_clss,
114
+ metacell_mode=metacell_mode,
115
+ get_knn_cells=get_knn_cells,
116
+ store_location=store_location,
117
+ force_recompute_indices=force_recompute_indices,
118
+ )
105
119
  # and location
106
120
  self.metacell_mode = bool(metacell_mode)
107
121
  self.gene_pos = None
108
122
  self.collection_name = collection_name
109
123
  if do_gene_pos:
110
- if type(do_gene_pos) is str:
111
- print("seeing a string: loading gene positions as biomart parquet file")
112
- biomart = pd.read_parquet(do_gene_pos)
113
- else:
114
- # and annotations
115
- if organisms != ["NCBITaxon:9606"]:
116
- raise ValueError(
117
- "need to provide your own table as this automated function only works for humans for now"
118
- )
119
- biomart = getBiomartTable(
120
- attributes=["start_position", "chromosome_name"],
121
- useCache=True,
122
- ).set_index("ensembl_gene_id")
123
- biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
124
- biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
125
- c = []
126
- i = 0
127
- prev_position = -100000
128
- prev_chromosome = None
129
- for _, r in biomart.iterrows():
130
- if (
131
- r["chromosome_name"] != prev_chromosome
132
- or r["start_position"] - prev_position > gene_position_tolerance
133
- ):
134
- i += 1
135
- c.append(i)
136
- prev_position = r["start_position"]
137
- prev_chromosome = r["chromosome_name"]
138
- print(f"reduced the size to {len(set(c)) / len(biomart)}")
139
- biomart["pos"] = c
124
+ biomart = pd.read_parquet(do_gene_pos)
140
125
  mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
141
126
  self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
142
-
143
127
  if gene_embeddings != "":
144
128
  mdataset.genedf = mdataset.genedf.join(
145
- pd.read_parquet(gene_embeddings), how="inner"
129
+ pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
146
130
  )
147
131
  if do_gene_pos:
148
132
  self.gene_pos = mdataset.genedf["pos"].tolist()
@@ -151,7 +135,7 @@ class DataModule(L.LightningDataModule):
151
135
  # we might want to not introduce zeros and
152
136
  if use_default_col:
153
137
  kwargs["collate_fn"] = Collator(
154
- organisms=organisms,
138
+ organisms=mdataset.organisms,
155
139
  how=how,
156
140
  valid_genes=mdataset.genedf.index.tolist(),
157
141
  max_len=max_len,
@@ -159,7 +143,7 @@ class DataModule(L.LightningDataModule):
159
143
  org_to_id=mdataset.encoder[organism_name],
160
144
  tp_name=tp_name,
161
145
  organism_name=organism_name,
162
- class_names=clss_to_predict,
146
+ class_names=list(self.classes.keys()),
163
147
  )
164
148
  self.validation_split = validation_split
165
149
  self.test_split = test_split
@@ -171,16 +155,19 @@ class DataModule(L.LightningDataModule):
171
155
  self.assays_to_drop = assays_to_drop
172
156
  self.n_samples = len(mdataset)
173
157
  self.weight_scaler = weight_scaler
174
- self.train_oversampling_per_epoch = train_oversampling_per_epoch
158
+ self.n_samples_per_epoch = n_samples_per_epoch
175
159
  self.clss_to_weight = clss_to_weight
176
160
  self.train_weights = None
177
161
  self.train_labels = None
178
- self.modify_seed_on_requeue = modify_seed_on_requeue
162
+ self.sampler_workers = sampler_workers
163
+ self.sampler_chunk_size = sampler_chunk_size
164
+ self.store_location = store_location
179
165
  self.nnz = None
180
- self.restart_num = 0
181
166
  self.test_datasets = []
167
+ self.force_recompute_indices = force_recompute_indices
182
168
  self.test_idx = []
183
169
  super().__init__()
170
+ print("finished init")
184
171
 
185
172
  def __repr__(self):
186
173
  return (
@@ -190,7 +177,7 @@ class DataModule(L.LightningDataModule):
190
177
  f"\ttest_split={self.test_split},\n"
191
178
  f"\tn_samples={self.n_samples},\n"
192
179
  f"\tweight_scaler={self.weight_scaler},\n"
193
- f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
180
+ f"\tn_samples_per_epoch={self.n_samples_per_epoch},\n"
194
181
  f"\tassays_to_drop={self.assays_to_drop},\n"
195
182
  f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
196
183
  f"\ttest datasets={str(self.test_datasets)},\n"
@@ -242,6 +229,44 @@ class DataModule(L.LightningDataModule):
242
229
  """
243
230
  return self.dataset.genedf.index.tolist()
244
231
 
232
+ @genes.setter
233
+ def genes(self, genes):
234
+ self.dataset.genedf = self.dataset.genedf.loc[genes]
235
+ self.kwargs["collate_fn"].genes = genes
236
+ self.kwargs["collate_fn"]._setup(
237
+ genedf=self.dataset.genedf,
238
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
239
+ valid_genes=genes,
240
+ )
241
+
242
+ @property
243
+ def encoders(self):
244
+ return self.dataset.encoder
245
+
246
+ @encoders.setter
247
+ def encoders(self, encoders):
248
+ self.dataset.encoder = encoders
249
+ self.kwargs["collate_fn"].org_to_id = encoders[
250
+ self.kwargs["collate_fn"].organism_name
251
+ ]
252
+ self.kwargs["collate_fn"]._setup(
253
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
254
+ valid_genes=self.genes,
255
+ )
256
+
257
+ @property
258
+ def organisms(self):
259
+ return self.dataset.organisms
260
+
261
+ @organisms.setter
262
+ def organisms(self, organisms):
263
+ self.dataset.organisms = organisms
264
+ self.kwargs["collate_fn"].organisms = organisms
265
+ self.kwargs["collate_fn"]._setup(
266
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
267
+ valid_genes=self.genes,
268
+ )
269
+
245
270
  @property
246
271
  def num_datasets(self):
247
272
  return len(self.dataset.mapped_dataset.storages)
@@ -256,106 +281,191 @@ class DataModule(L.LightningDataModule):
256
281
  It can be either 'fit' or 'test'. Defaults to None.
257
282
  """
258
283
  SCALE = 10
259
- if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
260
- self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
261
- self.clss_to_weight.remove("nnz")
262
- (
263
- (self.nnz.max() / SCALE)
264
- / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
265
- ).min()
266
- if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
267
- weights, labels = self.dataset.get_label_weights(
268
- self.clss_to_weight,
269
- scaler=self.weight_scaler,
270
- return_categories=True,
284
+ print("setting up the datamodule")
285
+ start_time = time.time()
286
+ if (
287
+ self.store_location is None
288
+ or not os.path.exists(
289
+ os.path.join(self.store_location, "train_weights.npy")
271
290
  )
272
- else:
273
- weights = np.ones(1)
274
- labels = np.zeros(self.n_samples, dtype=int)
275
- if isinstance(self.validation_split, int):
276
- len_valid = self.validation_split
277
- else:
278
- len_valid = int(self.n_samples * self.validation_split)
279
- if isinstance(self.test_split, int):
280
- len_test = self.test_split
281
- else:
282
- len_test = int(self.n_samples * self.test_split)
283
- assert (
284
- len_test + len_valid < self.n_samples
285
- ), "test set + valid set size is configured to be larger than entire dataset."
286
-
287
- idx_full = []
288
- if len(self.assays_to_drop) > 0:
289
- badloc = np.isin(
290
- self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
291
- self.assays_to_drop,
292
- )
293
- idx_full = np.arange(len(labels))[~badloc]
294
- else:
295
- idx_full = np.arange(self.n_samples)
296
- if len_test > 0:
297
- # this way we work on some never seen datasets
298
- # keeping at least one
299
- len_test = (
300
- len_test
301
- if len_test > self.dataset.mapped_dataset.n_obs_list[0]
302
- else self.dataset.mapped_dataset.n_obs_list[0]
291
+ or self.force_recompute_indices
292
+ ):
293
+ if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
294
+ self.nnz = self.dataset.mapped_dataset.get_merged_labels(
295
+ "nnz", is_cat=False
296
+ )
297
+ self.clss_to_weight.remove("nnz")
298
+ (
299
+ (self.nnz.max() / SCALE)
300
+ / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
301
+ ).min()
302
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
303
+ weights, labels = self.dataset.get_label_weights(
304
+ self.clss_to_weight,
305
+ scaler=self.weight_scaler,
306
+ return_categories=True,
307
+ )
308
+ else:
309
+ weights = np.ones(1)
310
+ labels = np.zeros(self.n_samples, dtype=int)
311
+ if isinstance(self.validation_split, int):
312
+ len_valid = self.validation_split
313
+ else:
314
+ len_valid = int(self.n_samples * self.validation_split)
315
+ if isinstance(self.test_split, int):
316
+ len_test = self.test_split
317
+ else:
318
+ len_test = int(self.n_samples * self.test_split)
319
+ assert len_test + len_valid < self.n_samples, (
320
+ "test set + valid set size is configured to be larger than entire dataset."
303
321
  )
304
- cs = 0
305
- for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
306
- if cs + c > len_test:
307
- break
308
- else:
309
- self.test_datasets.append(
310
- self.dataset.mapped_dataset.path_list[i].path
311
- )
312
- cs += c
313
- len_test = cs
314
- self.test_idx = idx_full[:len_test]
315
- idx_full = idx_full[len_test:]
316
- else:
317
- self.test_idx = None
318
322
 
319
- np.random.shuffle(idx_full)
320
- if len_valid > 0:
321
- self.valid_idx = idx_full[:len_valid].copy()
322
- # store it for later
323
- idx_full = idx_full[len_valid:]
324
- else:
325
- self.valid_idx = None
326
- weights = np.concatenate([weights, np.zeros(1)])
327
- labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
328
- # some labels will now not exist anymore as replaced by len(weights) - 1.
329
- # this means that the associated weights should be 0.
330
- # by doing np.bincount(labels)*weights this will be taken into account
331
- self.train_weights = weights
332
- self.train_labels = labels
333
- self.idx_full = idx_full
323
+ idx_full = []
324
+ if len(self.assays_to_drop) > 0:
325
+ badloc = np.isin(
326
+ self.dataset.mapped_dataset.get_merged_labels(
327
+ "assay_ontology_term_id"
328
+ ),
329
+ self.assays_to_drop,
330
+ )
331
+ idx_full = np.arange(len(labels))[~badloc]
332
+ else:
333
+ idx_full = np.arange(self.n_samples)
334
+ if len_test > 0:
335
+ # this way we work on some never seen datasets
336
+ # keeping at least one
337
+ len_test = (
338
+ len_test
339
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
340
+ else self.dataset.mapped_dataset.n_obs_list[0]
341
+ )
342
+ cs = 0
343
+ d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
344
+ random.Random(42).shuffle(d_size) # always same order
345
+ for i, c in d_size:
346
+ if cs + c > len_test:
347
+ break
348
+ else:
349
+ self.test_datasets.append(
350
+ self.dataset.mapped_dataset.path_list[i].path
351
+ )
352
+ cs += c
353
+ len_test = cs
354
+ self.test_idx = idx_full[:len_test]
355
+ idx_full = idx_full[len_test:]
356
+ else:
357
+ self.test_idx = None
358
+
359
+ np.random.shuffle(idx_full)
360
+ if len_valid > 0:
361
+ self.valid_idx = idx_full[:len_valid].copy()
362
+ # store it for later
363
+ idx_full = idx_full[len_valid:]
364
+ else:
365
+ self.valid_idx = None
366
+ weights = np.concatenate([weights, np.zeros(1)])
367
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
368
+ # some labels will now not exist anymore as replaced by len(weights) - 1.
369
+ # this means that the associated weights should be 0.
370
+ # by doing np.bincount(labels)*weights this will be taken into account
371
+ self.train_weights = weights
372
+ self.train_labels = labels
373
+ self.idx_full = idx_full
374
+ if self.store_location is not None:
375
+ if (
376
+ not os.path.exists(
377
+ os.path.join(self.store_location, "train_weights.npy")
378
+ )
379
+ or self.force_recompute_indices
380
+ ):
381
+ os.makedirs(self.store_location, exist_ok=True)
382
+ if self.nnz is not None:
383
+ np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
384
+ np.save(
385
+ os.path.join(self.store_location, "train_weights.npy"),
386
+ self.train_weights,
387
+ )
388
+ np.save(
389
+ os.path.join(self.store_location, "train_labels.npy"),
390
+ self.train_labels,
391
+ )
392
+ np.save(
393
+ os.path.join(self.store_location, "idx_full.npy"), self.idx_full
394
+ )
395
+ if self.test_idx is not None:
396
+ np.save(
397
+ os.path.join(self.store_location, "test_idx.npy"), self.test_idx
398
+ )
399
+ if self.valid_idx is not None:
400
+ np.save(
401
+ os.path.join(self.store_location, "valid_idx.npy"),
402
+ self.valid_idx,
403
+ )
404
+ listToFile(
405
+ self.test_datasets,
406
+ os.path.join(self.store_location, "test_datasets.txt"),
407
+ )
408
+ else:
409
+ self.nnz = (
410
+ np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
411
+ if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
412
+ else None
413
+ )
414
+ self.train_weights = np.load(
415
+ os.path.join(self.store_location, "train_weights.npy")
416
+ )
417
+ self.train_labels = np.load(
418
+ os.path.join(self.store_location, "train_labels.npy")
419
+ )
420
+ self.idx_full = np.load(
421
+ os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
422
+ )
423
+ self.test_idx = (
424
+ np.load(os.path.join(self.store_location, "test_idx.npy"))
425
+ if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
426
+ else None
427
+ )
428
+ self.valid_idx = (
429
+ np.load(os.path.join(self.store_location, "valid_idx.npy"))
430
+ if os.path.exists(
431
+ os.path.join(self.store_location, "valid_idx.npy")
432
+ )
433
+ else None
434
+ )
435
+ self.test_datasets = fileToList(
436
+ os.path.join(self.store_location, "test_datasets.txt")
437
+ )
438
+ print("loaded from store")
439
+ print(f"done setup, took {time.time() - start_time:.2f} seconds")
334
440
  return self.test_datasets
335
441
 
336
442
  def train_dataloader(self, **kwargs):
337
- # train_sampler = WeightedRandomSampler(
338
- # self.train_weights[self.train_labels],
339
- # int(self.n_samples*self.train_oversampling_per_epoch),
340
- # replacement=True,
341
- # )
342
- try:
343
- train_sampler = LabelWeightedSampler(
344
- label_weights=self.train_weights,
345
- labels=self.train_labels,
346
- num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
347
- element_weights=self.nnz,
348
- replacement=self.replacement,
349
- restart_num=self.restart_num,
350
- modify_seed_on_requeue=self.modify_seed_on_requeue,
351
- )
352
- except ValueError as e:
353
- raise ValueError(e + "have you run `datamodule.setup()`?")
443
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
444
+ try:
445
+ print("Setting up the parallel train sampler...")
446
+ # Create the optimized parallel sampler
447
+ print(f"Using {self.sampler_workers} workers for class indexing")
448
+ train_sampler = LabelWeightedSampler(
449
+ label_weights=self.train_weights,
450
+ labels=self.train_labels,
451
+ num_samples=int(self.n_samples_per_epoch),
452
+ element_weights=self.nnz,
453
+ replacement=self.replacement,
454
+ n_workers=self.sampler_workers,
455
+ chunk_size=self.sampler_chunk_size,
456
+ store_location=self.store_location,
457
+ force_recompute_indices=self.force_recompute_indices,
458
+ )
459
+ except ValueError as e:
460
+ raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
461
+ else:
462
+ train_sampler = SubsetRandomSampler(self.idx_full)
463
+ current_loader_kwargs = kwargs.copy()
464
+ current_loader_kwargs.update(self.kwargs)
354
465
  return DataLoader(
355
466
  self.dataset,
356
467
  sampler=train_sampler,
357
- **self.kwargs,
358
- **kwargs,
468
+ **current_loader_kwargs,
359
469
  )
360
470
 
361
471
  def val_dataloader(self):
@@ -385,115 +495,335 @@ class DataModule(L.LightningDataModule):
385
495
  **self.kwargs,
386
496
  )
387
497
 
388
- # def teardown(self):
389
- # clean up state after the trainer stops, delete files...
390
- # called on every process in DDP
391
- # pass
392
-
393
498
 
394
499
  class LabelWeightedSampler(Sampler[int]):
395
- label_weights: Sequence[float]
396
- klass_indices: Sequence[Sequence[int]]
500
+ """
501
+ A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
502
+
503
+ This sampler is designed to handle very large datasets efficiently, with optimizations for:
504
+ 1. Parallel building of class indices
505
+ 2. Chunked processing for large arrays
506
+ 3. Efficient memory management
507
+ 4. Proper handling of replacement and non-replacement sampling
508
+ """
509
+
510
+ label_weights: torch.Tensor
511
+ klass_indices: dict[int, torch.Tensor]
397
512
  num_samples: int
398
- nnz: Optional[Sequence[int]]
513
+ element_weights: Optional[torch.Tensor]
399
514
  replacement: bool
400
- restart_num: int
401
- modify_seed_on_requeue: bool
402
- # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
403
- # this will result a class-balanced sampling, no matter how imbalance the labels are.
404
515
 
405
516
  def __init__(
406
517
  self,
407
518
  label_weights: Sequence[float],
408
- labels: Sequence[int],
519
+ labels: np.ndarray,
409
520
  num_samples: int,
410
521
  replacement: bool = True,
411
- element_weights: Sequence[float] = None,
412
- restart_num: int = 0,
413
- modify_seed_on_requeue: bool = True,
522
+ element_weights: Optional[Sequence[float]] = None,
523
+ n_workers: int = None,
524
+ chunk_size: int = None, # Process 10M elements per chunk
525
+ store_location: str = None,
526
+ force_recompute_indices: bool = False,
414
527
  ) -> None:
415
528
  """
529
+ Initialize the sampler with parallel processing for large datasets.
416
530
 
417
- :param label_weights: list(len=num_classes)[float], weights for each class.
418
- :param labels: list(len=dataset_len)[int], labels of a dataset.
419
- :param num_samples: number of samples.
420
- :param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
531
+ Args:
532
+ label_weights: Weights for each class (length = number of classes)
533
+ labels: Class label for each dataset element (length = dataset size)
534
+ num_samples: Number of samples to draw
535
+ replacement: Whether to sample with replacement
536
+ element_weights: Optional weights for each element within classes
537
+ n_workers: Number of parallel workers to use (default: number of CPUs-1)
538
+ chunk_size: Size of chunks to process in parallel (default: 10M elements)
421
539
  """
422
-
540
+ print("Initializing optimized parallel weighted sampler...")
423
541
  super(LabelWeightedSampler, self).__init__(None)
424
- # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
425
- label_weights = np.array(label_weights) * np.bincount(labels)
426
-
427
- self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
428
- self.labels = torch.as_tensor(labels, dtype=torch.int)
429
- self.element_weights = (
430
- torch.as_tensor(element_weights, dtype=torch.float32)
431
- if element_weights is not None
432
- else None
433
- )
542
+
543
+ # Compute label weights (incorporating class frequencies)
544
+ # Directly use labels as numpy array without conversion
545
+ label_weights = np.asarray(label_weights) * np.bincount(labels)
546
+ self.label_weights = torch.as_tensor(
547
+ label_weights, dtype=torch.float32
548
+ ).share_memory_()
549
+
550
+ # Store element weights if provided
551
+ if element_weights is not None:
552
+ self.element_weights = torch.as_tensor(
553
+ element_weights, dtype=torch.float32
554
+ ).share_memory_()
555
+ else:
556
+ self.element_weights = None
557
+
434
558
  self.replacement = replacement
435
559
  self.num_samples = num_samples
436
- self.restart_num = slurm_restart_count(use_mine=True) + restart_num
437
- self.modify_seed_on_requeue = modify_seed_on_requeue
438
- # list of tensor.
439
- self.klass_indices = [
440
- (self.labels == i_klass).nonzero().squeeze(1)
441
- for i_klass in range(len(label_weights))
442
- ]
443
- self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
560
+ if (
561
+ store_location is None
562
+ or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
563
+ or force_recompute_indices
564
+ ):
565
+ # Set number of workers (default to CPU count - 1, but at least 1)
566
+ if n_workers is None:
567
+ # Check if running on SLURM
568
+ n_workers = min(20, max(1, mp.cpu_count() - 1))
569
+ if "SLURM_CPUS_PER_TASK" in os.environ:
570
+ n_workers = min(
571
+ 20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
572
+ )
573
+
574
+ # Try to auto-determine optimal chunk size based on memory
575
+ if chunk_size is None:
576
+ try:
577
+ import psutil
578
+
579
+ # Check if running on SLURM
580
+ available_memory = psutil.virtual_memory().available
581
+ for name in [
582
+ "SLURM_MEM_PER_NODE",
583
+ "SLURM_MEM_PER_CPU",
584
+ "SLURM_MEM_PER_GPU",
585
+ "SLURM_MEM_PER_TASK",
586
+ ]:
587
+ if name in os.environ:
588
+ available_memory = (
589
+ int(os.environ[name]) * 1024 * 1024
590
+ ) # Convert MB to bytes
591
+ break
592
+
593
+ # Use at most 50% of available memory across all workers
594
+ memory_per_worker = 0.5 * available_memory / n_workers
595
+ # Rough estimate: each label takes 4 bytes, each index 8 bytes
596
+ bytes_per_element = 12
597
+ chunk_size = min(
598
+ max(100_000, int(memory_per_worker / bytes_per_element / 3)),
599
+ 2_000_000,
600
+ )
601
+ print(f"Auto-determined chunk size: {chunk_size:,} elements")
602
+ except (ImportError, KeyError):
603
+ chunk_size = 2_000_000
604
+ print(f"Using default chunk size: {chunk_size:,} elements")
605
+
606
+ # Parallelize the class indices building
607
+ print(f"Building class indices in parallel with {n_workers} workers...")
608
+ klass_indices = self._build_class_indices_parallel(
609
+ labels, chunk_size, n_workers
610
+ )
611
+
612
+ # Convert klass_indices to a single tensor and offset vector
613
+ all_indices = []
614
+ offsets = []
615
+ current_offset = 0
616
+
617
+ # Sort keys to ensure consistent ordering
618
+ keys = klass_indices.keys()
619
+
620
+ # Build concatenated tensor and track offsets
621
+ for i in range(max(keys) + 1):
622
+ offsets.append(current_offset)
623
+ if i in keys:
624
+ indices = klass_indices[i]
625
+ all_indices.append(indices)
626
+ current_offset += len(indices)
627
+
628
+ # Convert to tensors
629
+ self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
630
+ self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
631
+ if store_location is not None:
632
+ store_path = os.path.join(store_location, "klass_indices.pt")
633
+ if os.path.exists(store_path) and not force_recompute_indices:
634
+ self.klass_indices = torch.load(store_path).share_memory_()
635
+ self.klass_offsets = torch.load(
636
+ store_path.replace(".pt", "_offsets.pt")
637
+ ).share_memory_()
638
+ print(f"Loaded sampler indices from {store_path}")
639
+ else:
640
+ torch.save(self.klass_indices, store_path)
641
+ torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
642
+ print(f"Saved sampler indices to {store_path}")
643
+ print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
444
644
 
445
645
  def __iter__(self):
646
+ # Sample classes according to their weights
647
+ print("sampling a new batch of size", self.num_samples)
648
+
446
649
  sample_labels = torch.multinomial(
447
650
  self.label_weights,
448
651
  num_samples=self.num_samples,
449
652
  replacement=True,
450
- generator=None
451
- if self.restart_num == 0 and not self.modify_seed_on_requeue
452
- else torch.Generator().manual_seed(self.restart_num),
453
653
  )
454
- sample_indices = torch.empty_like(sample_labels)
455
- for i_klass, klass_index in enumerate(self.klass_indices):
654
+ # Get counts of each class in sample_labels
655
+ unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
656
+
657
+ # Initialize result tensor
658
+ result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
659
+
660
+ # Process only the classes that were actually sampled
661
+ for i, (label, count) in tqdm(
662
+ enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
663
+ total=len(unique_samples),
664
+ desc="Processing classes in sampler",
665
+ ):
666
+ klass_index = self.klass_indices[
667
+ self.klass_offsets[label] : self.klass_offsets[label + 1]
668
+ ]
669
+
456
670
  if klass_index.numel() == 0:
457
671
  continue
458
- left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
459
- if len(left_inds) == 0:
460
- continue
672
+
673
+ # Sample elements from this class
461
674
  if self.element_weights is not None:
462
- right_inds = torch.multinomial(
463
- self.element_weights[klass_index],
464
- num_samples=len(klass_index)
465
- if not self.replacement and len(klass_index) < len(left_inds)
466
- else len(left_inds),
467
- replacement=self.replacement,
468
- generator=None
469
- if self.restart_num == 0 and not self.modify_seed_on_requeue
470
- else torch.Generator().manual_seed(self.restart_num),
471
- )
675
+ # This is a critical point for memory
676
+ current_element_weights_slice = self.element_weights[klass_index]
677
+
678
+ if self.replacement:
679
+ right_inds = torch.multinomial(
680
+ current_element_weights_slice,
681
+ num_samples=count,
682
+ replacement=True,
683
+ )
684
+ else:
685
+ num_to_sample = min(count, len(klass_index))
686
+ right_inds = torch.multinomial(
687
+ current_element_weights_slice,
688
+ num_samples=num_to_sample,
689
+ replacement=False,
690
+ )
472
691
  elif self.replacement:
473
- right_inds = torch.randint(
474
- len(klass_index),
475
- size=(len(left_inds),),
476
- generator=None
477
- if self.restart_num == 0 and not self.modify_seed_on_requeue
478
- else torch.Generator().manual_seed(self.restart_num),
479
- )
692
+ right_inds = torch.randint(len(klass_index), size=(count,))
480
693
  else:
481
- maxelem = (
482
- len(left_inds)
483
- if len(left_inds) < len(klass_index)
484
- else len(klass_index)
485
- )
486
- right_inds = torch.randperm(len(klass_index))[:maxelem]
487
- sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
488
- if len(right_inds) < len(left_inds):
489
- sample_indices[left_inds[len(right_inds) :]] = -1
490
- # drop all -1
491
- sample_indices = sample_indices[sample_indices != -1]
492
- # torch shuffle
493
- sample_indices = sample_indices[torch.randperm(len(sample_indices))]
494
- self.num_samples = len(sample_indices)
495
- # raise Exception("stop")
496
- yield from iter(sample_indices.tolist())
694
+ num_to_sample = min(count, len(klass_index))
695
+ right_inds = torch.randperm(len(klass_index))[:num_to_sample]
696
+
697
+ # Get actual indices
698
+ sampled_indices = klass_index[right_inds]
699
+ result_indices_list.append(sampled_indices)
700
+
701
+ # Combine all indices
702
+ if result_indices_list: # Check if the list is not empty
703
+ final_result_indices = torch.cat(
704
+ result_indices_list
705
+ ) # Use the list with the appended new name
706
+
707
+ # Shuffle the combined indices
708
+ shuffled_indices = final_result_indices[
709
+ torch.randperm(len(final_result_indices))
710
+ ]
711
+ self.num_samples = len(shuffled_indices)
712
+ yield from shuffled_indices.tolist()
713
+ else:
714
+ self.num_samples = 0
715
+ yield from iter([])
497
716
 
498
717
  def __len__(self):
499
718
  return self.num_samples
719
+
720
+ def _merge_chunk_results(self, results_list):
721
+ """Merge results from multiple chunks into a single dictionary.
722
+
723
+ Args:
724
+ results_list: list of dictionaries mapping class labels to index arrays
725
+
726
+ Returns:
727
+ merged dictionary with PyTorch tensors
728
+ """
729
+ merged = {}
730
+
731
+ # Collect all labels across all chunks
732
+ all_labels = set()
733
+ for chunk_result in results_list:
734
+ all_labels.update(chunk_result.keys())
735
+
736
+ # For each unique label
737
+ for label in all_labels:
738
+ # Collect indices from all chunks where this label appears
739
+ indices_lists = [
740
+ chunk_result[label]
741
+ for chunk_result in results_list
742
+ if label in chunk_result
743
+ ]
744
+
745
+ if indices_lists:
746
+ # Concatenate all indices for this label
747
+ merged[label] = torch.tensor(
748
+ np.concatenate(indices_lists), dtype=torch.long
749
+ )
750
+ else:
751
+ merged[label] = torch.tensor([], dtype=torch.long)
752
+
753
+ return merged
754
+
755
+ def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
756
+ """Build class indices in parallel across multiple workers.
757
+
758
+ Args:
759
+ labels: array of class labels
760
+ n_workers: number of parallel workers
761
+ chunk_size: size of chunks to process
762
+
763
+ Returns:
764
+ dictionary mapping class labels to tensors of indices
765
+ """
766
+ n = len(labels)
767
+ results = []
768
+ # Create chunks of the labels array with proper sizing
769
+ n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
770
+ print(f"Processing {n:,} elements in {n_chunks} chunks...")
771
+
772
+ # Process in chunks to limit memory usage
773
+ with ProcessPoolExecutor(
774
+ max_workers=n_workers, mp_context=mp.get_context("spawn")
775
+ ) as executor:
776
+ # Submit chunks for processing
777
+ futures = []
778
+ for i in range(n_chunks):
779
+ start_idx = i * chunk_size
780
+ end_idx = min((i + 1) * chunk_size, n)
781
+ # We pass only chunk boundaries, not the data itself
782
+ # This avoids unnecessary copies during process creation
783
+ futures.append(
784
+ executor.submit(
785
+ self._process_chunk_with_slice,
786
+ (start_idx, end_idx, labels),
787
+ )
788
+ )
789
+
790
+ # Collect results as they complete with progress reporting
791
+ for future in tqdm(
792
+ as_completed(futures), total=len(futures), desc="Processing chunks"
793
+ ):
794
+ results.append(future.result())
795
+
796
+ # Merge results from all chunks
797
+ print("Merging results from all chunks...")
798
+ merged_results = self._merge_chunk_results(results)
799
+
800
+ return merged_results
801
+
802
+ def _process_chunk_with_slice(self, slice_info):
803
+ """Process a slice of the labels array by indices.
804
+
805
+ Args:
806
+ slice_info: tuple of (start_idx, end_idx, labels_array) where
807
+ start_idx and end_idx define the slice to process
808
+
809
+ Returns:
810
+ dict mapping class labels to arrays of indices
811
+ """
812
+ start_idx, end_idx, labels_array = slice_info
813
+
814
+ # We're processing a slice of the original array
815
+ labels_slice = labels_array[start_idx:end_idx]
816
+ chunk_indices = {}
817
+
818
+ # Create a direct map of indices
819
+ indices = np.arange(start_idx, end_idx)
820
+
821
+ # Get unique labels in this slice for more efficient processing
822
+ unique_labels = np.unique(labels_slice)
823
+ # For each valid label, find its indices
824
+ for label in unique_labels:
825
+ # Find positions where this label appears (using direct boolean indexing)
826
+ label_mask = labels_slice == label
827
+ chunk_indices[int(label)] = indices[label_mask]
828
+
829
+ return chunk_indices