scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,51 +1,55 @@
1
+ import math
2
+ import multiprocessing as mp
1
3
  import os
2
- from typing import Optional, Sequence, Union
4
+ import random
5
+ import time
6
+ from concurrent.futures import ProcessPoolExecutor, as_completed
7
+ from functools import partial
8
+ from typing import List, Optional, Sequence, Union
3
9
 
4
10
  import lamindb as ln
5
11
  import lightning as L
6
12
  import numpy as np
7
13
  import pandas as pd
8
14
  import torch
9
- from torch.utils.data import DataLoader, Sampler
15
+ from torch.utils.data import DataLoader, Sampler, Subset
10
16
  from torch.utils.data.sampler import (
11
17
  RandomSampler,
12
18
  SequentialSampler,
13
19
  SubsetRandomSampler,
14
20
  WeightedRandomSampler,
15
21
  )
22
+ from tqdm import tqdm
16
23
 
17
24
  from .collator import Collator
18
25
  from .data import Dataset
19
- from .utils import getBiomartTable
26
+ from .utils import fileToList, getBiomartTable, listToFile
20
27
 
21
28
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
29
+ NNZ_SCALE = 1000
22
30
 
23
31
 
24
32
  class DataModule(L.LightningDataModule):
25
33
  def __init__(
26
34
  self,
27
35
  collection_name: str,
28
- clss_to_weight: list = ["organism_ontology_term_id"],
29
- organisms: list = ["NCBITaxon:9606"],
36
+ clss_to_weight: List[str] = ["organism_ontology_term_id"],
30
37
  weight_scaler: int = 10,
31
38
  n_samples_per_epoch: int = 2_000_000,
32
39
  validation_split: float = 0.2,
33
40
  test_split: float = 0,
34
- gene_embeddings: str = "",
35
41
  use_default_col: bool = True,
36
- gene_position_tolerance: int = 10_000,
37
42
  # this is for the mappedCollection
38
- clss_to_predict: list = ["organism_ontology_term_id"],
39
- hierarchical_clss: list = [],
43
+ clss_to_predict: List[str] = ["organism_ontology_term_id"],
44
+ hierarchical_clss: List[str] = [],
40
45
  # this is for the collator
41
46
  how: str = "random expr",
42
- organism_name: str = "organism_ontology_term_id",
47
+ organism_col: str = "organism_ontology_term_id",
43
48
  max_len: int = 1000,
44
- add_zero_genes: int = 100,
45
49
  replacement: bool = True,
46
- do_gene_pos: Union[bool, str] = True,
50
+ gene_subset: Optional[list[str]] = None,
47
51
  tp_name: Optional[str] = None, # "heat_diff"
48
- assays_to_drop: list = [
52
+ assays_to_drop: List[str] = [
49
53
  # "EFO:0008853", #patch seq
50
54
  # "EFO:0010961", # visium
51
55
  "EFO:0030007", # ATACseq
@@ -53,6 +57,14 @@ class DataModule(L.LightningDataModule):
53
57
  ],
54
58
  metacell_mode: float = 0.0,
55
59
  get_knn_cells: bool = False,
60
+ store_location: str = None,
61
+ force_recompute_indices: bool = False,
62
+ sampler_workers: int = None,
63
+ sampler_chunk_size: int = None,
64
+ organisms: Optional[str] = None,
65
+ genedf: Optional[pd.DataFrame] = None,
66
+ n_bins: int = 0,
67
+ curiculum: int = 0,
56
68
  **kwargs,
57
69
  ):
58
70
  """
@@ -64,101 +76,74 @@ class DataModule(L.LightningDataModule):
64
76
 
65
77
  Args:
66
78
  collection_name (str): The lamindb collection to be used.
67
- organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
68
79
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
69
80
  n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
70
81
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
71
82
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
72
83
  it will use a full dataset and will round to the nearest dataset's cell count.
73
- gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
84
+ use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
85
+ clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
86
+ assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
87
+ gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
74
88
  the file must have ensembl_gene_id as index.
75
89
  This is used to subset the available genes further to the ones that have embeddings in your model.
76
- use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
77
- gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
78
- any genes within this distance of each other will be considered at the same position.
79
- clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
80
- assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
81
- do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
82
90
  max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
83
- add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
84
91
  how (str, optional): The method to use for the collator. Defaults to "random expr".
85
- organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
92
+ organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
86
93
  tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
87
- hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
94
+ hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
88
95
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
89
- clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
96
+ clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
90
97
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
98
+ store_location (str, optional): The location to store the sampler indices. Defaults to None.
99
+ force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
100
+ sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
101
+ sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
91
102
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
92
103
  see @file data.py and @file collator.py for more details about some of the parameters
93
104
  """
94
- if collection_name is not None:
95
- mdataset = Dataset(
96
- ln.Collection.filter(name=collection_name).first(),
97
- organisms=organisms,
98
- clss_to_predict=clss_to_predict,
99
- hierarchical_clss=hierarchical_clss,
100
- metacell_mode=metacell_mode,
101
- get_knn_cells=get_knn_cells,
105
+ if "organism_ontology_term_id" not in clss_to_predict:
106
+ raise ValueError(
107
+ "need 'organism_ontology_term_id' in the set of classes at least"
102
108
  )
109
+ if metacell_mode > 0 and get_knn_cells:
110
+ raise ValueError(
111
+ "cannot use metacell mode and get_knn_cells at the same time"
112
+ )
113
+ mdataset = Dataset(
114
+ ln.Collection.filter(key=collection_name, is_latest=True).first(),
115
+ clss_to_predict=clss_to_predict,
116
+ hierarchical_clss=hierarchical_clss,
117
+ metacell_mode=metacell_mode,
118
+ get_knn_cells=get_knn_cells,
119
+ store_location=store_location,
120
+ force_recompute_indices=force_recompute_indices,
121
+ genedf=genedf,
122
+ )
103
123
  # and location
104
124
  self.metacell_mode = bool(metacell_mode)
105
125
  self.gene_pos = None
106
126
  self.collection_name = collection_name
107
- if do_gene_pos:
108
- if type(do_gene_pos) is str:
109
- print("seeing a string: loading gene positions as biomart parquet file")
110
- biomart = pd.read_parquet(do_gene_pos)
111
- else:
112
- # and annotations
113
- if organisms != ["NCBITaxon:9606"]:
114
- raise ValueError(
115
- "need to provide your own table as this automated function only works for humans for now"
116
- )
117
- biomart = getBiomartTable(
118
- attributes=["start_position", "chromosome_name"],
119
- useCache=True,
120
- ).set_index("ensembl_gene_id")
121
- biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
122
- biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
123
- c = []
124
- i = 0
125
- prev_position = -100000
126
- prev_chromosome = None
127
- for _, r in biomart.iterrows():
128
- if (
129
- r["chromosome_name"] != prev_chromosome
130
- or r["start_position"] - prev_position > gene_position_tolerance
131
- ):
132
- i += 1
133
- c.append(i)
134
- prev_position = r["start_position"]
135
- prev_chromosome = r["chromosome_name"]
136
- print(f"reduced the size to {len(set(c)) / len(biomart)}")
137
- biomart["pos"] = c
138
- mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
139
- self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
140
-
141
- if gene_embeddings != "":
142
- mdataset.genedf = mdataset.genedf.join(
143
- pd.read_parquet(gene_embeddings), how="inner"
144
- )
145
- if do_gene_pos:
146
- self.gene_pos = mdataset.genedf["pos"].tolist()
127
+ if gene_subset is not None:
128
+ tokeep = set(mdataset.genedf.index.tolist())
129
+ gene_subset = [u for u in gene_subset if u in tokeep]
147
130
  self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
148
131
  # we might want not to order the genes by expression (or do it?)
149
132
  # we might want to not introduce zeros and
150
133
  if use_default_col:
151
134
  kwargs["collate_fn"] = Collator(
152
- organisms=organisms,
135
+ organisms=mdataset.organisms if organisms is None else organisms,
153
136
  how=how,
154
- valid_genes=mdataset.genedf.index.tolist(),
137
+ valid_genes=gene_subset,
155
138
  max_len=max_len,
156
- add_zero_genes=add_zero_genes,
157
- org_to_id=mdataset.encoder[organism_name],
139
+ org_to_id=mdataset.encoder[organism_col],
158
140
  tp_name=tp_name,
159
- organism_name=organism_name,
160
- class_names=clss_to_predict,
141
+ organism_name=organism_col,
142
+ class_names=list(self.classes.keys()),
143
+ genedf=genedf,
144
+ n_bins=n_bins,
161
145
  )
146
+ self.n_bins = n_bins
162
147
  self.validation_split = validation_split
163
148
  self.test_split = test_split
164
149
  self.dataset = mdataset
@@ -173,10 +158,19 @@ class DataModule(L.LightningDataModule):
173
158
  self.clss_to_weight = clss_to_weight
174
159
  self.train_weights = None
175
160
  self.train_labels = None
161
+ self.sampler_workers = sampler_workers
162
+ self.sampler_chunk_size = sampler_chunk_size
163
+ self.store_location = store_location
176
164
  self.nnz = None
165
+ self.idx_full = None
166
+ self.max_len = max_len
177
167
  self.test_datasets = []
168
+ self.force_recompute_indices = force_recompute_indices
169
+ self.curiculum = curiculum
170
+ self.valid_idx = []
178
171
  self.test_idx = []
179
172
  super().__init__()
173
+ print("finished init")
180
174
 
181
175
  def __repr__(self):
182
176
  return (
@@ -192,9 +186,11 @@ class DataModule(L.LightningDataModule):
192
186
  f"\ttest datasets={str(self.test_datasets)},\n"
193
187
  f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
194
188
  f"\tclss_to_weight={self.clss_to_weight}\n"
195
- + ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
196
- if self.idx_full is not None
197
- else ")"
189
+ + (
190
+ "\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
191
+ if self.idx_full is not None
192
+ else ")"
193
+ )
198
194
  )
199
195
 
200
196
  @property
@@ -238,6 +234,49 @@ class DataModule(L.LightningDataModule):
238
234
  """
239
235
  return self.dataset.genedf.index.tolist()
240
236
 
237
+ @property
238
+ def genes_dict(self):
239
+ return {
240
+ i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
241
+ for i in self.dataset.organisms
242
+ }
243
+
244
+ def set_valid_genes_collator(self, genes):
245
+ self.kwargs["collate_fn"]._setup(
246
+ # cannot use genedf there since I am purposefully decreasing it...
247
+ # genedf=self.dataset.genedf,
248
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
249
+ valid_genes=genes,
250
+ )
251
+
252
+ @property
253
+ def encoders(self):
254
+ return self.dataset.encoder
255
+
256
+ @encoders.setter
257
+ def encoders(self, encoders):
258
+ self.dataset.encoder = encoders
259
+ self.kwargs["collate_fn"].org_to_id = encoders[
260
+ self.kwargs["collate_fn"].organism_name
261
+ ]
262
+ self.kwargs["collate_fn"]._setup(
263
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
264
+ valid_genes=self.genes,
265
+ )
266
+
267
+ @property
268
+ def organisms(self):
269
+ return self.dataset.organisms
270
+
271
+ @organisms.setter
272
+ def organisms(self, organisms):
273
+ self.dataset.organisms = organisms
274
+ self.kwargs["collate_fn"].organisms = organisms
275
+ self.kwargs["collate_fn"]._setup(
276
+ org_to_id=self.kwargs["collate_fn"].org_to_id,
277
+ valid_genes=self.genes,
278
+ )
279
+
241
280
  @property
242
281
  def num_datasets(self):
243
282
  return len(self.dataset.mapped_dataset.storages)
@@ -251,116 +290,194 @@ class DataModule(L.LightningDataModule):
251
290
  stage (str, optional): The stage of the model training process.
252
291
  It can be either 'fit' or 'test'. Defaults to None.
253
292
  """
254
- SCALE = 10
255
- if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
256
- self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
257
- self.clss_to_weight.remove("nnz")
258
- (
259
- (self.nnz.max() / SCALE)
260
- / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
261
- ).min()
262
- if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
263
- weights, labels = self.dataset.get_label_weights(
264
- self.clss_to_weight,
265
- scaler=self.weight_scaler,
266
- return_categories=True,
267
- )
268
- else:
269
- weights = np.ones(1)
270
- labels = np.zeros(self.n_samples, dtype=int)
271
- if isinstance(self.validation_split, int):
272
- len_valid = self.validation_split
273
- else:
274
- len_valid = int(self.n_samples * self.validation_split)
275
- if isinstance(self.test_split, int):
276
- len_test = self.test_split
277
- else:
278
- len_test = int(self.n_samples * self.test_split)
279
- assert (
280
- len_test + len_valid < self.n_samples
281
- ), "test set + valid set size is configured to be larger than entire dataset."
282
-
283
- idx_full = []
284
- if len(self.assays_to_drop) > 0:
285
- badloc = np.isin(
286
- self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
287
- self.assays_to_drop,
288
- )
289
- idx_full = np.arange(len(labels))[~badloc]
290
- else:
291
- idx_full = np.arange(self.n_samples)
292
- if len_test > 0:
293
- # this way we work on some never seen datasets
294
- # keeping at least one
295
- len_test = (
296
- len_test
297
- if len_test > self.dataset.mapped_dataset.n_obs_list[0]
298
- else self.dataset.mapped_dataset.n_obs_list[0]
293
+ print("setting up the datamodule")
294
+ start_time = time.time()
295
+ if (
296
+ self.store_location is None
297
+ or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
298
+ or self.force_recompute_indices
299
+ ):
300
+ if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
301
+ self.nnz = self.dataset.mapped_dataset.get_merged_labels(
302
+ "nnz", is_cat=False
303
+ )
304
+ self.clss_to_weight.remove("nnz")
305
+ # Sigmoid scaling with 2 parameters
306
+ midpoint = 2000
307
+ steepness = 0.003
308
+ # Apply sigmoid transformation
309
+ # sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
310
+ # Then scale to [1, NNZ_SCALE] range
311
+ sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
312
+ self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
313
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
314
+ labels = self.dataset.get_label_cats(
315
+ self.clss_to_weight,
316
+ )
317
+ else:
318
+ labels = np.zeros(self.n_samples, dtype=int)
319
+ if isinstance(self.validation_split, int):
320
+ len_valid = self.validation_split
321
+ else:
322
+ len_valid = int(self.n_samples * self.validation_split)
323
+ if isinstance(self.test_split, int):
324
+ len_test = self.test_split
325
+ else:
326
+ len_test = int(self.n_samples * self.test_split)
327
+ assert len_test + len_valid < self.n_samples, (
328
+ "test set + valid set size is configured to be larger than entire dataset."
299
329
  )
300
- cs = 0
301
- for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
302
- if cs + c > len_test:
303
- break
304
- else:
305
- self.test_datasets.append(
306
- self.dataset.mapped_dataset.path_list[i].path
307
- )
308
- cs += c
309
- len_test = cs
310
- self.test_idx = idx_full[:len_test]
311
- idx_full = idx_full[len_test:]
312
- else:
313
- self.test_idx = None
314
330
 
315
- np.random.shuffle(idx_full)
316
- if len_valid > 0:
317
- self.valid_idx = idx_full[:len_valid].copy()
318
- # store it for later
319
- idx_full = idx_full[len_valid:]
320
- else:
321
- self.valid_idx = None
322
- weights = np.concatenate([weights, np.zeros(1)])
323
- labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
324
- # some labels will now not exist anymore as replaced by len(weights) - 1.
325
- # this means that the associated weights should be 0.
326
- # by doing np.bincount(labels)*weights this will be taken into account
327
- self.train_weights = weights
328
- self.train_labels = labels
329
- self.idx_full = idx_full
331
+ idx_full = []
332
+ if len(self.assays_to_drop) > 0:
333
+ badloc = np.isin(
334
+ self.dataset.mapped_dataset.get_merged_labels(
335
+ "assay_ontology_term_id"
336
+ ),
337
+ self.assays_to_drop,
338
+ )
339
+ idx_full = np.arange(len(labels))[~badloc]
340
+ else:
341
+ idx_full = np.arange(self.n_samples)
342
+ if len_test > 0:
343
+ # this way we work on some never seen datasets
344
+ # keeping at least one
345
+ len_test = (
346
+ len_test
347
+ if len_test > self.dataset.mapped_dataset.n_obs_list[0]
348
+ else self.dataset.mapped_dataset.n_obs_list[0]
349
+ )
350
+ cs = 0
351
+ d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
352
+ random.Random(42).shuffle(d_size) # always same order
353
+ for i, c in d_size:
354
+ if cs + c > len_test:
355
+ break
356
+ else:
357
+ self.test_datasets.append(
358
+ self.dataset.mapped_dataset.path_list[i].path
359
+ )
360
+ cs += c
361
+ len_test = cs
362
+ self.test_idx = idx_full[:len_test]
363
+ idx_full = idx_full[len_test:]
364
+ else:
365
+ self.test_idx = None
366
+
367
+ np.random.shuffle(idx_full)
368
+ if len_valid > 0:
369
+ self.valid_idx = idx_full[:len_valid].copy()
370
+ # store it for later
371
+ idx_full = idx_full[len_valid:]
372
+ else:
373
+ self.valid_idx = None
374
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
375
+ # some labels will now not exist anymore as replaced by len(weights) - 1.
376
+ # this means that the associated weights should be 0.
377
+ # by doing np.bincount(labels)*weights this will be taken into account
378
+ self.train_labels = labels
379
+ self.idx_full = idx_full
380
+ if self.store_location is not None:
381
+ if (
382
+ not os.path.exists(
383
+ os.path.join(self.store_location, "train_labels.npy")
384
+ )
385
+ or self.force_recompute_indices
386
+ ):
387
+ os.makedirs(self.store_location, exist_ok=True)
388
+ if self.nnz is not None:
389
+ np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
390
+ np.save(
391
+ os.path.join(self.store_location, "train_labels.npy"),
392
+ self.train_labels,
393
+ )
394
+ np.save(
395
+ os.path.join(self.store_location, "idx_full.npy"), self.idx_full
396
+ )
397
+ if self.test_idx is not None:
398
+ np.save(
399
+ os.path.join(self.store_location, "test_idx.npy"), self.test_idx
400
+ )
401
+ if self.valid_idx is not None:
402
+ np.save(
403
+ os.path.join(self.store_location, "valid_idx.npy"),
404
+ self.valid_idx,
405
+ )
406
+ listToFile(
407
+ self.test_datasets,
408
+ os.path.join(self.store_location, "test_datasets.txt"),
409
+ )
410
+ else:
411
+ self.nnz = (
412
+ np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
413
+ if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
414
+ else None
415
+ )
416
+ self.train_labels = np.load(
417
+ os.path.join(self.store_location, "train_labels.npy")
418
+ )
419
+ self.idx_full = np.load(
420
+ os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
421
+ )
422
+ self.test_idx = (
423
+ np.load(os.path.join(self.store_location, "test_idx.npy"))
424
+ if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
425
+ else None
426
+ )
427
+ self.valid_idx = (
428
+ np.load(os.path.join(self.store_location, "valid_idx.npy"))
429
+ if os.path.exists(
430
+ os.path.join(self.store_location, "valid_idx.npy")
431
+ )
432
+ else None
433
+ )
434
+ self.test_datasets = fileToList(
435
+ os.path.join(self.store_location, "test_datasets.txt")
436
+ )
437
+ print("loaded from store")
438
+ print(f"done setup, took {time.time() - start_time:.2f} seconds")
330
439
  return self.test_datasets
331
440
 
332
441
  def train_dataloader(self, **kwargs):
333
- # train_sampler = WeightedRandomSampler(
334
- # self.train_weights[self.train_labels],
335
- # int(self.n_samples*self.n_samples_per_epoch),
336
- # replacement=True,
337
- # )
338
- try:
339
- train_sampler = LabelWeightedSampler(
340
- label_weights=self.train_weights,
341
- labels=self.train_labels,
342
- num_samples=int(self.n_samples_per_epoch),
343
- element_weights=self.nnz,
344
- replacement=self.replacement,
345
- )
346
- except ValueError as e:
347
- raise ValueError(e + "have you run `datamodule.setup()`?")
442
+ if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
443
+ try:
444
+ print("Setting up the parallel train sampler...")
445
+ # Create the optimized parallel sampler
446
+ print(f"Using {self.sampler_workers} workers for class indexing")
447
+ train_sampler = LabelWeightedSampler(
448
+ labels=self.train_labels,
449
+ weight_scaler=self.weight_scaler,
450
+ num_samples=int(self.n_samples_per_epoch),
451
+ element_weights=self.nnz,
452
+ replacement=self.replacement,
453
+ n_workers=self.sampler_workers,
454
+ chunk_size=self.sampler_chunk_size,
455
+ store_location=self.store_location,
456
+ force_recompute_indices=self.force_recompute_indices,
457
+ curiculum=self.curiculum,
458
+ )
459
+ except ValueError as e:
460
+ raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
461
+ dataset = None
462
+ else:
463
+ dataset = Subset(self.dataset, self.idx_full)
464
+ train_sampler = RankShardSampler(len(dataset))
465
+ current_loader_kwargs = kwargs.copy()
466
+ current_loader_kwargs.update(self.kwargs)
348
467
  return DataLoader(
349
- self.dataset,
468
+ self.dataset if dataset is None else dataset,
350
469
  sampler=train_sampler,
351
- **self.kwargs,
352
- **kwargs,
470
+ **current_loader_kwargs,
353
471
  )
354
472
 
355
473
  def val_dataloader(self):
356
474
  return (
357
475
  DataLoader(
358
- self.dataset,
359
- sampler=SubsetRandomSampler(self.valid_idx),
476
+ Subset(self.dataset, self.valid_idx),
360
477
  **self.kwargs,
361
478
  )
362
479
  if self.valid_idx is not None
363
- else None
480
+ else []
364
481
  )
365
482
 
366
483
  def test_dataloader(self):
@@ -369,109 +486,385 @@ class DataModule(L.LightningDataModule):
369
486
  self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
370
487
  )
371
488
  if self.test_idx is not None
372
- else None
489
+ else []
373
490
  )
374
491
 
375
492
  def predict_dataloader(self):
493
+ subset = Subset(self.dataset, self.idx_full)
376
494
  return DataLoader(
377
- self.dataset,
378
- sampler=SubsetRandomSampler(self.idx_full),
495
+ subset,
496
+ sampler=RankShardSampler(len(subset)),
379
497
  **self.kwargs,
380
498
  )
381
499
 
382
- # def teardown(self):
383
- # clean up state after the trainer stops, delete files...
384
- # called on every process in DDP
385
- # pass
386
-
387
500
 
388
501
  class LabelWeightedSampler(Sampler[int]):
389
- label_weights: Sequence[float]
390
- klass_indices: Sequence[Sequence[int]]
502
+ """
503
+ A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
504
+
505
+ This sampler is designed to handle very large datasets efficiently, with optimizations for:
506
+ 1. Parallel building of class indices
507
+ 2. Chunked processing for large arrays
508
+ 3. Efficient memory management
509
+ 4. Proper handling of replacement and non-replacement sampling
510
+ """
511
+
512
+ label_weights: torch.Tensor
513
+ klass_indices: dict[int, torch.Tensor]
391
514
  num_samples: int
392
- nnz: Optional[Sequence[int]]
515
+ element_weights: Optional[torch.Tensor]
393
516
  replacement: bool
394
- # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
395
- # this will result a class-balanced sampling, no matter how imbalance the labels are.
396
517
 
397
518
  def __init__(
398
519
  self,
399
- label_weights: Sequence[float],
400
- labels: Sequence[int],
520
+ labels: np.ndarray,
401
521
  num_samples: int,
402
522
  replacement: bool = True,
403
- element_weights: Sequence[float] = None,
523
+ weight_scaler: Optional[float] = None,
524
+ element_weights: Optional[Sequence[float]] = None,
525
+ n_workers: int = None,
526
+ chunk_size: int = None, # Process 10M elements per chunk
527
+ store_location: str = None,
528
+ force_recompute_indices: bool = False,
529
+ curiculum: int = 0,
404
530
  ) -> None:
405
531
  """
532
+ Initialize the sampler with parallel processing for large datasets.
406
533
 
407
- :param label_weights: list(len=num_classes)[float], weights for each class.
408
- :param labels: list(len=dataset_len)[int], labels of a dataset.
409
- :param num_samples: number of samples.
534
+ Args:
535
+ weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
536
+ labels: Class label for each dataset element (length = dataset size)
537
+ num_samples: Number of samples to draw
538
+ replacement: Whether to sample with replacement
539
+ element_weights: Optional weights for each element within classes
540
+ n_workers: Number of parallel workers to use (default: number of CPUs-1)
541
+ chunk_size: Size of chunks to process in parallel (default: 10M elements)
410
542
  """
411
-
543
+ print("Initializing optimized parallel weighted sampler...")
412
544
  super(LabelWeightedSampler, self).__init__(None)
413
- # reweight labels from counter otherwsie same weight to labels that have many elements vs a few
414
- label_weights = np.array(label_weights) * np.bincount(labels)
415
-
416
- self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
417
- self.labels = torch.as_tensor(labels, dtype=torch.int)
418
- self.element_weights = (
419
- torch.as_tensor(element_weights, dtype=torch.float32)
420
- if element_weights is not None
421
- else None
422
- )
545
+ self.count = 0
546
+ self.curiculum = curiculum
547
+
548
+ # Compute label weights (incorporating class frequencies)
549
+ # Directly use labels as numpy array without conversion
550
+ counts = np.bincount(labels)
551
+ counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
552
+ label_weights = (weight_scaler * counts) / (counts + weight_scaler)
553
+ self.label_weights = torch.as_tensor(
554
+ label_weights, dtype=torch.float32
555
+ ).share_memory_()
556
+
557
+ # Store element weights if provided
558
+ if element_weights is not None:
559
+ self.element_weights = torch.as_tensor(
560
+ element_weights, dtype=torch.float32
561
+ ).share_memory_()
562
+ else:
563
+ self.element_weights = None
564
+
423
565
  self.replacement = replacement
424
566
  self.num_samples = num_samples
425
- # list of tensor.
426
- self.klass_indices = [
427
- (self.labels == i_klass).nonzero().squeeze(1)
428
- for i_klass in range(len(label_weights))
429
- ]
430
- self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
567
+ if (
568
+ store_location is None
569
+ or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
570
+ or force_recompute_indices
571
+ ):
572
+ # Set number of workers (default to CPU count - 1, but at least 1)
573
+ if n_workers is None:
574
+ # Check if running on SLURM
575
+ n_workers = min(20, max(1, mp.cpu_count() - 1))
576
+ if "SLURM_CPUS_PER_TASK" in os.environ:
577
+ n_workers = min(
578
+ 20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
579
+ )
580
+
581
+ # Try to auto-determine optimal chunk size based on memory
582
+ if chunk_size is None:
583
+ try:
584
+ import psutil
585
+
586
+ # Check if running on SLURM
587
+ available_memory = psutil.virtual_memory().available
588
+ for name in [
589
+ "SLURM_MEM_PER_NODE",
590
+ "SLURM_MEM_PER_CPU",
591
+ "SLURM_MEM_PER_GPU",
592
+ "SLURM_MEM_PER_TASK",
593
+ ]:
594
+ if name in os.environ:
595
+ available_memory = (
596
+ int(os.environ[name]) * 1024 * 1024
597
+ ) # Convert MB to bytes
598
+ break
599
+
600
+ # Use at most 50% of available memory across all workers
601
+ memory_per_worker = 0.5 * available_memory / n_workers
602
+ # Rough estimate: each label takes 4 bytes, each index 8 bytes
603
+ bytes_per_element = 12
604
+ chunk_size = min(
605
+ max(100_000, int(memory_per_worker / bytes_per_element / 3)),
606
+ 2_000_000,
607
+ )
608
+ print(f"Auto-determined chunk size: {chunk_size:,} elements")
609
+ except (ImportError, KeyError):
610
+ chunk_size = 2_000_000
611
+ print(f"Using default chunk size: {chunk_size:,} elements")
612
+
613
+ # Parallelize the class indices building
614
+ print(f"Building class indices in parallel with {n_workers} workers...")
615
+ klass_indices = self._build_class_indices_parallel(
616
+ labels, chunk_size, n_workers
617
+ )
618
+
619
+ # Convert klass_indices to a single tensor and offset vector
620
+ all_indices = []
621
+ offsets = []
622
+ current_offset = 0
623
+
624
+ # Sort keys to ensure consistent ordering
625
+ keys = klass_indices.keys()
626
+
627
+ # Build concatenated tensor and track offsets
628
+ for i in range(max(keys) + 1):
629
+ offsets.append(current_offset)
630
+ if i in keys:
631
+ indices = klass_indices[i]
632
+ all_indices.append(indices)
633
+ current_offset += len(indices)
634
+
635
+ # Convert to tensors
636
+ self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
637
+ self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
638
+ if store_location is not None:
639
+ store_path = os.path.join(store_location, "klass_indices.pt")
640
+ if os.path.exists(store_path) and not force_recompute_indices:
641
+ self.klass_indices = torch.load(store_path).share_memory_()
642
+ self.klass_offsets = torch.load(
643
+ store_path.replace(".pt", "_offsets.pt")
644
+ ).share_memory_()
645
+ print(f"Loaded sampler indices from {store_path}")
646
+ else:
647
+ torch.save(self.klass_indices, store_path)
648
+ torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
649
+ print(f"Saved sampler indices to {store_path}")
650
+ print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
431
651
 
432
652
  def __iter__(self):
653
+ self.count += 1
654
+ # Sample classes according to their weights
655
+ print("sampling a new batch of size", self.num_samples)
656
+
433
657
  sample_labels = torch.multinomial(
434
- self.label_weights,
658
+ (
659
+ self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
660
+ if self.curiculum
661
+ else self.label_weights
662
+ ),
435
663
  num_samples=self.num_samples,
436
664
  replacement=True,
437
665
  )
438
- sample_indices = torch.empty_like(sample_labels)
439
- for i_klass, klass_index in enumerate(self.klass_indices):
666
+ # Get counts of each class in sample_labels
667
+ unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
668
+
669
+ # Initialize result tensor
670
+ result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
671
+
672
+ # Process only the classes that were actually sampled
673
+ for i, (label, count) in tqdm(
674
+ enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
675
+ total=len(unique_samples),
676
+ desc="Processing classes in sampler",
677
+ ):
678
+ klass_index = self.klass_indices[
679
+ self.klass_offsets[label] : self.klass_offsets[label + 1]
680
+ ]
681
+
440
682
  if klass_index.numel() == 0:
441
683
  continue
442
- left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
443
- if len(left_inds) == 0:
444
- continue
684
+
685
+ # Sample elements from this class
445
686
  if self.element_weights is not None:
446
- right_inds = torch.multinomial(
447
- self.element_weights[klass_index],
448
- num_samples=len(klass_index)
449
- if not self.replacement and len(klass_index) < len(left_inds)
450
- else len(left_inds),
451
- replacement=self.replacement,
452
- )
687
+ # This is a critical point for memory
688
+ current_element_weights_slice = self.element_weights[klass_index]
689
+
690
+ if current_element_weights_slice.shape[0] >= (2**24) - 1:
691
+ ind = torch.randperm(len(klass_index))[: (2**24) - 10]
692
+ klass_index = klass_index[ind]
693
+ current_element_weights_slice = current_element_weights_slice[ind]
694
+
695
+ if self.replacement:
696
+ right_inds = torch.multinomial(
697
+ current_element_weights_slice,
698
+ num_samples=count,
699
+ replacement=True,
700
+ )
701
+ else:
702
+ num_to_sample = min(count, len(klass_index))
703
+ right_inds = torch.multinomial(
704
+ current_element_weights_slice,
705
+ num_samples=num_to_sample,
706
+ replacement=False,
707
+ )
453
708
  elif self.replacement:
454
- right_inds = torch.randint(
455
- len(klass_index),
456
- size=(len(left_inds),),
709
+ right_inds = torch.randint(len(klass_index), size=(count,))
710
+ else:
711
+ num_to_sample = min(count, len(klass_index))
712
+ right_inds = torch.randperm(len(klass_index))[:num_to_sample]
713
+
714
+ # Get actual indices
715
+ sampled_indices = klass_index[right_inds]
716
+ result_indices_list.append(sampled_indices)
717
+
718
+ # Combine all indices
719
+ if result_indices_list: # Check if the list is not empty
720
+ final_result_indices = torch.cat(
721
+ result_indices_list
722
+ ) # Use the list with the appended new name
723
+
724
+ # Shuffle the combined indices
725
+ shuffled_indices = final_result_indices[
726
+ torch.randperm(len(final_result_indices))
727
+ ]
728
+ self.num_samples = len(shuffled_indices)
729
+ yield from shuffled_indices.tolist()
730
+ else:
731
+ self.num_samples = 0
732
+ yield from iter([])
733
+
734
+ def __len__(self):
735
+ return self.num_samples
736
+
737
+ def _merge_chunk_results(self, results_list):
738
+ """Merge results from multiple chunks into a single dictionary.
739
+
740
+ Args:
741
+ results_list: list of dictionaries mapping class labels to index arrays
742
+
743
+ Returns:
744
+ merged dictionary with PyTorch tensors
745
+ """
746
+ merged = {}
747
+
748
+ # Collect all labels across all chunks
749
+ all_labels = set()
750
+ for chunk_result in results_list:
751
+ all_labels.update(chunk_result.keys())
752
+
753
+ # For each unique label
754
+ for label in all_labels:
755
+ # Collect indices from all chunks where this label appears
756
+ indices_lists = [
757
+ chunk_result[label]
758
+ for chunk_result in results_list
759
+ if label in chunk_result
760
+ ]
761
+
762
+ if indices_lists:
763
+ # Concatenate all indices for this label
764
+ merged[label] = torch.tensor(
765
+ np.concatenate(indices_lists), dtype=torch.long
457
766
  )
458
767
  else:
459
- maxelem = (
460
- len(left_inds)
461
- if len(left_inds) < len(klass_index)
462
- else len(klass_index)
768
+ merged[label] = torch.tensor([], dtype=torch.long)
769
+
770
+ return merged
771
+
772
+ def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
773
+ """Build class indices in parallel across multiple workers.
774
+
775
+ Args:
776
+ labels: array of class labels
777
+ n_workers: number of parallel workers
778
+ chunk_size: size of chunks to process
779
+
780
+ Returns:
781
+ dictionary mapping class labels to tensors of indices
782
+ """
783
+ n = len(labels)
784
+ results = []
785
+ # Create chunks of the labels array with proper sizing
786
+ n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
787
+ print(f"Processing {n:,} elements in {n_chunks} chunks...")
788
+
789
+ # Process in chunks to limit memory usage
790
+ with ProcessPoolExecutor(
791
+ max_workers=n_workers, mp_context=mp.get_context("spawn")
792
+ ) as executor:
793
+ # Submit chunks for processing
794
+ futures = []
795
+ for i in range(n_chunks):
796
+ start_idx = i * chunk_size
797
+ end_idx = min((i + 1) * chunk_size, n)
798
+ # We pass only chunk boundaries, not the data itself
799
+ # This avoids unnecessary copies during process creation
800
+ futures.append(
801
+ executor.submit(
802
+ self._process_chunk_with_slice,
803
+ (start_idx, end_idx, labels),
804
+ )
463
805
  )
464
- right_inds = torch.randperm(len(klass_index))[:maxelem]
465
- sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
466
- # if there are more left_inds than right_inds, we need to drop the extra ones
467
- if len(right_inds) < len(left_inds):
468
- sample_indices[left_inds[len(right_inds) :]] = -1
469
- # drop all -1
470
- sample_indices = sample_indices[sample_indices != -1]
471
- # torch shuffle
472
- sample_indices = sample_indices[torch.randperm(len(sample_indices))]
473
- self.num_samples = len(sample_indices)
474
- yield from iter(sample_indices.tolist())
806
+
807
+ # Collect results as they complete with progress reporting
808
+ for future in tqdm(
809
+ as_completed(futures), total=len(futures), desc="Processing chunks"
810
+ ):
811
+ results.append(future.result())
812
+
813
+ # Merge results from all chunks
814
+ print("Merging results from all chunks...")
815
+ merged_results = self._merge_chunk_results(results)
816
+
817
+ return merged_results
818
+
819
+ def _process_chunk_with_slice(self, slice_info):
820
+ """Process a slice of the labels array by indices.
821
+
822
+ Args:
823
+ slice_info: tuple of (start_idx, end_idx, labels_array) where
824
+ start_idx and end_idx define the slice to process
825
+
826
+ Returns:
827
+ dict mapping class labels to arrays of indices
828
+ """
829
+ start_idx, end_idx, labels_array = slice_info
830
+
831
+ # We're processing a slice of the original array
832
+ labels_slice = labels_array[start_idx:end_idx]
833
+ chunk_indices = {}
834
+
835
+ # Create a direct map of indices
836
+ indices = np.arange(start_idx, end_idx)
837
+
838
+ # Get unique labels in this slice for more efficient processing
839
+ unique_labels = np.unique(labels_slice)
840
+ # For each valid label, find its indices
841
+ for label in unique_labels:
842
+ # Find positions where this label appears (using direct boolean indexing)
843
+ label_mask = labels_slice == label
844
+ chunk_indices[int(label)] = indices[label_mask]
845
+
846
+ return chunk_indices
847
+
848
+
849
+ class RankShardSampler(Sampler[int]):
850
+ """Shards a dataset contiguously across ranks without padding or duplicates.
851
+ Preserves the existing order (e.g., your pre-shuffled idx_full)."""
852
+
853
+ def __init__(self, data_len: int):
854
+ self.data_len = data_len
855
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
856
+ self.rank = torch.distributed.get_rank()
857
+ self.world_size = torch.distributed.get_world_size()
858
+ else:
859
+ self.rank, self.world_size = 0, 1
860
+
861
+ # contiguous chunk per rank (last rank may be shorter)
862
+ per_rank = math.ceil(self.data_len / self.world_size)
863
+ self.start = self.rank * per_rank
864
+ self.end = min(self.start + per_rank, self.data_len)
865
+
866
+ def __iter__(self):
867
+ return iter(range(self.start, self.end))
475
868
 
476
869
  def __len__(self):
477
- return self.num_samples
870
+ return self.end - self.start