scdataloader 1.9.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/collator.py +13 -24
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +116 -43
- scdataloader/datamodule.py +551 -199
- scdataloader/mapped.py +77 -18
- scdataloader/preprocess.py +106 -95
- scdataloader/utils.py +39 -33
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/METADATA +3 -4
- scdataloader-2.0.0.dist-info/RECORD +16 -0
- scdataloader-2.0.0.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/WHEEL +0 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/entry_points.txt +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
+
import multiprocessing as mp
|
|
1
2
|
import os
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
6
|
+
from functools import partial
|
|
2
7
|
from typing import Optional, Sequence, Union
|
|
3
8
|
|
|
4
9
|
import lamindb as ln
|
|
@@ -13,10 +18,11 @@ from torch.utils.data.sampler import (
|
|
|
13
18
|
SubsetRandomSampler,
|
|
14
19
|
WeightedRandomSampler,
|
|
15
20
|
)
|
|
21
|
+
from tqdm import tqdm
|
|
16
22
|
|
|
17
23
|
from .collator import Collator
|
|
18
24
|
from .data import Dataset
|
|
19
|
-
from .utils import getBiomartTable
|
|
25
|
+
from .utils import fileToList, getBiomartTable, listToFile
|
|
20
26
|
|
|
21
27
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
22
28
|
|
|
@@ -26,7 +32,6 @@ class DataModule(L.LightningDataModule):
|
|
|
26
32
|
self,
|
|
27
33
|
collection_name: str,
|
|
28
34
|
clss_to_weight: list = ["organism_ontology_term_id"],
|
|
29
|
-
organisms: list = ["NCBITaxon:9606"],
|
|
30
35
|
weight_scaler: int = 10,
|
|
31
36
|
n_samples_per_epoch: int = 2_000_000,
|
|
32
37
|
validation_split: float = 0.2,
|
|
@@ -43,7 +48,7 @@ class DataModule(L.LightningDataModule):
|
|
|
43
48
|
max_len: int = 1000,
|
|
44
49
|
add_zero_genes: int = 100,
|
|
45
50
|
replacement: bool = True,
|
|
46
|
-
do_gene_pos:
|
|
51
|
+
do_gene_pos: str = "",
|
|
47
52
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
48
53
|
assays_to_drop: list = [
|
|
49
54
|
# "EFO:0008853", #patch seq
|
|
@@ -53,6 +58,10 @@ class DataModule(L.LightningDataModule):
|
|
|
53
58
|
],
|
|
54
59
|
metacell_mode: float = 0.0,
|
|
55
60
|
get_knn_cells: bool = False,
|
|
61
|
+
store_location: str = None,
|
|
62
|
+
force_recompute_indices: bool = False,
|
|
63
|
+
sampler_workers: int = None,
|
|
64
|
+
sampler_chunk_size: int = None,
|
|
56
65
|
**kwargs,
|
|
57
66
|
):
|
|
58
67
|
"""
|
|
@@ -64,7 +73,6 @@ class DataModule(L.LightningDataModule):
|
|
|
64
73
|
|
|
65
74
|
Args:
|
|
66
75
|
collection_name (str): The lamindb collection to be used.
|
|
67
|
-
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
68
76
|
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
69
77
|
n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
|
|
70
78
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
@@ -88,59 +96,37 @@ class DataModule(L.LightningDataModule):
|
|
|
88
96
|
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
89
97
|
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
90
98
|
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
|
|
99
|
+
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
100
|
+
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
101
|
+
sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
|
|
102
|
+
sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
|
|
91
103
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
92
104
|
see @file data.py and @file collator.py for more details about some of the parameters
|
|
93
105
|
"""
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
organisms=organisms,
|
|
98
|
-
clss_to_predict=clss_to_predict,
|
|
99
|
-
hierarchical_clss=hierarchical_clss,
|
|
100
|
-
metacell_mode=metacell_mode,
|
|
101
|
-
get_knn_cells=get_knn_cells,
|
|
106
|
+
if "organism_ontology_term_id" not in clss_to_predict:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"need 'organism_ontology_term_id' in the set of classes at least"
|
|
102
109
|
)
|
|
110
|
+
mdataset = Dataset(
|
|
111
|
+
ln.Collection.filter(name=collection_name, is_latest=True).first(),
|
|
112
|
+
clss_to_predict=clss_to_predict,
|
|
113
|
+
hierarchical_clss=hierarchical_clss,
|
|
114
|
+
metacell_mode=metacell_mode,
|
|
115
|
+
get_knn_cells=get_knn_cells,
|
|
116
|
+
store_location=store_location,
|
|
117
|
+
force_recompute_indices=force_recompute_indices,
|
|
118
|
+
)
|
|
103
119
|
# and location
|
|
104
120
|
self.metacell_mode = bool(metacell_mode)
|
|
105
121
|
self.gene_pos = None
|
|
106
122
|
self.collection_name = collection_name
|
|
107
123
|
if do_gene_pos:
|
|
108
|
-
|
|
109
|
-
print("seeing a string: loading gene positions as biomart parquet file")
|
|
110
|
-
biomart = pd.read_parquet(do_gene_pos)
|
|
111
|
-
else:
|
|
112
|
-
# and annotations
|
|
113
|
-
if organisms != ["NCBITaxon:9606"]:
|
|
114
|
-
raise ValueError(
|
|
115
|
-
"need to provide your own table as this automated function only works for humans for now"
|
|
116
|
-
)
|
|
117
|
-
biomart = getBiomartTable(
|
|
118
|
-
attributes=["start_position", "chromosome_name"],
|
|
119
|
-
useCache=True,
|
|
120
|
-
).set_index("ensembl_gene_id")
|
|
121
|
-
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
122
|
-
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
123
|
-
c = []
|
|
124
|
-
i = 0
|
|
125
|
-
prev_position = -100000
|
|
126
|
-
prev_chromosome = None
|
|
127
|
-
for _, r in biomart.iterrows():
|
|
128
|
-
if (
|
|
129
|
-
r["chromosome_name"] != prev_chromosome
|
|
130
|
-
or r["start_position"] - prev_position > gene_position_tolerance
|
|
131
|
-
):
|
|
132
|
-
i += 1
|
|
133
|
-
c.append(i)
|
|
134
|
-
prev_position = r["start_position"]
|
|
135
|
-
prev_chromosome = r["chromosome_name"]
|
|
136
|
-
print(f"reduced the size to {len(set(c)) / len(biomart)}")
|
|
137
|
-
biomart["pos"] = c
|
|
124
|
+
biomart = pd.read_parquet(do_gene_pos)
|
|
138
125
|
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
|
|
139
126
|
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
140
|
-
|
|
141
127
|
if gene_embeddings != "":
|
|
142
128
|
mdataset.genedf = mdataset.genedf.join(
|
|
143
|
-
pd.read_parquet(gene_embeddings), how="inner"
|
|
129
|
+
pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
|
|
144
130
|
)
|
|
145
131
|
if do_gene_pos:
|
|
146
132
|
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
@@ -149,7 +135,7 @@ class DataModule(L.LightningDataModule):
|
|
|
149
135
|
# we might want to not introduce zeros and
|
|
150
136
|
if use_default_col:
|
|
151
137
|
kwargs["collate_fn"] = Collator(
|
|
152
|
-
organisms=organisms,
|
|
138
|
+
organisms=mdataset.organisms,
|
|
153
139
|
how=how,
|
|
154
140
|
valid_genes=mdataset.genedf.index.tolist(),
|
|
155
141
|
max_len=max_len,
|
|
@@ -157,7 +143,7 @@ class DataModule(L.LightningDataModule):
|
|
|
157
143
|
org_to_id=mdataset.encoder[organism_name],
|
|
158
144
|
tp_name=tp_name,
|
|
159
145
|
organism_name=organism_name,
|
|
160
|
-
class_names=
|
|
146
|
+
class_names=list(self.classes.keys()),
|
|
161
147
|
)
|
|
162
148
|
self.validation_split = validation_split
|
|
163
149
|
self.test_split = test_split
|
|
@@ -173,10 +159,15 @@ class DataModule(L.LightningDataModule):
|
|
|
173
159
|
self.clss_to_weight = clss_to_weight
|
|
174
160
|
self.train_weights = None
|
|
175
161
|
self.train_labels = None
|
|
162
|
+
self.sampler_workers = sampler_workers
|
|
163
|
+
self.sampler_chunk_size = sampler_chunk_size
|
|
164
|
+
self.store_location = store_location
|
|
176
165
|
self.nnz = None
|
|
177
166
|
self.test_datasets = []
|
|
167
|
+
self.force_recompute_indices = force_recompute_indices
|
|
178
168
|
self.test_idx = []
|
|
179
169
|
super().__init__()
|
|
170
|
+
print("finished init")
|
|
180
171
|
|
|
181
172
|
def __repr__(self):
|
|
182
173
|
return (
|
|
@@ -238,6 +229,44 @@ class DataModule(L.LightningDataModule):
|
|
|
238
229
|
"""
|
|
239
230
|
return self.dataset.genedf.index.tolist()
|
|
240
231
|
|
|
232
|
+
@genes.setter
|
|
233
|
+
def genes(self, genes):
|
|
234
|
+
self.dataset.genedf = self.dataset.genedf.loc[genes]
|
|
235
|
+
self.kwargs["collate_fn"].genes = genes
|
|
236
|
+
self.kwargs["collate_fn"]._setup(
|
|
237
|
+
genedf=self.dataset.genedf,
|
|
238
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
239
|
+
valid_genes=genes,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def encoders(self):
|
|
244
|
+
return self.dataset.encoder
|
|
245
|
+
|
|
246
|
+
@encoders.setter
|
|
247
|
+
def encoders(self, encoders):
|
|
248
|
+
self.dataset.encoder = encoders
|
|
249
|
+
self.kwargs["collate_fn"].org_to_id = encoders[
|
|
250
|
+
self.kwargs["collate_fn"].organism_name
|
|
251
|
+
]
|
|
252
|
+
self.kwargs["collate_fn"]._setup(
|
|
253
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
254
|
+
valid_genes=self.genes,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def organisms(self):
|
|
259
|
+
return self.dataset.organisms
|
|
260
|
+
|
|
261
|
+
@organisms.setter
|
|
262
|
+
def organisms(self, organisms):
|
|
263
|
+
self.dataset.organisms = organisms
|
|
264
|
+
self.kwargs["collate_fn"].organisms = organisms
|
|
265
|
+
self.kwargs["collate_fn"]._setup(
|
|
266
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
267
|
+
valid_genes=self.genes,
|
|
268
|
+
)
|
|
269
|
+
|
|
241
270
|
@property
|
|
242
271
|
def num_datasets(self):
|
|
243
272
|
return len(self.dataset.mapped_dataset.storages)
|
|
@@ -252,104 +281,191 @@ class DataModule(L.LightningDataModule):
|
|
|
252
281
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
253
282
|
"""
|
|
254
283
|
SCALE = 10
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
).min()
|
|
262
|
-
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
263
|
-
weights, labels = self.dataset.get_label_weights(
|
|
264
|
-
self.clss_to_weight,
|
|
265
|
-
scaler=self.weight_scaler,
|
|
266
|
-
return_categories=True,
|
|
284
|
+
print("setting up the datamodule")
|
|
285
|
+
start_time = time.time()
|
|
286
|
+
if (
|
|
287
|
+
self.store_location is None
|
|
288
|
+
or not os.path.exists(
|
|
289
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
267
290
|
)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
self.
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
291
|
+
or self.force_recompute_indices
|
|
292
|
+
):
|
|
293
|
+
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
294
|
+
self.nnz = self.dataset.mapped_dataset.get_merged_labels(
|
|
295
|
+
"nnz", is_cat=False
|
|
296
|
+
)
|
|
297
|
+
self.clss_to_weight.remove("nnz")
|
|
298
|
+
(
|
|
299
|
+
(self.nnz.max() / SCALE)
|
|
300
|
+
/ ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
|
|
301
|
+
).min()
|
|
302
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
303
|
+
weights, labels = self.dataset.get_label_weights(
|
|
304
|
+
self.clss_to_weight,
|
|
305
|
+
scaler=self.weight_scaler,
|
|
306
|
+
return_categories=True,
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
weights = np.ones(1)
|
|
310
|
+
labels = np.zeros(self.n_samples, dtype=int)
|
|
311
|
+
if isinstance(self.validation_split, int):
|
|
312
|
+
len_valid = self.validation_split
|
|
313
|
+
else:
|
|
314
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
315
|
+
if isinstance(self.test_split, int):
|
|
316
|
+
len_test = self.test_split
|
|
317
|
+
else:
|
|
318
|
+
len_test = int(self.n_samples * self.test_split)
|
|
319
|
+
assert len_test + len_valid < self.n_samples, (
|
|
320
|
+
"test set + valid set size is configured to be larger than entire dataset."
|
|
299
321
|
)
|
|
300
|
-
cs = 0
|
|
301
|
-
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
302
|
-
if cs + c > len_test:
|
|
303
|
-
break
|
|
304
|
-
else:
|
|
305
|
-
self.test_datasets.append(
|
|
306
|
-
self.dataset.mapped_dataset.path_list[i].path
|
|
307
|
-
)
|
|
308
|
-
cs += c
|
|
309
|
-
len_test = cs
|
|
310
|
-
self.test_idx = idx_full[:len_test]
|
|
311
|
-
idx_full = idx_full[len_test:]
|
|
312
|
-
else:
|
|
313
|
-
self.test_idx = None
|
|
314
322
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
323
|
+
idx_full = []
|
|
324
|
+
if len(self.assays_to_drop) > 0:
|
|
325
|
+
badloc = np.isin(
|
|
326
|
+
self.dataset.mapped_dataset.get_merged_labels(
|
|
327
|
+
"assay_ontology_term_id"
|
|
328
|
+
),
|
|
329
|
+
self.assays_to_drop,
|
|
330
|
+
)
|
|
331
|
+
idx_full = np.arange(len(labels))[~badloc]
|
|
332
|
+
else:
|
|
333
|
+
idx_full = np.arange(self.n_samples)
|
|
334
|
+
if len_test > 0:
|
|
335
|
+
# this way we work on some never seen datasets
|
|
336
|
+
# keeping at least one
|
|
337
|
+
len_test = (
|
|
338
|
+
len_test
|
|
339
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
340
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
341
|
+
)
|
|
342
|
+
cs = 0
|
|
343
|
+
d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
|
|
344
|
+
random.Random(42).shuffle(d_size) # always same order
|
|
345
|
+
for i, c in d_size:
|
|
346
|
+
if cs + c > len_test:
|
|
347
|
+
break
|
|
348
|
+
else:
|
|
349
|
+
self.test_datasets.append(
|
|
350
|
+
self.dataset.mapped_dataset.path_list[i].path
|
|
351
|
+
)
|
|
352
|
+
cs += c
|
|
353
|
+
len_test = cs
|
|
354
|
+
self.test_idx = idx_full[:len_test]
|
|
355
|
+
idx_full = idx_full[len_test:]
|
|
356
|
+
else:
|
|
357
|
+
self.test_idx = None
|
|
358
|
+
|
|
359
|
+
np.random.shuffle(idx_full)
|
|
360
|
+
if len_valid > 0:
|
|
361
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
362
|
+
# store it for later
|
|
363
|
+
idx_full = idx_full[len_valid:]
|
|
364
|
+
else:
|
|
365
|
+
self.valid_idx = None
|
|
366
|
+
weights = np.concatenate([weights, np.zeros(1)])
|
|
367
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
368
|
+
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
369
|
+
# this means that the associated weights should be 0.
|
|
370
|
+
# by doing np.bincount(labels)*weights this will be taken into account
|
|
371
|
+
self.train_weights = weights
|
|
372
|
+
self.train_labels = labels
|
|
373
|
+
self.idx_full = idx_full
|
|
374
|
+
if self.store_location is not None:
|
|
375
|
+
if (
|
|
376
|
+
not os.path.exists(
|
|
377
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
378
|
+
)
|
|
379
|
+
or self.force_recompute_indices
|
|
380
|
+
):
|
|
381
|
+
os.makedirs(self.store_location, exist_ok=True)
|
|
382
|
+
if self.nnz is not None:
|
|
383
|
+
np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
|
|
384
|
+
np.save(
|
|
385
|
+
os.path.join(self.store_location, "train_weights.npy"),
|
|
386
|
+
self.train_weights,
|
|
387
|
+
)
|
|
388
|
+
np.save(
|
|
389
|
+
os.path.join(self.store_location, "train_labels.npy"),
|
|
390
|
+
self.train_labels,
|
|
391
|
+
)
|
|
392
|
+
np.save(
|
|
393
|
+
os.path.join(self.store_location, "idx_full.npy"), self.idx_full
|
|
394
|
+
)
|
|
395
|
+
if self.test_idx is not None:
|
|
396
|
+
np.save(
|
|
397
|
+
os.path.join(self.store_location, "test_idx.npy"), self.test_idx
|
|
398
|
+
)
|
|
399
|
+
if self.valid_idx is not None:
|
|
400
|
+
np.save(
|
|
401
|
+
os.path.join(self.store_location, "valid_idx.npy"),
|
|
402
|
+
self.valid_idx,
|
|
403
|
+
)
|
|
404
|
+
listToFile(
|
|
405
|
+
self.test_datasets,
|
|
406
|
+
os.path.join(self.store_location, "test_datasets.txt"),
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
self.nnz = (
|
|
410
|
+
np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
|
|
411
|
+
if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
|
|
412
|
+
else None
|
|
413
|
+
)
|
|
414
|
+
self.train_weights = np.load(
|
|
415
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
416
|
+
)
|
|
417
|
+
self.train_labels = np.load(
|
|
418
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
419
|
+
)
|
|
420
|
+
self.idx_full = np.load(
|
|
421
|
+
os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
|
|
422
|
+
)
|
|
423
|
+
self.test_idx = (
|
|
424
|
+
np.load(os.path.join(self.store_location, "test_idx.npy"))
|
|
425
|
+
if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
|
|
426
|
+
else None
|
|
427
|
+
)
|
|
428
|
+
self.valid_idx = (
|
|
429
|
+
np.load(os.path.join(self.store_location, "valid_idx.npy"))
|
|
430
|
+
if os.path.exists(
|
|
431
|
+
os.path.join(self.store_location, "valid_idx.npy")
|
|
432
|
+
)
|
|
433
|
+
else None
|
|
434
|
+
)
|
|
435
|
+
self.test_datasets = fileToList(
|
|
436
|
+
os.path.join(self.store_location, "test_datasets.txt")
|
|
437
|
+
)
|
|
438
|
+
print("loaded from store")
|
|
439
|
+
print(f"done setup, took {time.time() - start_time:.2f} seconds")
|
|
330
440
|
return self.test_datasets
|
|
331
441
|
|
|
332
442
|
def train_dataloader(self, **kwargs):
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
443
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
444
|
+
try:
|
|
445
|
+
print("Setting up the parallel train sampler...")
|
|
446
|
+
# Create the optimized parallel sampler
|
|
447
|
+
print(f"Using {self.sampler_workers} workers for class indexing")
|
|
448
|
+
train_sampler = LabelWeightedSampler(
|
|
449
|
+
label_weights=self.train_weights,
|
|
450
|
+
labels=self.train_labels,
|
|
451
|
+
num_samples=int(self.n_samples_per_epoch),
|
|
452
|
+
element_weights=self.nnz,
|
|
453
|
+
replacement=self.replacement,
|
|
454
|
+
n_workers=self.sampler_workers,
|
|
455
|
+
chunk_size=self.sampler_chunk_size,
|
|
456
|
+
store_location=self.store_location,
|
|
457
|
+
force_recompute_indices=self.force_recompute_indices,
|
|
458
|
+
)
|
|
459
|
+
except ValueError as e:
|
|
460
|
+
raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
|
|
461
|
+
else:
|
|
462
|
+
train_sampler = SubsetRandomSampler(self.idx_full)
|
|
463
|
+
current_loader_kwargs = kwargs.copy()
|
|
464
|
+
current_loader_kwargs.update(self.kwargs)
|
|
348
465
|
return DataLoader(
|
|
349
466
|
self.dataset,
|
|
350
467
|
sampler=train_sampler,
|
|
351
|
-
**
|
|
352
|
-
**kwargs,
|
|
468
|
+
**current_loader_kwargs,
|
|
353
469
|
)
|
|
354
470
|
|
|
355
471
|
def val_dataloader(self):
|
|
@@ -379,99 +495,335 @@ class DataModule(L.LightningDataModule):
|
|
|
379
495
|
**self.kwargs,
|
|
380
496
|
)
|
|
381
497
|
|
|
382
|
-
# def teardown(self):
|
|
383
|
-
# clean up state after the trainer stops, delete files...
|
|
384
|
-
# called on every process in DDP
|
|
385
|
-
# pass
|
|
386
|
-
|
|
387
498
|
|
|
388
499
|
class LabelWeightedSampler(Sampler[int]):
|
|
389
|
-
|
|
390
|
-
|
|
500
|
+
"""
|
|
501
|
+
A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
|
|
502
|
+
|
|
503
|
+
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
504
|
+
1. Parallel building of class indices
|
|
505
|
+
2. Chunked processing for large arrays
|
|
506
|
+
3. Efficient memory management
|
|
507
|
+
4. Proper handling of replacement and non-replacement sampling
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
label_weights: torch.Tensor
|
|
511
|
+
klass_indices: dict[int, torch.Tensor]
|
|
391
512
|
num_samples: int
|
|
392
|
-
|
|
513
|
+
element_weights: Optional[torch.Tensor]
|
|
393
514
|
replacement: bool
|
|
394
|
-
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
395
|
-
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
396
515
|
|
|
397
516
|
def __init__(
|
|
398
517
|
self,
|
|
399
518
|
label_weights: Sequence[float],
|
|
400
|
-
labels:
|
|
519
|
+
labels: np.ndarray,
|
|
401
520
|
num_samples: int,
|
|
402
521
|
replacement: bool = True,
|
|
403
|
-
element_weights: Sequence[float] = None,
|
|
522
|
+
element_weights: Optional[Sequence[float]] = None,
|
|
523
|
+
n_workers: int = None,
|
|
524
|
+
chunk_size: int = None, # Process 10M elements per chunk
|
|
525
|
+
store_location: str = None,
|
|
526
|
+
force_recompute_indices: bool = False,
|
|
404
527
|
) -> None:
|
|
405
528
|
"""
|
|
529
|
+
Initialize the sampler with parallel processing for large datasets.
|
|
406
530
|
|
|
407
|
-
:
|
|
408
|
-
|
|
409
|
-
|
|
531
|
+
Args:
|
|
532
|
+
label_weights: Weights for each class (length = number of classes)
|
|
533
|
+
labels: Class label for each dataset element (length = dataset size)
|
|
534
|
+
num_samples: Number of samples to draw
|
|
535
|
+
replacement: Whether to sample with replacement
|
|
536
|
+
element_weights: Optional weights for each element within classes
|
|
537
|
+
n_workers: Number of parallel workers to use (default: number of CPUs-1)
|
|
538
|
+
chunk_size: Size of chunks to process in parallel (default: 10M elements)
|
|
410
539
|
"""
|
|
411
|
-
|
|
540
|
+
print("Initializing optimized parallel weighted sampler...")
|
|
412
541
|
super(LabelWeightedSampler, self).__init__(None)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
self.
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
542
|
+
|
|
543
|
+
# Compute label weights (incorporating class frequencies)
|
|
544
|
+
# Directly use labels as numpy array without conversion
|
|
545
|
+
label_weights = np.asarray(label_weights) * np.bincount(labels)
|
|
546
|
+
self.label_weights = torch.as_tensor(
|
|
547
|
+
label_weights, dtype=torch.float32
|
|
548
|
+
).share_memory_()
|
|
549
|
+
|
|
550
|
+
# Store element weights if provided
|
|
551
|
+
if element_weights is not None:
|
|
552
|
+
self.element_weights = torch.as_tensor(
|
|
553
|
+
element_weights, dtype=torch.float32
|
|
554
|
+
).share_memory_()
|
|
555
|
+
else:
|
|
556
|
+
self.element_weights = None
|
|
557
|
+
|
|
423
558
|
self.replacement = replacement
|
|
424
559
|
self.num_samples = num_samples
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
560
|
+
if (
|
|
561
|
+
store_location is None
|
|
562
|
+
or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
|
|
563
|
+
or force_recompute_indices
|
|
564
|
+
):
|
|
565
|
+
# Set number of workers (default to CPU count - 1, but at least 1)
|
|
566
|
+
if n_workers is None:
|
|
567
|
+
# Check if running on SLURM
|
|
568
|
+
n_workers = min(20, max(1, mp.cpu_count() - 1))
|
|
569
|
+
if "SLURM_CPUS_PER_TASK" in os.environ:
|
|
570
|
+
n_workers = min(
|
|
571
|
+
20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Try to auto-determine optimal chunk size based on memory
|
|
575
|
+
if chunk_size is None:
|
|
576
|
+
try:
|
|
577
|
+
import psutil
|
|
578
|
+
|
|
579
|
+
# Check if running on SLURM
|
|
580
|
+
available_memory = psutil.virtual_memory().available
|
|
581
|
+
for name in [
|
|
582
|
+
"SLURM_MEM_PER_NODE",
|
|
583
|
+
"SLURM_MEM_PER_CPU",
|
|
584
|
+
"SLURM_MEM_PER_GPU",
|
|
585
|
+
"SLURM_MEM_PER_TASK",
|
|
586
|
+
]:
|
|
587
|
+
if name in os.environ:
|
|
588
|
+
available_memory = (
|
|
589
|
+
int(os.environ[name]) * 1024 * 1024
|
|
590
|
+
) # Convert MB to bytes
|
|
591
|
+
break
|
|
592
|
+
|
|
593
|
+
# Use at most 50% of available memory across all workers
|
|
594
|
+
memory_per_worker = 0.5 * available_memory / n_workers
|
|
595
|
+
# Rough estimate: each label takes 4 bytes, each index 8 bytes
|
|
596
|
+
bytes_per_element = 12
|
|
597
|
+
chunk_size = min(
|
|
598
|
+
max(100_000, int(memory_per_worker / bytes_per_element / 3)),
|
|
599
|
+
2_000_000,
|
|
600
|
+
)
|
|
601
|
+
print(f"Auto-determined chunk size: {chunk_size:,} elements")
|
|
602
|
+
except (ImportError, KeyError):
|
|
603
|
+
chunk_size = 2_000_000
|
|
604
|
+
print(f"Using default chunk size: {chunk_size:,} elements")
|
|
605
|
+
|
|
606
|
+
# Parallelize the class indices building
|
|
607
|
+
print(f"Building class indices in parallel with {n_workers} workers...")
|
|
608
|
+
klass_indices = self._build_class_indices_parallel(
|
|
609
|
+
labels, chunk_size, n_workers
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Convert klass_indices to a single tensor and offset vector
|
|
613
|
+
all_indices = []
|
|
614
|
+
offsets = []
|
|
615
|
+
current_offset = 0
|
|
616
|
+
|
|
617
|
+
# Sort keys to ensure consistent ordering
|
|
618
|
+
keys = klass_indices.keys()
|
|
619
|
+
|
|
620
|
+
# Build concatenated tensor and track offsets
|
|
621
|
+
for i in range(max(keys) + 1):
|
|
622
|
+
offsets.append(current_offset)
|
|
623
|
+
if i in keys:
|
|
624
|
+
indices = klass_indices[i]
|
|
625
|
+
all_indices.append(indices)
|
|
626
|
+
current_offset += len(indices)
|
|
627
|
+
|
|
628
|
+
# Convert to tensors
|
|
629
|
+
self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
|
|
630
|
+
self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
|
|
631
|
+
if store_location is not None:
|
|
632
|
+
store_path = os.path.join(store_location, "klass_indices.pt")
|
|
633
|
+
if os.path.exists(store_path) and not force_recompute_indices:
|
|
634
|
+
self.klass_indices = torch.load(store_path).share_memory_()
|
|
635
|
+
self.klass_offsets = torch.load(
|
|
636
|
+
store_path.replace(".pt", "_offsets.pt")
|
|
637
|
+
).share_memory_()
|
|
638
|
+
print(f"Loaded sampler indices from {store_path}")
|
|
639
|
+
else:
|
|
640
|
+
torch.save(self.klass_indices, store_path)
|
|
641
|
+
torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
|
|
642
|
+
print(f"Saved sampler indices to {store_path}")
|
|
643
|
+
print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
|
|
431
644
|
|
|
432
645
|
def __iter__(self):
|
|
646
|
+
# Sample classes according to their weights
|
|
647
|
+
print("sampling a new batch of size", self.num_samples)
|
|
648
|
+
|
|
433
649
|
sample_labels = torch.multinomial(
|
|
434
650
|
self.label_weights,
|
|
435
651
|
num_samples=self.num_samples,
|
|
436
652
|
replacement=True,
|
|
437
653
|
)
|
|
438
|
-
|
|
439
|
-
|
|
654
|
+
# Get counts of each class in sample_labels
|
|
655
|
+
unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
|
|
656
|
+
|
|
657
|
+
# Initialize result tensor
|
|
658
|
+
result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
|
|
659
|
+
|
|
660
|
+
# Process only the classes that were actually sampled
|
|
661
|
+
for i, (label, count) in tqdm(
|
|
662
|
+
enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
|
|
663
|
+
total=len(unique_samples),
|
|
664
|
+
desc="Processing classes in sampler",
|
|
665
|
+
):
|
|
666
|
+
klass_index = self.klass_indices[
|
|
667
|
+
self.klass_offsets[label] : self.klass_offsets[label + 1]
|
|
668
|
+
]
|
|
669
|
+
|
|
440
670
|
if klass_index.numel() == 0:
|
|
441
671
|
continue
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
continue
|
|
672
|
+
|
|
673
|
+
# Sample elements from this class
|
|
445
674
|
if self.element_weights is not None:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
675
|
+
# This is a critical point for memory
|
|
676
|
+
current_element_weights_slice = self.element_weights[klass_index]
|
|
677
|
+
|
|
678
|
+
if self.replacement:
|
|
679
|
+
right_inds = torch.multinomial(
|
|
680
|
+
current_element_weights_slice,
|
|
681
|
+
num_samples=count,
|
|
682
|
+
replacement=True,
|
|
683
|
+
)
|
|
684
|
+
else:
|
|
685
|
+
num_to_sample = min(count, len(klass_index))
|
|
686
|
+
right_inds = torch.multinomial(
|
|
687
|
+
current_element_weights_slice,
|
|
688
|
+
num_samples=num_to_sample,
|
|
689
|
+
replacement=False,
|
|
690
|
+
)
|
|
453
691
|
elif self.replacement:
|
|
454
|
-
right_inds = torch.randint(
|
|
455
|
-
len(klass_index),
|
|
456
|
-
size=(len(left_inds),),
|
|
457
|
-
)
|
|
692
|
+
right_inds = torch.randint(len(klass_index), size=(count,))
|
|
458
693
|
else:
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
694
|
+
num_to_sample = min(count, len(klass_index))
|
|
695
|
+
right_inds = torch.randperm(len(klass_index))[:num_to_sample]
|
|
696
|
+
|
|
697
|
+
# Get actual indices
|
|
698
|
+
sampled_indices = klass_index[right_inds]
|
|
699
|
+
result_indices_list.append(sampled_indices)
|
|
700
|
+
|
|
701
|
+
# Combine all indices
|
|
702
|
+
if result_indices_list: # Check if the list is not empty
|
|
703
|
+
final_result_indices = torch.cat(
|
|
704
|
+
result_indices_list
|
|
705
|
+
) # Use the list with the appended new name
|
|
706
|
+
|
|
707
|
+
# Shuffle the combined indices
|
|
708
|
+
shuffled_indices = final_result_indices[
|
|
709
|
+
torch.randperm(len(final_result_indices))
|
|
710
|
+
]
|
|
711
|
+
self.num_samples = len(shuffled_indices)
|
|
712
|
+
yield from shuffled_indices.tolist()
|
|
713
|
+
else:
|
|
714
|
+
self.num_samples = 0
|
|
715
|
+
yield from iter([])
|
|
475
716
|
|
|
476
717
|
def __len__(self):
|
|
477
718
|
return self.num_samples
|
|
719
|
+
|
|
720
|
+
def _merge_chunk_results(self, results_list):
|
|
721
|
+
"""Merge results from multiple chunks into a single dictionary.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
results_list: list of dictionaries mapping class labels to index arrays
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
merged dictionary with PyTorch tensors
|
|
728
|
+
"""
|
|
729
|
+
merged = {}
|
|
730
|
+
|
|
731
|
+
# Collect all labels across all chunks
|
|
732
|
+
all_labels = set()
|
|
733
|
+
for chunk_result in results_list:
|
|
734
|
+
all_labels.update(chunk_result.keys())
|
|
735
|
+
|
|
736
|
+
# For each unique label
|
|
737
|
+
for label in all_labels:
|
|
738
|
+
# Collect indices from all chunks where this label appears
|
|
739
|
+
indices_lists = [
|
|
740
|
+
chunk_result[label]
|
|
741
|
+
for chunk_result in results_list
|
|
742
|
+
if label in chunk_result
|
|
743
|
+
]
|
|
744
|
+
|
|
745
|
+
if indices_lists:
|
|
746
|
+
# Concatenate all indices for this label
|
|
747
|
+
merged[label] = torch.tensor(
|
|
748
|
+
np.concatenate(indices_lists), dtype=torch.long
|
|
749
|
+
)
|
|
750
|
+
else:
|
|
751
|
+
merged[label] = torch.tensor([], dtype=torch.long)
|
|
752
|
+
|
|
753
|
+
return merged
|
|
754
|
+
|
|
755
|
+
def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
|
|
756
|
+
"""Build class indices in parallel across multiple workers.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
labels: array of class labels
|
|
760
|
+
n_workers: number of parallel workers
|
|
761
|
+
chunk_size: size of chunks to process
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
dictionary mapping class labels to tensors of indices
|
|
765
|
+
"""
|
|
766
|
+
n = len(labels)
|
|
767
|
+
results = []
|
|
768
|
+
# Create chunks of the labels array with proper sizing
|
|
769
|
+
n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
|
|
770
|
+
print(f"Processing {n:,} elements in {n_chunks} chunks...")
|
|
771
|
+
|
|
772
|
+
# Process in chunks to limit memory usage
|
|
773
|
+
with ProcessPoolExecutor(
|
|
774
|
+
max_workers=n_workers, mp_context=mp.get_context("spawn")
|
|
775
|
+
) as executor:
|
|
776
|
+
# Submit chunks for processing
|
|
777
|
+
futures = []
|
|
778
|
+
for i in range(n_chunks):
|
|
779
|
+
start_idx = i * chunk_size
|
|
780
|
+
end_idx = min((i + 1) * chunk_size, n)
|
|
781
|
+
# We pass only chunk boundaries, not the data itself
|
|
782
|
+
# This avoids unnecessary copies during process creation
|
|
783
|
+
futures.append(
|
|
784
|
+
executor.submit(
|
|
785
|
+
self._process_chunk_with_slice,
|
|
786
|
+
(start_idx, end_idx, labels),
|
|
787
|
+
)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Collect results as they complete with progress reporting
|
|
791
|
+
for future in tqdm(
|
|
792
|
+
as_completed(futures), total=len(futures), desc="Processing chunks"
|
|
793
|
+
):
|
|
794
|
+
results.append(future.result())
|
|
795
|
+
|
|
796
|
+
# Merge results from all chunks
|
|
797
|
+
print("Merging results from all chunks...")
|
|
798
|
+
merged_results = self._merge_chunk_results(results)
|
|
799
|
+
|
|
800
|
+
return merged_results
|
|
801
|
+
|
|
802
|
+
def _process_chunk_with_slice(self, slice_info):
|
|
803
|
+
"""Process a slice of the labels array by indices.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
slice_info: tuple of (start_idx, end_idx, labels_array) where
|
|
807
|
+
start_idx and end_idx define the slice to process
|
|
808
|
+
|
|
809
|
+
Returns:
|
|
810
|
+
dict mapping class labels to arrays of indices
|
|
811
|
+
"""
|
|
812
|
+
start_idx, end_idx, labels_array = slice_info
|
|
813
|
+
|
|
814
|
+
# We're processing a slice of the original array
|
|
815
|
+
labels_slice = labels_array[start_idx:end_idx]
|
|
816
|
+
chunk_indices = {}
|
|
817
|
+
|
|
818
|
+
# Create a direct map of indices
|
|
819
|
+
indices = np.arange(start_idx, end_idx)
|
|
820
|
+
|
|
821
|
+
# Get unique labels in this slice for more efficient processing
|
|
822
|
+
unique_labels = np.unique(labels_slice)
|
|
823
|
+
# For each valid label, find its indices
|
|
824
|
+
for label in unique_labels:
|
|
825
|
+
# Find positions where this label appears (using direct boolean indexing)
|
|
826
|
+
label_mask = labels_slice == label
|
|
827
|
+
chunk_indices[int(label)] = indices[label_mask]
|
|
828
|
+
|
|
829
|
+
return chunk_indices
|