scdataloader 1.9.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__init__.py +2 -1
- scdataloader/collator.py +30 -42
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +116 -43
- scdataloader/datamodule.py +555 -225
- scdataloader/mapped.py +84 -18
- scdataloader/preprocess.py +108 -94
- scdataloader/utils.py +39 -33
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/METADATA +13 -5
- scdataloader-2.0.0.dist-info/RECORD +16 -0
- scdataloader-2.0.0.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.1.dist-info/RECORD +0 -16
- scdataloader-1.9.1.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/WHEEL +0 -0
- {scdataloader-1.9.1.dist-info → scdataloader-2.0.0.dist-info}/entry_points.txt +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
+
import multiprocessing as mp
|
|
1
2
|
import os
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
6
|
+
from functools import partial
|
|
2
7
|
from typing import Optional, Sequence, Union
|
|
3
8
|
|
|
4
9
|
import lamindb as ln
|
|
@@ -13,10 +18,11 @@ from torch.utils.data.sampler import (
|
|
|
13
18
|
SubsetRandomSampler,
|
|
14
19
|
WeightedRandomSampler,
|
|
15
20
|
)
|
|
21
|
+
from tqdm import tqdm
|
|
16
22
|
|
|
17
23
|
from .collator import Collator
|
|
18
24
|
from .data import Dataset
|
|
19
|
-
from .utils import getBiomartTable,
|
|
25
|
+
from .utils import fileToList, getBiomartTable, listToFile
|
|
20
26
|
|
|
21
27
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
22
28
|
|
|
@@ -26,9 +32,8 @@ class DataModule(L.LightningDataModule):
|
|
|
26
32
|
self,
|
|
27
33
|
collection_name: str,
|
|
28
34
|
clss_to_weight: list = ["organism_ontology_term_id"],
|
|
29
|
-
organisms: list = ["NCBITaxon:9606"],
|
|
30
35
|
weight_scaler: int = 10,
|
|
31
|
-
|
|
36
|
+
n_samples_per_epoch: int = 2_000_000,
|
|
32
37
|
validation_split: float = 0.2,
|
|
33
38
|
test_split: float = 0,
|
|
34
39
|
gene_embeddings: str = "",
|
|
@@ -43,7 +48,7 @@ class DataModule(L.LightningDataModule):
|
|
|
43
48
|
max_len: int = 1000,
|
|
44
49
|
add_zero_genes: int = 100,
|
|
45
50
|
replacement: bool = True,
|
|
46
|
-
do_gene_pos:
|
|
51
|
+
do_gene_pos: str = "",
|
|
47
52
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
48
53
|
assays_to_drop: list = [
|
|
49
54
|
# "EFO:0008853", #patch seq
|
|
@@ -53,7 +58,10 @@ class DataModule(L.LightningDataModule):
|
|
|
53
58
|
],
|
|
54
59
|
metacell_mode: float = 0.0,
|
|
55
60
|
get_knn_cells: bool = False,
|
|
56
|
-
|
|
61
|
+
store_location: str = None,
|
|
62
|
+
force_recompute_indices: bool = False,
|
|
63
|
+
sampler_workers: int = None,
|
|
64
|
+
sampler_chunk_size: int = None,
|
|
57
65
|
**kwargs,
|
|
58
66
|
):
|
|
59
67
|
"""
|
|
@@ -65,9 +73,8 @@ class DataModule(L.LightningDataModule):
|
|
|
65
73
|
|
|
66
74
|
Args:
|
|
67
75
|
collection_name (str): The lamindb collection to be used.
|
|
68
|
-
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
69
76
|
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
70
|
-
|
|
77
|
+
n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
|
|
71
78
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
72
79
|
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
73
80
|
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
@@ -88,61 +95,38 @@ class DataModule(L.LightningDataModule):
|
|
|
88
95
|
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
|
|
89
96
|
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
90
97
|
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
91
|
-
modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
|
|
92
98
|
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
|
|
99
|
+
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
100
|
+
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
101
|
+
sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
|
|
102
|
+
sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
|
|
93
103
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
94
104
|
see @file data.py and @file collator.py for more details about some of the parameters
|
|
95
105
|
"""
|
|
96
|
-
if
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
organisms=organisms,
|
|
100
|
-
clss_to_predict=clss_to_predict,
|
|
101
|
-
hierarchical_clss=hierarchical_clss,
|
|
102
|
-
metacell_mode=metacell_mode,
|
|
103
|
-
get_knn_cells=get_knn_cells,
|
|
106
|
+
if "organism_ontology_term_id" not in clss_to_predict:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"need 'organism_ontology_term_id' in the set of classes at least"
|
|
104
109
|
)
|
|
110
|
+
mdataset = Dataset(
|
|
111
|
+
ln.Collection.filter(name=collection_name, is_latest=True).first(),
|
|
112
|
+
clss_to_predict=clss_to_predict,
|
|
113
|
+
hierarchical_clss=hierarchical_clss,
|
|
114
|
+
metacell_mode=metacell_mode,
|
|
115
|
+
get_knn_cells=get_knn_cells,
|
|
116
|
+
store_location=store_location,
|
|
117
|
+
force_recompute_indices=force_recompute_indices,
|
|
118
|
+
)
|
|
105
119
|
# and location
|
|
106
120
|
self.metacell_mode = bool(metacell_mode)
|
|
107
121
|
self.gene_pos = None
|
|
108
122
|
self.collection_name = collection_name
|
|
109
123
|
if do_gene_pos:
|
|
110
|
-
|
|
111
|
-
print("seeing a string: loading gene positions as biomart parquet file")
|
|
112
|
-
biomart = pd.read_parquet(do_gene_pos)
|
|
113
|
-
else:
|
|
114
|
-
# and annotations
|
|
115
|
-
if organisms != ["NCBITaxon:9606"]:
|
|
116
|
-
raise ValueError(
|
|
117
|
-
"need to provide your own table as this automated function only works for humans for now"
|
|
118
|
-
)
|
|
119
|
-
biomart = getBiomartTable(
|
|
120
|
-
attributes=["start_position", "chromosome_name"],
|
|
121
|
-
useCache=True,
|
|
122
|
-
).set_index("ensembl_gene_id")
|
|
123
|
-
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
124
|
-
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
125
|
-
c = []
|
|
126
|
-
i = 0
|
|
127
|
-
prev_position = -100000
|
|
128
|
-
prev_chromosome = None
|
|
129
|
-
for _, r in biomart.iterrows():
|
|
130
|
-
if (
|
|
131
|
-
r["chromosome_name"] != prev_chromosome
|
|
132
|
-
or r["start_position"] - prev_position > gene_position_tolerance
|
|
133
|
-
):
|
|
134
|
-
i += 1
|
|
135
|
-
c.append(i)
|
|
136
|
-
prev_position = r["start_position"]
|
|
137
|
-
prev_chromosome = r["chromosome_name"]
|
|
138
|
-
print(f"reduced the size to {len(set(c)) / len(biomart)}")
|
|
139
|
-
biomart["pos"] = c
|
|
124
|
+
biomart = pd.read_parquet(do_gene_pos)
|
|
140
125
|
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
|
|
141
126
|
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
142
|
-
|
|
143
127
|
if gene_embeddings != "":
|
|
144
128
|
mdataset.genedf = mdataset.genedf.join(
|
|
145
|
-
pd.read_parquet(gene_embeddings), how="inner"
|
|
129
|
+
pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
|
|
146
130
|
)
|
|
147
131
|
if do_gene_pos:
|
|
148
132
|
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
@@ -151,7 +135,7 @@ class DataModule(L.LightningDataModule):
|
|
|
151
135
|
# we might want to not introduce zeros and
|
|
152
136
|
if use_default_col:
|
|
153
137
|
kwargs["collate_fn"] = Collator(
|
|
154
|
-
organisms=organisms,
|
|
138
|
+
organisms=mdataset.organisms,
|
|
155
139
|
how=how,
|
|
156
140
|
valid_genes=mdataset.genedf.index.tolist(),
|
|
157
141
|
max_len=max_len,
|
|
@@ -159,7 +143,7 @@ class DataModule(L.LightningDataModule):
|
|
|
159
143
|
org_to_id=mdataset.encoder[organism_name],
|
|
160
144
|
tp_name=tp_name,
|
|
161
145
|
organism_name=organism_name,
|
|
162
|
-
class_names=
|
|
146
|
+
class_names=list(self.classes.keys()),
|
|
163
147
|
)
|
|
164
148
|
self.validation_split = validation_split
|
|
165
149
|
self.test_split = test_split
|
|
@@ -171,16 +155,19 @@ class DataModule(L.LightningDataModule):
|
|
|
171
155
|
self.assays_to_drop = assays_to_drop
|
|
172
156
|
self.n_samples = len(mdataset)
|
|
173
157
|
self.weight_scaler = weight_scaler
|
|
174
|
-
self.
|
|
158
|
+
self.n_samples_per_epoch = n_samples_per_epoch
|
|
175
159
|
self.clss_to_weight = clss_to_weight
|
|
176
160
|
self.train_weights = None
|
|
177
161
|
self.train_labels = None
|
|
178
|
-
self.
|
|
162
|
+
self.sampler_workers = sampler_workers
|
|
163
|
+
self.sampler_chunk_size = sampler_chunk_size
|
|
164
|
+
self.store_location = store_location
|
|
179
165
|
self.nnz = None
|
|
180
|
-
self.restart_num = 0
|
|
181
166
|
self.test_datasets = []
|
|
167
|
+
self.force_recompute_indices = force_recompute_indices
|
|
182
168
|
self.test_idx = []
|
|
183
169
|
super().__init__()
|
|
170
|
+
print("finished init")
|
|
184
171
|
|
|
185
172
|
def __repr__(self):
|
|
186
173
|
return (
|
|
@@ -190,7 +177,7 @@ class DataModule(L.LightningDataModule):
|
|
|
190
177
|
f"\ttest_split={self.test_split},\n"
|
|
191
178
|
f"\tn_samples={self.n_samples},\n"
|
|
192
179
|
f"\tweight_scaler={self.weight_scaler},\n"
|
|
193
|
-
f"\
|
|
180
|
+
f"\tn_samples_per_epoch={self.n_samples_per_epoch},\n"
|
|
194
181
|
f"\tassays_to_drop={self.assays_to_drop},\n"
|
|
195
182
|
f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
|
|
196
183
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
@@ -242,6 +229,44 @@ class DataModule(L.LightningDataModule):
|
|
|
242
229
|
"""
|
|
243
230
|
return self.dataset.genedf.index.tolist()
|
|
244
231
|
|
|
232
|
+
@genes.setter
|
|
233
|
+
def genes(self, genes):
|
|
234
|
+
self.dataset.genedf = self.dataset.genedf.loc[genes]
|
|
235
|
+
self.kwargs["collate_fn"].genes = genes
|
|
236
|
+
self.kwargs["collate_fn"]._setup(
|
|
237
|
+
genedf=self.dataset.genedf,
|
|
238
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
239
|
+
valid_genes=genes,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def encoders(self):
|
|
244
|
+
return self.dataset.encoder
|
|
245
|
+
|
|
246
|
+
@encoders.setter
|
|
247
|
+
def encoders(self, encoders):
|
|
248
|
+
self.dataset.encoder = encoders
|
|
249
|
+
self.kwargs["collate_fn"].org_to_id = encoders[
|
|
250
|
+
self.kwargs["collate_fn"].organism_name
|
|
251
|
+
]
|
|
252
|
+
self.kwargs["collate_fn"]._setup(
|
|
253
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
254
|
+
valid_genes=self.genes,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def organisms(self):
|
|
259
|
+
return self.dataset.organisms
|
|
260
|
+
|
|
261
|
+
@organisms.setter
|
|
262
|
+
def organisms(self, organisms):
|
|
263
|
+
self.dataset.organisms = organisms
|
|
264
|
+
self.kwargs["collate_fn"].organisms = organisms
|
|
265
|
+
self.kwargs["collate_fn"]._setup(
|
|
266
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
267
|
+
valid_genes=self.genes,
|
|
268
|
+
)
|
|
269
|
+
|
|
245
270
|
@property
|
|
246
271
|
def num_datasets(self):
|
|
247
272
|
return len(self.dataset.mapped_dataset.storages)
|
|
@@ -256,106 +281,191 @@ class DataModule(L.LightningDataModule):
|
|
|
256
281
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
257
282
|
"""
|
|
258
283
|
SCALE = 10
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
).min()
|
|
266
|
-
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
267
|
-
weights, labels = self.dataset.get_label_weights(
|
|
268
|
-
self.clss_to_weight,
|
|
269
|
-
scaler=self.weight_scaler,
|
|
270
|
-
return_categories=True,
|
|
284
|
+
print("setting up the datamodule")
|
|
285
|
+
start_time = time.time()
|
|
286
|
+
if (
|
|
287
|
+
self.store_location is None
|
|
288
|
+
or not os.path.exists(
|
|
289
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
271
290
|
)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
self.
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
291
|
+
or self.force_recompute_indices
|
|
292
|
+
):
|
|
293
|
+
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
294
|
+
self.nnz = self.dataset.mapped_dataset.get_merged_labels(
|
|
295
|
+
"nnz", is_cat=False
|
|
296
|
+
)
|
|
297
|
+
self.clss_to_weight.remove("nnz")
|
|
298
|
+
(
|
|
299
|
+
(self.nnz.max() / SCALE)
|
|
300
|
+
/ ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
|
|
301
|
+
).min()
|
|
302
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
303
|
+
weights, labels = self.dataset.get_label_weights(
|
|
304
|
+
self.clss_to_weight,
|
|
305
|
+
scaler=self.weight_scaler,
|
|
306
|
+
return_categories=True,
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
weights = np.ones(1)
|
|
310
|
+
labels = np.zeros(self.n_samples, dtype=int)
|
|
311
|
+
if isinstance(self.validation_split, int):
|
|
312
|
+
len_valid = self.validation_split
|
|
313
|
+
else:
|
|
314
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
315
|
+
if isinstance(self.test_split, int):
|
|
316
|
+
len_test = self.test_split
|
|
317
|
+
else:
|
|
318
|
+
len_test = int(self.n_samples * self.test_split)
|
|
319
|
+
assert len_test + len_valid < self.n_samples, (
|
|
320
|
+
"test set + valid set size is configured to be larger than entire dataset."
|
|
303
321
|
)
|
|
304
|
-
cs = 0
|
|
305
|
-
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
306
|
-
if cs + c > len_test:
|
|
307
|
-
break
|
|
308
|
-
else:
|
|
309
|
-
self.test_datasets.append(
|
|
310
|
-
self.dataset.mapped_dataset.path_list[i].path
|
|
311
|
-
)
|
|
312
|
-
cs += c
|
|
313
|
-
len_test = cs
|
|
314
|
-
self.test_idx = idx_full[:len_test]
|
|
315
|
-
idx_full = idx_full[len_test:]
|
|
316
|
-
else:
|
|
317
|
-
self.test_idx = None
|
|
318
322
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
323
|
+
idx_full = []
|
|
324
|
+
if len(self.assays_to_drop) > 0:
|
|
325
|
+
badloc = np.isin(
|
|
326
|
+
self.dataset.mapped_dataset.get_merged_labels(
|
|
327
|
+
"assay_ontology_term_id"
|
|
328
|
+
),
|
|
329
|
+
self.assays_to_drop,
|
|
330
|
+
)
|
|
331
|
+
idx_full = np.arange(len(labels))[~badloc]
|
|
332
|
+
else:
|
|
333
|
+
idx_full = np.arange(self.n_samples)
|
|
334
|
+
if len_test > 0:
|
|
335
|
+
# this way we work on some never seen datasets
|
|
336
|
+
# keeping at least one
|
|
337
|
+
len_test = (
|
|
338
|
+
len_test
|
|
339
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
340
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
341
|
+
)
|
|
342
|
+
cs = 0
|
|
343
|
+
d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
|
|
344
|
+
random.Random(42).shuffle(d_size) # always same order
|
|
345
|
+
for i, c in d_size:
|
|
346
|
+
if cs + c > len_test:
|
|
347
|
+
break
|
|
348
|
+
else:
|
|
349
|
+
self.test_datasets.append(
|
|
350
|
+
self.dataset.mapped_dataset.path_list[i].path
|
|
351
|
+
)
|
|
352
|
+
cs += c
|
|
353
|
+
len_test = cs
|
|
354
|
+
self.test_idx = idx_full[:len_test]
|
|
355
|
+
idx_full = idx_full[len_test:]
|
|
356
|
+
else:
|
|
357
|
+
self.test_idx = None
|
|
358
|
+
|
|
359
|
+
np.random.shuffle(idx_full)
|
|
360
|
+
if len_valid > 0:
|
|
361
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
362
|
+
# store it for later
|
|
363
|
+
idx_full = idx_full[len_valid:]
|
|
364
|
+
else:
|
|
365
|
+
self.valid_idx = None
|
|
366
|
+
weights = np.concatenate([weights, np.zeros(1)])
|
|
367
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
368
|
+
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
369
|
+
# this means that the associated weights should be 0.
|
|
370
|
+
# by doing np.bincount(labels)*weights this will be taken into account
|
|
371
|
+
self.train_weights = weights
|
|
372
|
+
self.train_labels = labels
|
|
373
|
+
self.idx_full = idx_full
|
|
374
|
+
if self.store_location is not None:
|
|
375
|
+
if (
|
|
376
|
+
not os.path.exists(
|
|
377
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
378
|
+
)
|
|
379
|
+
or self.force_recompute_indices
|
|
380
|
+
):
|
|
381
|
+
os.makedirs(self.store_location, exist_ok=True)
|
|
382
|
+
if self.nnz is not None:
|
|
383
|
+
np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
|
|
384
|
+
np.save(
|
|
385
|
+
os.path.join(self.store_location, "train_weights.npy"),
|
|
386
|
+
self.train_weights,
|
|
387
|
+
)
|
|
388
|
+
np.save(
|
|
389
|
+
os.path.join(self.store_location, "train_labels.npy"),
|
|
390
|
+
self.train_labels,
|
|
391
|
+
)
|
|
392
|
+
np.save(
|
|
393
|
+
os.path.join(self.store_location, "idx_full.npy"), self.idx_full
|
|
394
|
+
)
|
|
395
|
+
if self.test_idx is not None:
|
|
396
|
+
np.save(
|
|
397
|
+
os.path.join(self.store_location, "test_idx.npy"), self.test_idx
|
|
398
|
+
)
|
|
399
|
+
if self.valid_idx is not None:
|
|
400
|
+
np.save(
|
|
401
|
+
os.path.join(self.store_location, "valid_idx.npy"),
|
|
402
|
+
self.valid_idx,
|
|
403
|
+
)
|
|
404
|
+
listToFile(
|
|
405
|
+
self.test_datasets,
|
|
406
|
+
os.path.join(self.store_location, "test_datasets.txt"),
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
self.nnz = (
|
|
410
|
+
np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
|
|
411
|
+
if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
|
|
412
|
+
else None
|
|
413
|
+
)
|
|
414
|
+
self.train_weights = np.load(
|
|
415
|
+
os.path.join(self.store_location, "train_weights.npy")
|
|
416
|
+
)
|
|
417
|
+
self.train_labels = np.load(
|
|
418
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
419
|
+
)
|
|
420
|
+
self.idx_full = np.load(
|
|
421
|
+
os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
|
|
422
|
+
)
|
|
423
|
+
self.test_idx = (
|
|
424
|
+
np.load(os.path.join(self.store_location, "test_idx.npy"))
|
|
425
|
+
if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
|
|
426
|
+
else None
|
|
427
|
+
)
|
|
428
|
+
self.valid_idx = (
|
|
429
|
+
np.load(os.path.join(self.store_location, "valid_idx.npy"))
|
|
430
|
+
if os.path.exists(
|
|
431
|
+
os.path.join(self.store_location, "valid_idx.npy")
|
|
432
|
+
)
|
|
433
|
+
else None
|
|
434
|
+
)
|
|
435
|
+
self.test_datasets = fileToList(
|
|
436
|
+
os.path.join(self.store_location, "test_datasets.txt")
|
|
437
|
+
)
|
|
438
|
+
print("loaded from store")
|
|
439
|
+
print(f"done setup, took {time.time() - start_time:.2f} seconds")
|
|
334
440
|
return self.test_datasets
|
|
335
441
|
|
|
336
442
|
def train_dataloader(self, **kwargs):
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
443
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
444
|
+
try:
|
|
445
|
+
print("Setting up the parallel train sampler...")
|
|
446
|
+
# Create the optimized parallel sampler
|
|
447
|
+
print(f"Using {self.sampler_workers} workers for class indexing")
|
|
448
|
+
train_sampler = LabelWeightedSampler(
|
|
449
|
+
label_weights=self.train_weights,
|
|
450
|
+
labels=self.train_labels,
|
|
451
|
+
num_samples=int(self.n_samples_per_epoch),
|
|
452
|
+
element_weights=self.nnz,
|
|
453
|
+
replacement=self.replacement,
|
|
454
|
+
n_workers=self.sampler_workers,
|
|
455
|
+
chunk_size=self.sampler_chunk_size,
|
|
456
|
+
store_location=self.store_location,
|
|
457
|
+
force_recompute_indices=self.force_recompute_indices,
|
|
458
|
+
)
|
|
459
|
+
except ValueError as e:
|
|
460
|
+
raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
|
|
461
|
+
else:
|
|
462
|
+
train_sampler = SubsetRandomSampler(self.idx_full)
|
|
463
|
+
current_loader_kwargs = kwargs.copy()
|
|
464
|
+
current_loader_kwargs.update(self.kwargs)
|
|
354
465
|
return DataLoader(
|
|
355
466
|
self.dataset,
|
|
356
467
|
sampler=train_sampler,
|
|
357
|
-
**
|
|
358
|
-
**kwargs,
|
|
468
|
+
**current_loader_kwargs,
|
|
359
469
|
)
|
|
360
470
|
|
|
361
471
|
def val_dataloader(self):
|
|
@@ -385,115 +495,335 @@ class DataModule(L.LightningDataModule):
|
|
|
385
495
|
**self.kwargs,
|
|
386
496
|
)
|
|
387
497
|
|
|
388
|
-
# def teardown(self):
|
|
389
|
-
# clean up state after the trainer stops, delete files...
|
|
390
|
-
# called on every process in DDP
|
|
391
|
-
# pass
|
|
392
|
-
|
|
393
498
|
|
|
394
499
|
class LabelWeightedSampler(Sampler[int]):
|
|
395
|
-
|
|
396
|
-
|
|
500
|
+
"""
|
|
501
|
+
A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
|
|
502
|
+
|
|
503
|
+
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
504
|
+
1. Parallel building of class indices
|
|
505
|
+
2. Chunked processing for large arrays
|
|
506
|
+
3. Efficient memory management
|
|
507
|
+
4. Proper handling of replacement and non-replacement sampling
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
label_weights: torch.Tensor
|
|
511
|
+
klass_indices: dict[int, torch.Tensor]
|
|
397
512
|
num_samples: int
|
|
398
|
-
|
|
513
|
+
element_weights: Optional[torch.Tensor]
|
|
399
514
|
replacement: bool
|
|
400
|
-
restart_num: int
|
|
401
|
-
modify_seed_on_requeue: bool
|
|
402
|
-
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
403
|
-
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
404
515
|
|
|
405
516
|
def __init__(
|
|
406
517
|
self,
|
|
407
518
|
label_weights: Sequence[float],
|
|
408
|
-
labels:
|
|
519
|
+
labels: np.ndarray,
|
|
409
520
|
num_samples: int,
|
|
410
521
|
replacement: bool = True,
|
|
411
|
-
element_weights: Sequence[float] = None,
|
|
412
|
-
|
|
413
|
-
|
|
522
|
+
element_weights: Optional[Sequence[float]] = None,
|
|
523
|
+
n_workers: int = None,
|
|
524
|
+
chunk_size: int = None, # Process 10M elements per chunk
|
|
525
|
+
store_location: str = None,
|
|
526
|
+
force_recompute_indices: bool = False,
|
|
414
527
|
) -> None:
|
|
415
528
|
"""
|
|
529
|
+
Initialize the sampler with parallel processing for large datasets.
|
|
416
530
|
|
|
417
|
-
:
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
531
|
+
Args:
|
|
532
|
+
label_weights: Weights for each class (length = number of classes)
|
|
533
|
+
labels: Class label for each dataset element (length = dataset size)
|
|
534
|
+
num_samples: Number of samples to draw
|
|
535
|
+
replacement: Whether to sample with replacement
|
|
536
|
+
element_weights: Optional weights for each element within classes
|
|
537
|
+
n_workers: Number of parallel workers to use (default: number of CPUs-1)
|
|
538
|
+
chunk_size: Size of chunks to process in parallel (default: 10M elements)
|
|
421
539
|
"""
|
|
422
|
-
|
|
540
|
+
print("Initializing optimized parallel weighted sampler...")
|
|
423
541
|
super(LabelWeightedSampler, self).__init__(None)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
self.
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
542
|
+
|
|
543
|
+
# Compute label weights (incorporating class frequencies)
|
|
544
|
+
# Directly use labels as numpy array without conversion
|
|
545
|
+
label_weights = np.asarray(label_weights) * np.bincount(labels)
|
|
546
|
+
self.label_weights = torch.as_tensor(
|
|
547
|
+
label_weights, dtype=torch.float32
|
|
548
|
+
).share_memory_()
|
|
549
|
+
|
|
550
|
+
# Store element weights if provided
|
|
551
|
+
if element_weights is not None:
|
|
552
|
+
self.element_weights = torch.as_tensor(
|
|
553
|
+
element_weights, dtype=torch.float32
|
|
554
|
+
).share_memory_()
|
|
555
|
+
else:
|
|
556
|
+
self.element_weights = None
|
|
557
|
+
|
|
434
558
|
self.replacement = replacement
|
|
435
559
|
self.num_samples = num_samples
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
560
|
+
if (
|
|
561
|
+
store_location is None
|
|
562
|
+
or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
|
|
563
|
+
or force_recompute_indices
|
|
564
|
+
):
|
|
565
|
+
# Set number of workers (default to CPU count - 1, but at least 1)
|
|
566
|
+
if n_workers is None:
|
|
567
|
+
# Check if running on SLURM
|
|
568
|
+
n_workers = min(20, max(1, mp.cpu_count() - 1))
|
|
569
|
+
if "SLURM_CPUS_PER_TASK" in os.environ:
|
|
570
|
+
n_workers = min(
|
|
571
|
+
20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Try to auto-determine optimal chunk size based on memory
|
|
575
|
+
if chunk_size is None:
|
|
576
|
+
try:
|
|
577
|
+
import psutil
|
|
578
|
+
|
|
579
|
+
# Check if running on SLURM
|
|
580
|
+
available_memory = psutil.virtual_memory().available
|
|
581
|
+
for name in [
|
|
582
|
+
"SLURM_MEM_PER_NODE",
|
|
583
|
+
"SLURM_MEM_PER_CPU",
|
|
584
|
+
"SLURM_MEM_PER_GPU",
|
|
585
|
+
"SLURM_MEM_PER_TASK",
|
|
586
|
+
]:
|
|
587
|
+
if name in os.environ:
|
|
588
|
+
available_memory = (
|
|
589
|
+
int(os.environ[name]) * 1024 * 1024
|
|
590
|
+
) # Convert MB to bytes
|
|
591
|
+
break
|
|
592
|
+
|
|
593
|
+
# Use at most 50% of available memory across all workers
|
|
594
|
+
memory_per_worker = 0.5 * available_memory / n_workers
|
|
595
|
+
# Rough estimate: each label takes 4 bytes, each index 8 bytes
|
|
596
|
+
bytes_per_element = 12
|
|
597
|
+
chunk_size = min(
|
|
598
|
+
max(100_000, int(memory_per_worker / bytes_per_element / 3)),
|
|
599
|
+
2_000_000,
|
|
600
|
+
)
|
|
601
|
+
print(f"Auto-determined chunk size: {chunk_size:,} elements")
|
|
602
|
+
except (ImportError, KeyError):
|
|
603
|
+
chunk_size = 2_000_000
|
|
604
|
+
print(f"Using default chunk size: {chunk_size:,} elements")
|
|
605
|
+
|
|
606
|
+
# Parallelize the class indices building
|
|
607
|
+
print(f"Building class indices in parallel with {n_workers} workers...")
|
|
608
|
+
klass_indices = self._build_class_indices_parallel(
|
|
609
|
+
labels, chunk_size, n_workers
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Convert klass_indices to a single tensor and offset vector
|
|
613
|
+
all_indices = []
|
|
614
|
+
offsets = []
|
|
615
|
+
current_offset = 0
|
|
616
|
+
|
|
617
|
+
# Sort keys to ensure consistent ordering
|
|
618
|
+
keys = klass_indices.keys()
|
|
619
|
+
|
|
620
|
+
# Build concatenated tensor and track offsets
|
|
621
|
+
for i in range(max(keys) + 1):
|
|
622
|
+
offsets.append(current_offset)
|
|
623
|
+
if i in keys:
|
|
624
|
+
indices = klass_indices[i]
|
|
625
|
+
all_indices.append(indices)
|
|
626
|
+
current_offset += len(indices)
|
|
627
|
+
|
|
628
|
+
# Convert to tensors
|
|
629
|
+
self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
|
|
630
|
+
self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
|
|
631
|
+
if store_location is not None:
|
|
632
|
+
store_path = os.path.join(store_location, "klass_indices.pt")
|
|
633
|
+
if os.path.exists(store_path) and not force_recompute_indices:
|
|
634
|
+
self.klass_indices = torch.load(store_path).share_memory_()
|
|
635
|
+
self.klass_offsets = torch.load(
|
|
636
|
+
store_path.replace(".pt", "_offsets.pt")
|
|
637
|
+
).share_memory_()
|
|
638
|
+
print(f"Loaded sampler indices from {store_path}")
|
|
639
|
+
else:
|
|
640
|
+
torch.save(self.klass_indices, store_path)
|
|
641
|
+
torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
|
|
642
|
+
print(f"Saved sampler indices to {store_path}")
|
|
643
|
+
print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
|
|
444
644
|
|
|
445
645
|
def __iter__(self):
|
|
646
|
+
# Sample classes according to their weights
|
|
647
|
+
print("sampling a new batch of size", self.num_samples)
|
|
648
|
+
|
|
446
649
|
sample_labels = torch.multinomial(
|
|
447
650
|
self.label_weights,
|
|
448
651
|
num_samples=self.num_samples,
|
|
449
652
|
replacement=True,
|
|
450
|
-
generator=None
|
|
451
|
-
if self.restart_num == 0 and not self.modify_seed_on_requeue
|
|
452
|
-
else torch.Generator().manual_seed(self.restart_num),
|
|
453
653
|
)
|
|
454
|
-
|
|
455
|
-
|
|
654
|
+
# Get counts of each class in sample_labels
|
|
655
|
+
unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
|
|
656
|
+
|
|
657
|
+
# Initialize result tensor
|
|
658
|
+
result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
|
|
659
|
+
|
|
660
|
+
# Process only the classes that were actually sampled
|
|
661
|
+
for i, (label, count) in tqdm(
|
|
662
|
+
enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
|
|
663
|
+
total=len(unique_samples),
|
|
664
|
+
desc="Processing classes in sampler",
|
|
665
|
+
):
|
|
666
|
+
klass_index = self.klass_indices[
|
|
667
|
+
self.klass_offsets[label] : self.klass_offsets[label + 1]
|
|
668
|
+
]
|
|
669
|
+
|
|
456
670
|
if klass_index.numel() == 0:
|
|
457
671
|
continue
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
continue
|
|
672
|
+
|
|
673
|
+
# Sample elements from this class
|
|
461
674
|
if self.element_weights is not None:
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
675
|
+
# This is a critical point for memory
|
|
676
|
+
current_element_weights_slice = self.element_weights[klass_index]
|
|
677
|
+
|
|
678
|
+
if self.replacement:
|
|
679
|
+
right_inds = torch.multinomial(
|
|
680
|
+
current_element_weights_slice,
|
|
681
|
+
num_samples=count,
|
|
682
|
+
replacement=True,
|
|
683
|
+
)
|
|
684
|
+
else:
|
|
685
|
+
num_to_sample = min(count, len(klass_index))
|
|
686
|
+
right_inds = torch.multinomial(
|
|
687
|
+
current_element_weights_slice,
|
|
688
|
+
num_samples=num_to_sample,
|
|
689
|
+
replacement=False,
|
|
690
|
+
)
|
|
472
691
|
elif self.replacement:
|
|
473
|
-
right_inds = torch.randint(
|
|
474
|
-
len(klass_index),
|
|
475
|
-
size=(len(left_inds),),
|
|
476
|
-
generator=None
|
|
477
|
-
if self.restart_num == 0 and not self.modify_seed_on_requeue
|
|
478
|
-
else torch.Generator().manual_seed(self.restart_num),
|
|
479
|
-
)
|
|
692
|
+
right_inds = torch.randint(len(klass_index), size=(count,))
|
|
480
693
|
else:
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
694
|
+
num_to_sample = min(count, len(klass_index))
|
|
695
|
+
right_inds = torch.randperm(len(klass_index))[:num_to_sample]
|
|
696
|
+
|
|
697
|
+
# Get actual indices
|
|
698
|
+
sampled_indices = klass_index[right_inds]
|
|
699
|
+
result_indices_list.append(sampled_indices)
|
|
700
|
+
|
|
701
|
+
# Combine all indices
|
|
702
|
+
if result_indices_list: # Check if the list is not empty
|
|
703
|
+
final_result_indices = torch.cat(
|
|
704
|
+
result_indices_list
|
|
705
|
+
) # Use the list with the appended new name
|
|
706
|
+
|
|
707
|
+
# Shuffle the combined indices
|
|
708
|
+
shuffled_indices = final_result_indices[
|
|
709
|
+
torch.randperm(len(final_result_indices))
|
|
710
|
+
]
|
|
711
|
+
self.num_samples = len(shuffled_indices)
|
|
712
|
+
yield from shuffled_indices.tolist()
|
|
713
|
+
else:
|
|
714
|
+
self.num_samples = 0
|
|
715
|
+
yield from iter([])
|
|
497
716
|
|
|
498
717
|
def __len__(self):
|
|
499
718
|
return self.num_samples
|
|
719
|
+
|
|
720
|
+
def _merge_chunk_results(self, results_list):
|
|
721
|
+
"""Merge results from multiple chunks into a single dictionary.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
results_list: list of dictionaries mapping class labels to index arrays
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
merged dictionary with PyTorch tensors
|
|
728
|
+
"""
|
|
729
|
+
merged = {}
|
|
730
|
+
|
|
731
|
+
# Collect all labels across all chunks
|
|
732
|
+
all_labels = set()
|
|
733
|
+
for chunk_result in results_list:
|
|
734
|
+
all_labels.update(chunk_result.keys())
|
|
735
|
+
|
|
736
|
+
# For each unique label
|
|
737
|
+
for label in all_labels:
|
|
738
|
+
# Collect indices from all chunks where this label appears
|
|
739
|
+
indices_lists = [
|
|
740
|
+
chunk_result[label]
|
|
741
|
+
for chunk_result in results_list
|
|
742
|
+
if label in chunk_result
|
|
743
|
+
]
|
|
744
|
+
|
|
745
|
+
if indices_lists:
|
|
746
|
+
# Concatenate all indices for this label
|
|
747
|
+
merged[label] = torch.tensor(
|
|
748
|
+
np.concatenate(indices_lists), dtype=torch.long
|
|
749
|
+
)
|
|
750
|
+
else:
|
|
751
|
+
merged[label] = torch.tensor([], dtype=torch.long)
|
|
752
|
+
|
|
753
|
+
return merged
|
|
754
|
+
|
|
755
|
+
def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
|
|
756
|
+
"""Build class indices in parallel across multiple workers.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
labels: array of class labels
|
|
760
|
+
n_workers: number of parallel workers
|
|
761
|
+
chunk_size: size of chunks to process
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
dictionary mapping class labels to tensors of indices
|
|
765
|
+
"""
|
|
766
|
+
n = len(labels)
|
|
767
|
+
results = []
|
|
768
|
+
# Create chunks of the labels array with proper sizing
|
|
769
|
+
n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
|
|
770
|
+
print(f"Processing {n:,} elements in {n_chunks} chunks...")
|
|
771
|
+
|
|
772
|
+
# Process in chunks to limit memory usage
|
|
773
|
+
with ProcessPoolExecutor(
|
|
774
|
+
max_workers=n_workers, mp_context=mp.get_context("spawn")
|
|
775
|
+
) as executor:
|
|
776
|
+
# Submit chunks for processing
|
|
777
|
+
futures = []
|
|
778
|
+
for i in range(n_chunks):
|
|
779
|
+
start_idx = i * chunk_size
|
|
780
|
+
end_idx = min((i + 1) * chunk_size, n)
|
|
781
|
+
# We pass only chunk boundaries, not the data itself
|
|
782
|
+
# This avoids unnecessary copies during process creation
|
|
783
|
+
futures.append(
|
|
784
|
+
executor.submit(
|
|
785
|
+
self._process_chunk_with_slice,
|
|
786
|
+
(start_idx, end_idx, labels),
|
|
787
|
+
)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Collect results as they complete with progress reporting
|
|
791
|
+
for future in tqdm(
|
|
792
|
+
as_completed(futures), total=len(futures), desc="Processing chunks"
|
|
793
|
+
):
|
|
794
|
+
results.append(future.result())
|
|
795
|
+
|
|
796
|
+
# Merge results from all chunks
|
|
797
|
+
print("Merging results from all chunks...")
|
|
798
|
+
merged_results = self._merge_chunk_results(results)
|
|
799
|
+
|
|
800
|
+
return merged_results
|
|
801
|
+
|
|
802
|
+
def _process_chunk_with_slice(self, slice_info):
|
|
803
|
+
"""Process a slice of the labels array by indices.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
slice_info: tuple of (start_idx, end_idx, labels_array) where
|
|
807
|
+
start_idx and end_idx define the slice to process
|
|
808
|
+
|
|
809
|
+
Returns:
|
|
810
|
+
dict mapping class labels to arrays of indices
|
|
811
|
+
"""
|
|
812
|
+
start_idx, end_idx, labels_array = slice_info
|
|
813
|
+
|
|
814
|
+
# We're processing a slice of the original array
|
|
815
|
+
labels_slice = labels_array[start_idx:end_idx]
|
|
816
|
+
chunk_indices = {}
|
|
817
|
+
|
|
818
|
+
# Create a direct map of indices
|
|
819
|
+
indices = np.arange(start_idx, end_idx)
|
|
820
|
+
|
|
821
|
+
# Get unique labels in this slice for more efficient processing
|
|
822
|
+
unique_labels = np.unique(labels_slice)
|
|
823
|
+
# For each valid label, find its indices
|
|
824
|
+
for label in unique_labels:
|
|
825
|
+
# Find positions where this label appears (using direct boolean indexing)
|
|
826
|
+
label_mask = labels_slice == label
|
|
827
|
+
chunk_indices[int(label)] = indices[label_mask]
|
|
828
|
+
|
|
829
|
+
return chunk_indices
|