scdataloader 1.9.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,9 @@
1
+ import multiprocessing as mp
1
2
  import os
3
+ import random
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor, as_completed
6
+ from functools import partial
2
7
  from typing import Optional, Sequence, Union
3
8
 
4
9
  import lamindb as ln
@@ -13,10 +18,11 @@ from torch.utils.data.sampler import (
13
18
  SubsetRandomSampler,
14
19
  WeightedRandomSampler,
15
20
  )
21
+ from tqdm import tqdm
16
22
 
17
23
  from .collator import Collator
18
24
  from .data import Dataset
19
- from .utils import getBiomartTable
25
+ from .utils import fileToList, getBiomartTable, listToFile
20
26
 
21
27
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
22
28
 
@@ -26,7 +32,6 @@ class DataModule(L.LightningDataModule):
26
32
  self,
27
33
  collection_name: str,
28
34
  clss_to_weight: list = ["organism_ontology_term_id"],
29
- organisms: list = ["NCBITaxon:9606"],
30
35
  weight_scaler: int = 10,
31
36
  n_samples_per_epoch: int = 2_000_000,
32
37
  validation_split: float = 0.2,
@@ -43,7 +48,7 @@ class DataModule(L.LightningDataModule):
43
48
  max_len: int = 1000,
44
49
  add_zero_genes: int = 100,
45
50
  replacement: bool = True,
46
- do_gene_pos: Union[bool, str] = True,
51
+ do_gene_pos: str = "",
47
52
  tp_name: Optional[str] = None, # "heat_diff"
48
53
  assays_to_drop: list = [
49
54
  # "EFO:0008853", #patch seq
@@ -53,6 +58,10 @@ class DataModule(L.LightningDataModule):
53
58
  ],
54
59
  metacell_mode: float = 0.0,
55
60
  get_knn_cells: bool = False,
61
+ store_location: str = None,
62
+ force_recompute_indices: bool = False,
63
+ sampler_workers: int = None,
64
+ sampler_chunk_size: int = None,
56
65
  **kwargs,
57
66
  ):
58
67
  """
@@ -64,7 +73,6 @@ class DataModule(L.LightningDataModule):
64
73
 
65
74
  Args:
66
75
  collection_name (str): The lamindb collection to be used.
67
- organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
68
76
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
69
77
  n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
70
78
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
@@ -88,59 +96,37 @@ class DataModule(L.LightningDataModule):
88
96
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
89
97
  clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
90
98
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
99
+ store_location (str, optional): The location to store the sampler indices. Defaults to None.
100
+ force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
101
+ sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
102
+ sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
91
103
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
92
104
  see @file data.py and @file collator.py for more details about some of the parameters
93
105
  """
94
- if collection_name is not None:
95
- mdataset = Dataset(
96
- ln.Collection.filter(name=collection_name).first(),
97
- organisms=organisms,
98
- clss_to_predict=clss_to_predict,
99
- hierarchical_clss=hierarchical_clss,
100
- metacell_mode=metacell_mode,
101
- get_knn_cells=get_knn_cells,
106
+ if "organism_ontology_term_id" not in clss_to_predict:
107
+ raise ValueError(
108
+ "need 'organism_ontology_term_id' in the set of classes at least"
102
109
  )
110
+ mdataset = Dataset(
111
+ ln.Collection.filter(name=collection_name, is_latest=True).first(),
112
+ clss_to_predict=clss_to_predict,
113
+ hierarchical_clss=hierarchical_clss,
114
+ metacell_mode=metacell_mode,
115
+ get_knn_cells=get_knn_cells,
116
+ store_location=store_location,
117
+ force_recompute_indices=force_recompute_indices,
118
+ )
103
119
  # and location
104
120
  self.metacell_mode = bool(metacell_mode)
105
121
  self.gene_pos = None
106
122
  self.collection_name = collection_name
107
123
  if do_gene_pos:
108
- if type(do_gene_pos) is str:
109
- print("seeing a string: loading gene positions as biomart parquet file")
110
- biomart = pd.read_parquet(do_gene_pos)
111
- else:
112
- # and annotations
113
- if organisms != ["NCBITaxon:9606"]:
114
- raise ValueError(
115
- "need to provide your own table as this automated function only works for humans for now"
116
- )
117
- biomart = getBiomartTable(
118
- attributes=["start_position", "chromosome_name"],
119
- useCache=True,
120
- ).set_index("ensembl_gene_id")
121
- biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
122
- biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
123
- c = []
124
- i = 0
125
- prev_position = -100000
126
- prev_chromosome = None
127
- for _, r in biomart.iterrows():
128
- if (
129
- r["chromosome_name"] != prev_chromosome
130
- or r["start_position"] - prev_position > gene_position_tolerance
131
- ):
132
- i += 1
133
- c.append(i)
134
- prev_position = r["start_position"]
135
- prev_chromosome = r["chromosome_name"]
136
- print(f"reduced the size to {len(set(c)) / len(biomart)}")
137
- biomart["pos"] = c
124
+ biomart = pd.read_parquet(do_gene_pos)
138
125
  mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
139
126
  self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
140
-
141
127
  if gene_embeddings != "":
142
128
  mdataset.genedf = mdataset.genedf.join(
143
- pd.read_parquet(gene_embeddings), how="inner"
129
+ pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
144
130
  )
145
131
  if do_gene_pos:
146
132
  self.gene_pos = mdataset.genedf["pos"].tolist()
@@ -149,7 +135,7 @@ class DataModule(L.LightningDataModule):
149
135
  # we might want to not introduce zeros and
150
136
  if use_default_col:
151
137
  kwargs["collate_fn"] = Collator(
152
- organisms=organisms,
138
+ organisms=mdataset.organisms,
153
139
  how=how,
154
140
  valid_genes=mdataset.genedf.index.tolist(),
155
141
  max_len=max_len,
@@ -157,7 +143,7 @@ class DataModule(L.LightningDataModule):
157
143
  org_to_id=mdataset.encoder[organism_name],
158
144
  tp_name=tp_name,
159
145
  organism_name=organism_name,
160
- class_names=clss_to_predict,
146
+ class_names=list(self.classes.keys()),
161
147
  )
162
148
  self.validation_split = validation_split
163
149
  self.test_split = test_split
@@ -173,10 +159,15 @@ class DataModule(L.LightningDataModule):
173
159
  self.clss_to_weight = clss_to_weight
174
160
  self.train_weights = None
175
161
  self.train_labels = None
162
+ self.sampler_workers = sampler_workers
163
+ self.sampler_chunk_size = sampler_chunk_size
164
+ self.store_location = store_location
176
165
  self.nnz = None
177
166
  self.test_datasets = []
167
+ self.force_recompute_indices = force_recompute_indices
178
168
  self.test_idx = []
179
169
  super().__init__()
170
+ print("finished init")
180
171
 
181
172
  def __repr__(self):
182
173
  return (
@@ -238,6 +229,44 @@ class DataModule(L.LightningDataModule):
238
229
  """
239
230
  return self.dataset.genedf.index.tolist()
240
231
 
232
+ @genes.setter
233
+ def genes(self, genes):
234
+ self.dataset.genedf = self.dataset.genedf.loc[genes]
235
+ self.kwargs["collate_fn"].genes = genes
236
+ self.kwargs["collate_fn"]._setup(
237
+ genedf=self.dataset.genedf,
238
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
239
+ valid_genes=genes,
240
+ )
241
+
242
+ @property
243
+ def encoders(self):
244
+ return self.dataset.encoder
245
+
246
+ @encoders.setter
247
+ def encoders(self, encoders):
248
+ self.dataset.encoder = encoders
249
+ self.kwargs["collate_fn"].org_to_id = encoders[
250
+ self.kwargs["collate_fn"].organism_name
251
+ ]
252
+ self.kwargs["collate_fn"]._setup(
253
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
254
+ valid_genes=self.genes,
255
+ )
256
+
257
+ @property
258
+ def organisms(self):
259
+ return self.dataset.organisms
260
+
261
+ @organisms.setter
262
+ def organisms(self, organisms):
263
+ self.dataset.organisms = organisms
264
+ self.kwargs["collate_fn"].organisms = organisms
265
+ self.kwargs["collate_fn"]._setup(
266
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
267
+ valid_genes=self.genes,
268
+ )
269
+
241
270
  @property
242
271
  def num_datasets(self):
243
272
  return len(self.dataset.mapped_dataset.storages)
@@ -252,104 +281,191 @@ class DataModule(L.LightningDataModule):
252
281
  It can be either 'fit' or 'test'. Defaults to None.
253
282
  """
254
283
  SCALE = 10
255
- if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
256
- self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
257
- self.clss_to_weight.remove("nnz")
258
- (
259
- (self.nnz.max() / SCALE)
260
- / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
261
- ).min()
262
- if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
263
- weights, labels = self.dataset.get_label_weights(
264
- self.clss_to_weight,
265
- scaler=self.weight_scaler,
266
- return_categories=True,
284
+ print("setting up the datamodule")
285
+ start_time = time.time()
286
+ if (
287
+ self.store_location is None
288
+ or not os.path.exists(
289
+ os.path.join(self.store_location, "train_weights.npy")
267
290
  )
268
- else:
269
- weights = np.ones(1)
270
- labels = np.zeros(self.n_samples, dtype=int)
271
- if isinstance(self.validation_split, int):
272
- len_valid = self.validation_split
273
- else:
274
- len_valid = int(self.n_samples * self.validation_split)
275
- if isinstance(self.test_split, int):
276
- len_test = self.test_split
277
- else:
278
- len_test = int(self.n_samples * self.test_split)
279
- assert (
280
- len_test + len_valid < self.n_samples
281
- ), "test set + valid set size is configured to be larger than entire dataset."
282
-
283
- idx_full = []
284
- if len(self.assays_to_drop) > 0:
285
- badloc = np.isin(
286
- self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
287
- self.assays_to_drop,
288
- )
289
- idx_full = np.arange(len(labels))[~badloc]
290
- else:
291
- idx_full = np.arange(self.n_samples)
292
- if len_test > 0:
293
- # this way we work on some never seen datasets
294
- # keeping at least one
295
- len_test = (
296
- len_test
297
- if len_test > self.dataset.mapped_dataset.n_obs_list[0]
298
- else self.dataset.mapped_dataset.n_obs_list[0]
291
+ or self.force_recompute_indices
292
+ ):
293
+ if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
294
+ self.nnz = self.dataset.mapped_dataset.get_merged_labels(
295
+ "nnz", is_cat=False
296
+ )
297
+ self.clss_to_weight.remove("nnz")
298
+ (
299
+ (self.nnz.max() / SCALE)
300
+ / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
301
+ ).min()
302
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
303
+ weights, labels = self.dataset.get_label_weights(
304
+ self.clss_to_weight,
305
+ scaler=self.weight_scaler,
306
+ return_categories=True,
307
+ )
308
+ else:
309
+ weights = np.ones(1)
310
+ labels = np.zeros(self.n_samples, dtype=int)
311
+ if isinstance(self.validation_split, int):
312
+ len_valid = self.validation_split
313
+ else:
314
+ len_valid = int(self.n_samples * self.validation_split)
315
+ if isinstance(self.test_split, int):
316
+ len_test = self.test_split
317
+ else:
318
+ len_test = int(self.n_samples * self.test_split)
319
+ assert len_test + len_valid < self.n_samples, (
320
+ "test set + valid set size is configured to be larger than entire dataset."
299
321
  )
300
- cs = 0
301
- for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
302
- if cs + c > len_test:
303
- break
304
- else:
305
- self.test_datasets.append(
306
- self.dataset.mapped_dataset.path_list[i].path
307
- )
308
- cs += c
309
- len_test = cs
310
- self.test_idx = idx_full[:len_test]
311
- idx_full = idx_full[len_test:]
312
- else:
313
- self.test_idx = None
314
322
 
315
- np.random.shuffle(idx_full)
316
- if len_valid > 0:
317
- self.valid_idx = idx_full[:len_valid].copy()
318
- # store it for later
319
- idx_full = idx_full[len_valid:]
320
- else:
321
- self.valid_idx = None
322
- weights = np.concatenate([weights, np.zeros(1)])
323
- labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
324
- # some labels will now not exist anymore as replaced by len(weights) - 1.
325
- # this means that the associated weights should be 0.
326
- # by doing np.bincount(labels)*weights this will be taken into account
327
- self.train_weights = weights
328
- self.train_labels = labels
329
- self.idx_full = idx_full
323
+ idx_full = []
324
+ if len(self.assays_to_drop) > 0:
325
+ badloc = np.isin(
326
+ self.dataset.mapped_dataset.get_merged_labels(
327
+ "assay_ontology_term_id"
328
+ ),
329
+ self.assays_to_drop,
330
+ )
331
+ idx_full = np.arange(len(labels))[~badloc]
332
+ else:
333
+ idx_full = np.arange(self.n_samples)
334
+ if len_test > 0:
335
+ # this way we work on some never seen datasets
336
+ # keeping at least one
337
+ len_test = (
338
+ len_test
339
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
340
+ else self.dataset.mapped_dataset.n_obs_list[0]
341
+ )
342
+ cs = 0
343
+ d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
344
+ random.Random(42).shuffle(d_size) # always same order
345
+ for i, c in d_size:
346
+ if cs + c > len_test:
347
+ break
348
+ else:
349
+ self.test_datasets.append(
350
+ self.dataset.mapped_dataset.path_list[i].path
351
+ )
352
+ cs += c
353
+ len_test = cs
354
+ self.test_idx = idx_full[:len_test]
355
+ idx_full = idx_full[len_test:]
356
+ else:
357
+ self.test_idx = None
358
+
359
+ np.random.shuffle(idx_full)
360
+ if len_valid > 0:
361
+ self.valid_idx = idx_full[:len_valid].copy()
362
+ # store it for later
363
+ idx_full = idx_full[len_valid:]
364
+ else:
365
+ self.valid_idx = None
366
+ weights = np.concatenate([weights, np.zeros(1)])
367
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
368
+ # some labels will now not exist anymore as replaced by len(weights) - 1.
369
+ # this means that the associated weights should be 0.
370
+ # by doing np.bincount(labels)*weights this will be taken into account
371
+ self.train_weights = weights
372
+ self.train_labels = labels
373
+ self.idx_full = idx_full
374
+ if self.store_location is not None:
375
+ if (
376
+ not os.path.exists(
377
+ os.path.join(self.store_location, "train_weights.npy")
378
+ )
379
+ or self.force_recompute_indices
380
+ ):
381
+ os.makedirs(self.store_location, exist_ok=True)
382
+ if self.nnz is not None:
383
+ np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
384
+ np.save(
385
+ os.path.join(self.store_location, "train_weights.npy"),
386
+ self.train_weights,
387
+ )
388
+ np.save(
389
+ os.path.join(self.store_location, "train_labels.npy"),
390
+ self.train_labels,
391
+ )
392
+ np.save(
393
+ os.path.join(self.store_location, "idx_full.npy"), self.idx_full
394
+ )
395
+ if self.test_idx is not None:
396
+ np.save(
397
+ os.path.join(self.store_location, "test_idx.npy"), self.test_idx
398
+ )
399
+ if self.valid_idx is not None:
400
+ np.save(
401
+ os.path.join(self.store_location, "valid_idx.npy"),
402
+ self.valid_idx,
403
+ )
404
+ listToFile(
405
+ self.test_datasets,
406
+ os.path.join(self.store_location, "test_datasets.txt"),
407
+ )
408
+ else:
409
+ self.nnz = (
410
+ np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
411
+ if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
412
+ else None
413
+ )
414
+ self.train_weights = np.load(
415
+ os.path.join(self.store_location, "train_weights.npy")
416
+ )
417
+ self.train_labels = np.load(
418
+ os.path.join(self.store_location, "train_labels.npy")
419
+ )
420
+ self.idx_full = np.load(
421
+ os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
422
+ )
423
+ self.test_idx = (
424
+ np.load(os.path.join(self.store_location, "test_idx.npy"))
425
+ if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
426
+ else None
427
+ )
428
+ self.valid_idx = (
429
+ np.load(os.path.join(self.store_location, "valid_idx.npy"))
430
+ if os.path.exists(
431
+ os.path.join(self.store_location, "valid_idx.npy")
432
+ )
433
+ else None
434
+ )
435
+ self.test_datasets = fileToList(
436
+ os.path.join(self.store_location, "test_datasets.txt")
437
+ )
438
+ print("loaded from store")
439
+ print(f"done setup, took {time.time() - start_time:.2f} seconds")
330
440
  return self.test_datasets
331
441
 
332
442
  def train_dataloader(self, **kwargs):
333
- # train_sampler = WeightedRandomSampler(
334
- # self.train_weights[self.train_labels],
335
- # int(self.n_samples*self.n_samples_per_epoch),
336
- # replacement=True,
337
- # )
338
- try:
339
- train_sampler = LabelWeightedSampler(
340
- label_weights=self.train_weights,
341
- labels=self.train_labels,
342
- num_samples=int(self.n_samples_per_epoch),
343
- element_weights=self.nnz,
344
- replacement=self.replacement,
345
- )
346
- except ValueError as e:
347
- raise ValueError(e + "have you run `datamodule.setup()`?")
443
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
444
+ try:
445
+ print("Setting up the parallel train sampler...")
446
+ # Create the optimized parallel sampler
447
+ print(f"Using {self.sampler_workers} workers for class indexing")
448
+ train_sampler = LabelWeightedSampler(
449
+ label_weights=self.train_weights,
450
+ labels=self.train_labels,
451
+ num_samples=int(self.n_samples_per_epoch),
452
+ element_weights=self.nnz,
453
+ replacement=self.replacement,
454
+ n_workers=self.sampler_workers,
455
+ chunk_size=self.sampler_chunk_size,
456
+ store_location=self.store_location,
457
+ force_recompute_indices=self.force_recompute_indices,
458
+ )
459
+ except ValueError as e:
460
+ raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
461
+ else:
462
+ train_sampler = SubsetRandomSampler(self.idx_full)
463
+ current_loader_kwargs = kwargs.copy()
464
+ current_loader_kwargs.update(self.kwargs)
348
465
  return DataLoader(
349
466
  self.dataset,
350
467
  sampler=train_sampler,
351
- **self.kwargs,
352
- **kwargs,
468
+ **current_loader_kwargs,
353
469
  )
354
470
 
355
471
  def val_dataloader(self):
@@ -379,99 +495,335 @@ class DataModule(L.LightningDataModule):
379
495
  **self.kwargs,
380
496
  )
381
497
 
382
- # def teardown(self):
383
- # clean up state after the trainer stops, delete files...
384
- # called on every process in DDP
385
- # pass
386
-
387
498
 
388
499
  class LabelWeightedSampler(Sampler[int]):
389
- label_weights: Sequence[float]
390
- klass_indices: Sequence[Sequence[int]]
500
+ """
501
+ A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
502
+
503
+ This sampler is designed to handle very large datasets efficiently, with optimizations for:
504
+ 1. Parallel building of class indices
505
+ 2. Chunked processing for large arrays
506
+ 3. Efficient memory management
507
+ 4. Proper handling of replacement and non-replacement sampling
508
+ """
509
+
510
+ label_weights: torch.Tensor
511
+ klass_indices: dict[int, torch.Tensor]
391
512
  num_samples: int
392
- nnz: Optional[Sequence[int]]
513
+ element_weights: Optional[torch.Tensor]
393
514
  replacement: bool
394
- # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
395
- # this will result a class-balanced sampling, no matter how imbalance the labels are.
396
515
 
397
516
  def __init__(
398
517
  self,
399
518
  label_weights: Sequence[float],
400
- labels: Sequence[int],
519
+ labels: np.ndarray,
401
520
  num_samples: int,
402
521
  replacement: bool = True,
403
- element_weights: Sequence[float] = None,
522
+ element_weights: Optional[Sequence[float]] = None,
523
+ n_workers: int = None,
524
+ chunk_size: int = None, # Process 10M elements per chunk
525
+ store_location: str = None,
526
+ force_recompute_indices: bool = False,
404
527
  ) -> None:
405
528
  """
529
+ Initialize the sampler with parallel processing for large datasets.
406
530
 
407
- :param label_weights: list(len=num_classes)[float], weights for each class.
408
- :param labels: list(len=dataset_len)[int], labels of a dataset.
409
- :param num_samples: number of samples.
531
+ Args:
532
+ label_weights: Weights for each class (length = number of classes)
533
+ labels: Class label for each dataset element (length = dataset size)
534
+ num_samples: Number of samples to draw
535
+ replacement: Whether to sample with replacement
536
+ element_weights: Optional weights for each element within classes
537
+ n_workers: Number of parallel workers to use (default: number of CPUs-1)
538
+ chunk_size: Size of chunks to process in parallel (default: 10M elements)
410
539
  """
411
-
540
+ print("Initializing optimized parallel weighted sampler...")
412
541
  super(LabelWeightedSampler, self).__init__(None)
413
- # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
414
- label_weights = np.array(label_weights) * np.bincount(labels)
415
-
416
- self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
417
- self.labels = torch.as_tensor(labels, dtype=torch.int)
418
- self.element_weights = (
419
- torch.as_tensor(element_weights, dtype=torch.float32)
420
- if element_weights is not None
421
- else None
422
- )
542
+
543
+ # Compute label weights (incorporating class frequencies)
544
+ # Directly use labels as numpy array without conversion
545
+ label_weights = np.asarray(label_weights) * np.bincount(labels)
546
+ self.label_weights = torch.as_tensor(
547
+ label_weights, dtype=torch.float32
548
+ ).share_memory_()
549
+
550
+ # Store element weights if provided
551
+ if element_weights is not None:
552
+ self.element_weights = torch.as_tensor(
553
+ element_weights, dtype=torch.float32
554
+ ).share_memory_()
555
+ else:
556
+ self.element_weights = None
557
+
423
558
  self.replacement = replacement
424
559
  self.num_samples = num_samples
425
- # list of tensor.
426
- self.klass_indices = [
427
- (self.labels == i_klass).nonzero().squeeze(1)
428
- for i_klass in range(len(label_weights))
429
- ]
430
- self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
560
+ if (
561
+ store_location is None
562
+ or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
563
+ or force_recompute_indices
564
+ ):
565
+ # Set number of workers (default to CPU count - 1, but at least 1)
566
+ if n_workers is None:
567
+ # Check if running on SLURM
568
+ n_workers = min(20, max(1, mp.cpu_count() - 1))
569
+ if "SLURM_CPUS_PER_TASK" in os.environ:
570
+ n_workers = min(
571
+ 20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
572
+ )
573
+
574
+ # Try to auto-determine optimal chunk size based on memory
575
+ if chunk_size is None:
576
+ try:
577
+ import psutil
578
+
579
+ # Check if running on SLURM
580
+ available_memory = psutil.virtual_memory().available
581
+ for name in [
582
+ "SLURM_MEM_PER_NODE",
583
+ "SLURM_MEM_PER_CPU",
584
+ "SLURM_MEM_PER_GPU",
585
+ "SLURM_MEM_PER_TASK",
586
+ ]:
587
+ if name in os.environ:
588
+ available_memory = (
589
+ int(os.environ[name]) * 1024 * 1024
590
+ ) # Convert MB to bytes
591
+ break
592
+
593
+ # Use at most 50% of available memory across all workers
594
+ memory_per_worker = 0.5 * available_memory / n_workers
595
+ # Rough estimate: each label takes 4 bytes, each index 8 bytes
596
+ bytes_per_element = 12
597
+ chunk_size = min(
598
+ max(100_000, int(memory_per_worker / bytes_per_element / 3)),
599
+ 2_000_000,
600
+ )
601
+ print(f"Auto-determined chunk size: {chunk_size:,} elements")
602
+ except (ImportError, KeyError):
603
+ chunk_size = 2_000_000
604
+ print(f"Using default chunk size: {chunk_size:,} elements")
605
+
606
+ # Parallelize the class indices building
607
+ print(f"Building class indices in parallel with {n_workers} workers...")
608
+ klass_indices = self._build_class_indices_parallel(
609
+ labels, chunk_size, n_workers
610
+ )
611
+
612
+ # Convert klass_indices to a single tensor and offset vector
613
+ all_indices = []
614
+ offsets = []
615
+ current_offset = 0
616
+
617
+ # Sort keys to ensure consistent ordering
618
+ keys = klass_indices.keys()
619
+
620
+ # Build concatenated tensor and track offsets
621
+ for i in range(max(keys) + 1):
622
+ offsets.append(current_offset)
623
+ if i in keys:
624
+ indices = klass_indices[i]
625
+ all_indices.append(indices)
626
+ current_offset += len(indices)
627
+
628
+ # Convert to tensors
629
+ self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
630
+ self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
631
+ if store_location is not None:
632
+ store_path = os.path.join(store_location, "klass_indices.pt")
633
+ if os.path.exists(store_path) and not force_recompute_indices:
634
+ self.klass_indices = torch.load(store_path).share_memory_()
635
+ self.klass_offsets = torch.load(
636
+ store_path.replace(".pt", "_offsets.pt")
637
+ ).share_memory_()
638
+ print(f"Loaded sampler indices from {store_path}")
639
+ else:
640
+ torch.save(self.klass_indices, store_path)
641
+ torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
642
+ print(f"Saved sampler indices to {store_path}")
643
+ print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
431
644
 
432
645
  def __iter__(self):
646
+ # Sample classes according to their weights
647
+ print("sampling a new batch of size", self.num_samples)
648
+
433
649
  sample_labels = torch.multinomial(
434
650
  self.label_weights,
435
651
  num_samples=self.num_samples,
436
652
  replacement=True,
437
653
  )
438
- sample_indices = torch.empty_like(sample_labels)
439
- for i_klass, klass_index in enumerate(self.klass_indices):
654
+ # Get counts of each class in sample_labels
655
+ unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
656
+
657
+ # Initialize result tensor
658
+ result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
659
+
660
+ # Process only the classes that were actually sampled
661
+ for i, (label, count) in tqdm(
662
+ enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
663
+ total=len(unique_samples),
664
+ desc="Processing classes in sampler",
665
+ ):
666
+ klass_index = self.klass_indices[
667
+ self.klass_offsets[label] : self.klass_offsets[label + 1]
668
+ ]
669
+
440
670
  if klass_index.numel() == 0:
441
671
  continue
442
- left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
443
- if len(left_inds) == 0:
444
- continue
672
+
673
+ # Sample elements from this class
445
674
  if self.element_weights is not None:
446
- right_inds = torch.multinomial(
447
- self.element_weights[klass_index],
448
- num_samples=len(klass_index)
449
- if not self.replacement and len(klass_index) < len(left_inds)
450
- else len(left_inds),
451
- replacement=self.replacement,
452
- )
675
+ # This is a critical point for memory
676
+ current_element_weights_slice = self.element_weights[klass_index]
677
+
678
+ if self.replacement:
679
+ right_inds = torch.multinomial(
680
+ current_element_weights_slice,
681
+ num_samples=count,
682
+ replacement=True,
683
+ )
684
+ else:
685
+ num_to_sample = min(count, len(klass_index))
686
+ right_inds = torch.multinomial(
687
+ current_element_weights_slice,
688
+ num_samples=num_to_sample,
689
+ replacement=False,
690
+ )
453
691
  elif self.replacement:
454
- right_inds = torch.randint(
455
- len(klass_index),
456
- size=(len(left_inds),),
457
- )
692
+ right_inds = torch.randint(len(klass_index), size=(count,))
458
693
  else:
459
- maxelem = (
460
- len(left_inds)
461
- if len(left_inds) < len(klass_index)
462
- else len(klass_index)
463
- )
464
- right_inds = torch.randperm(len(klass_index))[:maxelem]
465
- sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
466
- # if there are more left_inds than right_inds, we need to drop the extra ones
467
- if len(right_inds) < len(left_inds):
468
- sample_indices[left_inds[len(right_inds) :]] = -1
469
- # drop all -1
470
- sample_indices = sample_indices[sample_indices != -1]
471
- # torch shuffle
472
- sample_indices = sample_indices[torch.randperm(len(sample_indices))]
473
- self.num_samples = len(sample_indices)
474
- yield from iter(sample_indices.tolist())
694
+ num_to_sample = min(count, len(klass_index))
695
+ right_inds = torch.randperm(len(klass_index))[:num_to_sample]
696
+
697
+ # Get actual indices
698
+ sampled_indices = klass_index[right_inds]
699
+ result_indices_list.append(sampled_indices)
700
+
701
+ # Combine all indices
702
+ if result_indices_list: # Check if the list is not empty
703
+ final_result_indices = torch.cat(
704
+ result_indices_list
705
+ ) # Use the list with the appended new name
706
+
707
+ # Shuffle the combined indices
708
+ shuffled_indices = final_result_indices[
709
+ torch.randperm(len(final_result_indices))
710
+ ]
711
+ self.num_samples = len(shuffled_indices)
712
+ yield from shuffled_indices.tolist()
713
+ else:
714
+ self.num_samples = 0
715
+ yield from iter([])
475
716
 
476
717
  def __len__(self):
477
718
  return self.num_samples
719
+
720
+ def _merge_chunk_results(self, results_list):
721
+ """Merge results from multiple chunks into a single dictionary.
722
+
723
+ Args:
724
+ results_list: list of dictionaries mapping class labels to index arrays
725
+
726
+ Returns:
727
+ merged dictionary with PyTorch tensors
728
+ """
729
+ merged = {}
730
+
731
+ # Collect all labels across all chunks
732
+ all_labels = set()
733
+ for chunk_result in results_list:
734
+ all_labels.update(chunk_result.keys())
735
+
736
+ # For each unique label
737
+ for label in all_labels:
738
+ # Collect indices from all chunks where this label appears
739
+ indices_lists = [
740
+ chunk_result[label]
741
+ for chunk_result in results_list
742
+ if label in chunk_result
743
+ ]
744
+
745
+ if indices_lists:
746
+ # Concatenate all indices for this label
747
+ merged[label] = torch.tensor(
748
+ np.concatenate(indices_lists), dtype=torch.long
749
+ )
750
+ else:
751
+ merged[label] = torch.tensor([], dtype=torch.long)
752
+
753
+ return merged
754
+
755
+ def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
756
+ """Build class indices in parallel across multiple workers.
757
+
758
+ Args:
759
+ labels: array of class labels
760
+ n_workers: number of parallel workers
761
+ chunk_size: size of chunks to process
762
+
763
+ Returns:
764
+ dictionary mapping class labels to tensors of indices
765
+ """
766
+ n = len(labels)
767
+ results = []
768
+ # Create chunks of the labels array with proper sizing
769
+ n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
770
+ print(f"Processing {n:,} elements in {n_chunks} chunks...")
771
+
772
+ # Process in chunks to limit memory usage
773
+ with ProcessPoolExecutor(
774
+ max_workers=n_workers, mp_context=mp.get_context("spawn")
775
+ ) as executor:
776
+ # Submit chunks for processing
777
+ futures = []
778
+ for i in range(n_chunks):
779
+ start_idx = i * chunk_size
780
+ end_idx = min((i + 1) * chunk_size, n)
781
+ # We pass only chunk boundaries, not the data itself
782
+ # This avoids unnecessary copies during process creation
783
+ futures.append(
784
+ executor.submit(
785
+ self._process_chunk_with_slice,
786
+ (start_idx, end_idx, labels),
787
+ )
788
+ )
789
+
790
+ # Collect results as they complete with progress reporting
791
+ for future in tqdm(
792
+ as_completed(futures), total=len(futures), desc="Processing chunks"
793
+ ):
794
+ results.append(future.result())
795
+
796
+ # Merge results from all chunks
797
+ print("Merging results from all chunks...")
798
+ merged_results = self._merge_chunk_results(results)
799
+
800
+ return merged_results
801
+
802
+ def _process_chunk_with_slice(self, slice_info):
803
+ """Process a slice of the labels array by indices.
804
+
805
+ Args:
806
+ slice_info: tuple of (start_idx, end_idx, labels_array) where
807
+ start_idx and end_idx define the slice to process
808
+
809
+ Returns:
810
+ dict mapping class labels to arrays of indices
811
+ """
812
+ start_idx, end_idx, labels_array = slice_info
813
+
814
+ # We're processing a slice of the original array
815
+ labels_slice = labels_array[start_idx:end_idx]
816
+ chunk_indices = {}
817
+
818
+ # Create a direct map of indices
819
+ indices = np.arange(start_idx, end_idx)
820
+
821
+ # Get unique labels in this slice for more efficient processing
822
+ unique_labels = np.unique(labels_slice)
823
+ # For each valid label, find its indices
824
+ for label in unique_labels:
825
+ # Find positions where this label appears (using direct boolean indexing)
826
+ label_mask = labels_slice == label
827
+ chunk_indices[int(label)] = indices[label_mask]
828
+
829
+ return chunk_indices