scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +76 -78
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +134 -77
- scdataloader/datamodule.py +638 -245
- scdataloader/mapped.py +104 -43
- scdataloader/preprocess.py +136 -110
- scdataloader/utils.py +158 -52
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/METADATA +6 -7
- scdataloader-2.0.2.dist-info/RECORD +16 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/WHEEL +1 -1
- scdataloader-2.0.2.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/entry_points.txt +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,51 +1,55 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import multiprocessing as mp
|
|
1
3
|
import os
|
|
2
|
-
|
|
4
|
+
import random
|
|
5
|
+
import time
|
|
6
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import List, Optional, Sequence, Union
|
|
3
9
|
|
|
4
10
|
import lamindb as ln
|
|
5
11
|
import lightning as L
|
|
6
12
|
import numpy as np
|
|
7
13
|
import pandas as pd
|
|
8
14
|
import torch
|
|
9
|
-
from torch.utils.data import DataLoader, Sampler
|
|
15
|
+
from torch.utils.data import DataLoader, Sampler, Subset
|
|
10
16
|
from torch.utils.data.sampler import (
|
|
11
17
|
RandomSampler,
|
|
12
18
|
SequentialSampler,
|
|
13
19
|
SubsetRandomSampler,
|
|
14
20
|
WeightedRandomSampler,
|
|
15
21
|
)
|
|
22
|
+
from tqdm import tqdm
|
|
16
23
|
|
|
17
24
|
from .collator import Collator
|
|
18
25
|
from .data import Dataset
|
|
19
|
-
from .utils import getBiomartTable
|
|
26
|
+
from .utils import fileToList, getBiomartTable, listToFile
|
|
20
27
|
|
|
21
28
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
29
|
+
NNZ_SCALE = 1000
|
|
22
30
|
|
|
23
31
|
|
|
24
32
|
class DataModule(L.LightningDataModule):
|
|
25
33
|
def __init__(
|
|
26
34
|
self,
|
|
27
35
|
collection_name: str,
|
|
28
|
-
clss_to_weight:
|
|
29
|
-
organisms: list = ["NCBITaxon:9606"],
|
|
36
|
+
clss_to_weight: List[str] = ["organism_ontology_term_id"],
|
|
30
37
|
weight_scaler: int = 10,
|
|
31
38
|
n_samples_per_epoch: int = 2_000_000,
|
|
32
39
|
validation_split: float = 0.2,
|
|
33
40
|
test_split: float = 0,
|
|
34
|
-
gene_embeddings: str = "",
|
|
35
41
|
use_default_col: bool = True,
|
|
36
|
-
gene_position_tolerance: int = 10_000,
|
|
37
42
|
# this is for the mappedCollection
|
|
38
|
-
clss_to_predict:
|
|
39
|
-
hierarchical_clss:
|
|
43
|
+
clss_to_predict: List[str] = ["organism_ontology_term_id"],
|
|
44
|
+
hierarchical_clss: List[str] = [],
|
|
40
45
|
# this is for the collator
|
|
41
46
|
how: str = "random expr",
|
|
42
|
-
|
|
47
|
+
organism_col: str = "organism_ontology_term_id",
|
|
43
48
|
max_len: int = 1000,
|
|
44
|
-
add_zero_genes: int = 100,
|
|
45
49
|
replacement: bool = True,
|
|
46
|
-
|
|
50
|
+
gene_subset: Optional[list[str]] = None,
|
|
47
51
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
48
|
-
assays_to_drop:
|
|
52
|
+
assays_to_drop: List[str] = [
|
|
49
53
|
# "EFO:0008853", #patch seq
|
|
50
54
|
# "EFO:0010961", # visium
|
|
51
55
|
"EFO:0030007", # ATACseq
|
|
@@ -53,6 +57,14 @@ class DataModule(L.LightningDataModule):
|
|
|
53
57
|
],
|
|
54
58
|
metacell_mode: float = 0.0,
|
|
55
59
|
get_knn_cells: bool = False,
|
|
60
|
+
store_location: str = None,
|
|
61
|
+
force_recompute_indices: bool = False,
|
|
62
|
+
sampler_workers: int = None,
|
|
63
|
+
sampler_chunk_size: int = None,
|
|
64
|
+
organisms: Optional[str] = None,
|
|
65
|
+
genedf: Optional[pd.DataFrame] = None,
|
|
66
|
+
n_bins: int = 0,
|
|
67
|
+
curiculum: int = 0,
|
|
56
68
|
**kwargs,
|
|
57
69
|
):
|
|
58
70
|
"""
|
|
@@ -64,101 +76,74 @@ class DataModule(L.LightningDataModule):
|
|
|
64
76
|
|
|
65
77
|
Args:
|
|
66
78
|
collection_name (str): The lamindb collection to be used.
|
|
67
|
-
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
68
79
|
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
69
80
|
n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
|
|
70
81
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
71
82
|
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
72
83
|
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
73
|
-
|
|
84
|
+
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
85
|
+
clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
86
|
+
assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
|
|
87
|
+
gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
|
|
74
88
|
the file must have ensembl_gene_id as index.
|
|
75
89
|
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
76
|
-
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
77
|
-
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
78
|
-
any genes within this distance of each other will be considered at the same position.
|
|
79
|
-
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
80
|
-
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
|
|
81
|
-
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
|
|
82
90
|
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
|
|
83
|
-
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
|
|
84
91
|
how (str, optional): The method to use for the collator. Defaults to "random expr".
|
|
85
|
-
|
|
92
|
+
organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
86
93
|
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
87
|
-
hierarchical_clss (
|
|
94
|
+
hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
|
|
88
95
|
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
89
|
-
clss_to_predict (
|
|
96
|
+
clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
90
97
|
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
|
|
98
|
+
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
99
|
+
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
100
|
+
sampler_workers (int, optional): The number of workers to use for the sampler. Defaults to None (auto-determined).
|
|
101
|
+
sampler_chunk_size (int, optional): The size of the chunks to use for the sampler. Defaults to None (auto-determined).
|
|
91
102
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
92
103
|
see @file data.py and @file collator.py for more details about some of the parameters
|
|
93
104
|
"""
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
organisms=organisms,
|
|
98
|
-
clss_to_predict=clss_to_predict,
|
|
99
|
-
hierarchical_clss=hierarchical_clss,
|
|
100
|
-
metacell_mode=metacell_mode,
|
|
101
|
-
get_knn_cells=get_knn_cells,
|
|
105
|
+
if "organism_ontology_term_id" not in clss_to_predict:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"need 'organism_ontology_term_id' in the set of classes at least"
|
|
102
108
|
)
|
|
109
|
+
if metacell_mode > 0 and get_knn_cells:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"cannot use metacell mode and get_knn_cells at the same time"
|
|
112
|
+
)
|
|
113
|
+
mdataset = Dataset(
|
|
114
|
+
ln.Collection.filter(key=collection_name, is_latest=True).first(),
|
|
115
|
+
clss_to_predict=clss_to_predict,
|
|
116
|
+
hierarchical_clss=hierarchical_clss,
|
|
117
|
+
metacell_mode=metacell_mode,
|
|
118
|
+
get_knn_cells=get_knn_cells,
|
|
119
|
+
store_location=store_location,
|
|
120
|
+
force_recompute_indices=force_recompute_indices,
|
|
121
|
+
genedf=genedf,
|
|
122
|
+
)
|
|
103
123
|
# and location
|
|
104
124
|
self.metacell_mode = bool(metacell_mode)
|
|
105
125
|
self.gene_pos = None
|
|
106
126
|
self.collection_name = collection_name
|
|
107
|
-
if
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
biomart = pd.read_parquet(do_gene_pos)
|
|
111
|
-
else:
|
|
112
|
-
# and annotations
|
|
113
|
-
if organisms != ["NCBITaxon:9606"]:
|
|
114
|
-
raise ValueError(
|
|
115
|
-
"need to provide your own table as this automated function only works for humans for now"
|
|
116
|
-
)
|
|
117
|
-
biomart = getBiomartTable(
|
|
118
|
-
attributes=["start_position", "chromosome_name"],
|
|
119
|
-
useCache=True,
|
|
120
|
-
).set_index("ensembl_gene_id")
|
|
121
|
-
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
122
|
-
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
123
|
-
c = []
|
|
124
|
-
i = 0
|
|
125
|
-
prev_position = -100000
|
|
126
|
-
prev_chromosome = None
|
|
127
|
-
for _, r in biomart.iterrows():
|
|
128
|
-
if (
|
|
129
|
-
r["chromosome_name"] != prev_chromosome
|
|
130
|
-
or r["start_position"] - prev_position > gene_position_tolerance
|
|
131
|
-
):
|
|
132
|
-
i += 1
|
|
133
|
-
c.append(i)
|
|
134
|
-
prev_position = r["start_position"]
|
|
135
|
-
prev_chromosome = r["chromosome_name"]
|
|
136
|
-
print(f"reduced the size to {len(set(c)) / len(biomart)}")
|
|
137
|
-
biomart["pos"] = c
|
|
138
|
-
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
|
|
139
|
-
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
140
|
-
|
|
141
|
-
if gene_embeddings != "":
|
|
142
|
-
mdataset.genedf = mdataset.genedf.join(
|
|
143
|
-
pd.read_parquet(gene_embeddings), how="inner"
|
|
144
|
-
)
|
|
145
|
-
if do_gene_pos:
|
|
146
|
-
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
127
|
+
if gene_subset is not None:
|
|
128
|
+
tokeep = set(mdataset.genedf.index.tolist())
|
|
129
|
+
gene_subset = [u for u in gene_subset if u in tokeep]
|
|
147
130
|
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
148
131
|
# we might want not to order the genes by expression (or do it?)
|
|
149
132
|
# we might want to not introduce zeros and
|
|
150
133
|
if use_default_col:
|
|
151
134
|
kwargs["collate_fn"] = Collator(
|
|
152
|
-
organisms=organisms,
|
|
135
|
+
organisms=mdataset.organisms if organisms is None else organisms,
|
|
153
136
|
how=how,
|
|
154
|
-
valid_genes=
|
|
137
|
+
valid_genes=gene_subset,
|
|
155
138
|
max_len=max_len,
|
|
156
|
-
|
|
157
|
-
org_to_id=mdataset.encoder[organism_name],
|
|
139
|
+
org_to_id=mdataset.encoder[organism_col],
|
|
158
140
|
tp_name=tp_name,
|
|
159
|
-
organism_name=
|
|
160
|
-
class_names=
|
|
141
|
+
organism_name=organism_col,
|
|
142
|
+
class_names=list(self.classes.keys()),
|
|
143
|
+
genedf=genedf,
|
|
144
|
+
n_bins=n_bins,
|
|
161
145
|
)
|
|
146
|
+
self.n_bins = n_bins
|
|
162
147
|
self.validation_split = validation_split
|
|
163
148
|
self.test_split = test_split
|
|
164
149
|
self.dataset = mdataset
|
|
@@ -173,10 +158,19 @@ class DataModule(L.LightningDataModule):
|
|
|
173
158
|
self.clss_to_weight = clss_to_weight
|
|
174
159
|
self.train_weights = None
|
|
175
160
|
self.train_labels = None
|
|
161
|
+
self.sampler_workers = sampler_workers
|
|
162
|
+
self.sampler_chunk_size = sampler_chunk_size
|
|
163
|
+
self.store_location = store_location
|
|
176
164
|
self.nnz = None
|
|
165
|
+
self.idx_full = None
|
|
166
|
+
self.max_len = max_len
|
|
177
167
|
self.test_datasets = []
|
|
168
|
+
self.force_recompute_indices = force_recompute_indices
|
|
169
|
+
self.curiculum = curiculum
|
|
170
|
+
self.valid_idx = []
|
|
178
171
|
self.test_idx = []
|
|
179
172
|
super().__init__()
|
|
173
|
+
print("finished init")
|
|
180
174
|
|
|
181
175
|
def __repr__(self):
|
|
182
176
|
return (
|
|
@@ -192,9 +186,11 @@ class DataModule(L.LightningDataModule):
|
|
|
192
186
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
193
187
|
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
194
188
|
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
195
|
-
+ (
|
|
196
|
-
|
|
197
|
-
|
|
189
|
+
+ (
|
|
190
|
+
"\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
|
|
191
|
+
if self.idx_full is not None
|
|
192
|
+
else ")"
|
|
193
|
+
)
|
|
198
194
|
)
|
|
199
195
|
|
|
200
196
|
@property
|
|
@@ -238,6 +234,49 @@ class DataModule(L.LightningDataModule):
|
|
|
238
234
|
"""
|
|
239
235
|
return self.dataset.genedf.index.tolist()
|
|
240
236
|
|
|
237
|
+
@property
|
|
238
|
+
def genes_dict(self):
|
|
239
|
+
return {
|
|
240
|
+
i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
|
|
241
|
+
for i in self.dataset.organisms
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
def set_valid_genes_collator(self, genes):
|
|
245
|
+
self.kwargs["collate_fn"]._setup(
|
|
246
|
+
# cannot use genedf there since I am purposefully decreasing it...
|
|
247
|
+
# genedf=self.dataset.genedf,
|
|
248
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
249
|
+
valid_genes=genes,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def encoders(self):
|
|
254
|
+
return self.dataset.encoder
|
|
255
|
+
|
|
256
|
+
@encoders.setter
|
|
257
|
+
def encoders(self, encoders):
|
|
258
|
+
self.dataset.encoder = encoders
|
|
259
|
+
self.kwargs["collate_fn"].org_to_id = encoders[
|
|
260
|
+
self.kwargs["collate_fn"].organism_name
|
|
261
|
+
]
|
|
262
|
+
self.kwargs["collate_fn"]._setup(
|
|
263
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
264
|
+
valid_genes=self.genes,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def organisms(self):
|
|
269
|
+
return self.dataset.organisms
|
|
270
|
+
|
|
271
|
+
@organisms.setter
|
|
272
|
+
def organisms(self, organisms):
|
|
273
|
+
self.dataset.organisms = organisms
|
|
274
|
+
self.kwargs["collate_fn"].organisms = organisms
|
|
275
|
+
self.kwargs["collate_fn"]._setup(
|
|
276
|
+
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
277
|
+
valid_genes=self.genes,
|
|
278
|
+
)
|
|
279
|
+
|
|
241
280
|
@property
|
|
242
281
|
def num_datasets(self):
|
|
243
282
|
return len(self.dataset.mapped_dataset.storages)
|
|
@@ -251,116 +290,194 @@ class DataModule(L.LightningDataModule):
|
|
|
251
290
|
stage (str, optional): The stage of the model training process.
|
|
252
291
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
253
292
|
"""
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
self.
|
|
258
|
-
(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
self.
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
else:
|
|
291
|
-
idx_full = np.arange(self.n_samples)
|
|
292
|
-
if len_test > 0:
|
|
293
|
-
# this way we work on some never seen datasets
|
|
294
|
-
# keeping at least one
|
|
295
|
-
len_test = (
|
|
296
|
-
len_test
|
|
297
|
-
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
298
|
-
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
293
|
+
print("setting up the datamodule")
|
|
294
|
+
start_time = time.time()
|
|
295
|
+
if (
|
|
296
|
+
self.store_location is None
|
|
297
|
+
or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
|
|
298
|
+
or self.force_recompute_indices
|
|
299
|
+
):
|
|
300
|
+
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
301
|
+
self.nnz = self.dataset.mapped_dataset.get_merged_labels(
|
|
302
|
+
"nnz", is_cat=False
|
|
303
|
+
)
|
|
304
|
+
self.clss_to_weight.remove("nnz")
|
|
305
|
+
# Sigmoid scaling with 2 parameters
|
|
306
|
+
midpoint = 2000
|
|
307
|
+
steepness = 0.003
|
|
308
|
+
# Apply sigmoid transformation
|
|
309
|
+
# sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
|
|
310
|
+
# Then scale to [1, NNZ_SCALE] range
|
|
311
|
+
sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
|
|
312
|
+
self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
|
|
313
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
314
|
+
labels = self.dataset.get_label_cats(
|
|
315
|
+
self.clss_to_weight,
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
labels = np.zeros(self.n_samples, dtype=int)
|
|
319
|
+
if isinstance(self.validation_split, int):
|
|
320
|
+
len_valid = self.validation_split
|
|
321
|
+
else:
|
|
322
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
323
|
+
if isinstance(self.test_split, int):
|
|
324
|
+
len_test = self.test_split
|
|
325
|
+
else:
|
|
326
|
+
len_test = int(self.n_samples * self.test_split)
|
|
327
|
+
assert len_test + len_valid < self.n_samples, (
|
|
328
|
+
"test set + valid set size is configured to be larger than entire dataset."
|
|
299
329
|
)
|
|
300
|
-
cs = 0
|
|
301
|
-
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
302
|
-
if cs + c > len_test:
|
|
303
|
-
break
|
|
304
|
-
else:
|
|
305
|
-
self.test_datasets.append(
|
|
306
|
-
self.dataset.mapped_dataset.path_list[i].path
|
|
307
|
-
)
|
|
308
|
-
cs += c
|
|
309
|
-
len_test = cs
|
|
310
|
-
self.test_idx = idx_full[:len_test]
|
|
311
|
-
idx_full = idx_full[len_test:]
|
|
312
|
-
else:
|
|
313
|
-
self.test_idx = None
|
|
314
330
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
331
|
+
idx_full = []
|
|
332
|
+
if len(self.assays_to_drop) > 0:
|
|
333
|
+
badloc = np.isin(
|
|
334
|
+
self.dataset.mapped_dataset.get_merged_labels(
|
|
335
|
+
"assay_ontology_term_id"
|
|
336
|
+
),
|
|
337
|
+
self.assays_to_drop,
|
|
338
|
+
)
|
|
339
|
+
idx_full = np.arange(len(labels))[~badloc]
|
|
340
|
+
else:
|
|
341
|
+
idx_full = np.arange(self.n_samples)
|
|
342
|
+
if len_test > 0:
|
|
343
|
+
# this way we work on some never seen datasets
|
|
344
|
+
# keeping at least one
|
|
345
|
+
len_test = (
|
|
346
|
+
len_test
|
|
347
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
348
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
349
|
+
)
|
|
350
|
+
cs = 0
|
|
351
|
+
d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
|
|
352
|
+
random.Random(42).shuffle(d_size) # always same order
|
|
353
|
+
for i, c in d_size:
|
|
354
|
+
if cs + c > len_test:
|
|
355
|
+
break
|
|
356
|
+
else:
|
|
357
|
+
self.test_datasets.append(
|
|
358
|
+
self.dataset.mapped_dataset.path_list[i].path
|
|
359
|
+
)
|
|
360
|
+
cs += c
|
|
361
|
+
len_test = cs
|
|
362
|
+
self.test_idx = idx_full[:len_test]
|
|
363
|
+
idx_full = idx_full[len_test:]
|
|
364
|
+
else:
|
|
365
|
+
self.test_idx = None
|
|
366
|
+
|
|
367
|
+
np.random.shuffle(idx_full)
|
|
368
|
+
if len_valid > 0:
|
|
369
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
370
|
+
# store it for later
|
|
371
|
+
idx_full = idx_full[len_valid:]
|
|
372
|
+
else:
|
|
373
|
+
self.valid_idx = None
|
|
374
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
|
|
375
|
+
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
376
|
+
# this means that the associated weights should be 0.
|
|
377
|
+
# by doing np.bincount(labels)*weights this will be taken into account
|
|
378
|
+
self.train_labels = labels
|
|
379
|
+
self.idx_full = idx_full
|
|
380
|
+
if self.store_location is not None:
|
|
381
|
+
if (
|
|
382
|
+
not os.path.exists(
|
|
383
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
384
|
+
)
|
|
385
|
+
or self.force_recompute_indices
|
|
386
|
+
):
|
|
387
|
+
os.makedirs(self.store_location, exist_ok=True)
|
|
388
|
+
if self.nnz is not None:
|
|
389
|
+
np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
|
|
390
|
+
np.save(
|
|
391
|
+
os.path.join(self.store_location, "train_labels.npy"),
|
|
392
|
+
self.train_labels,
|
|
393
|
+
)
|
|
394
|
+
np.save(
|
|
395
|
+
os.path.join(self.store_location, "idx_full.npy"), self.idx_full
|
|
396
|
+
)
|
|
397
|
+
if self.test_idx is not None:
|
|
398
|
+
np.save(
|
|
399
|
+
os.path.join(self.store_location, "test_idx.npy"), self.test_idx
|
|
400
|
+
)
|
|
401
|
+
if self.valid_idx is not None:
|
|
402
|
+
np.save(
|
|
403
|
+
os.path.join(self.store_location, "valid_idx.npy"),
|
|
404
|
+
self.valid_idx,
|
|
405
|
+
)
|
|
406
|
+
listToFile(
|
|
407
|
+
self.test_datasets,
|
|
408
|
+
os.path.join(self.store_location, "test_datasets.txt"),
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
self.nnz = (
|
|
412
|
+
np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
|
|
413
|
+
if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
|
|
414
|
+
else None
|
|
415
|
+
)
|
|
416
|
+
self.train_labels = np.load(
|
|
417
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
418
|
+
)
|
|
419
|
+
self.idx_full = np.load(
|
|
420
|
+
os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
|
|
421
|
+
)
|
|
422
|
+
self.test_idx = (
|
|
423
|
+
np.load(os.path.join(self.store_location, "test_idx.npy"))
|
|
424
|
+
if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
|
|
425
|
+
else None
|
|
426
|
+
)
|
|
427
|
+
self.valid_idx = (
|
|
428
|
+
np.load(os.path.join(self.store_location, "valid_idx.npy"))
|
|
429
|
+
if os.path.exists(
|
|
430
|
+
os.path.join(self.store_location, "valid_idx.npy")
|
|
431
|
+
)
|
|
432
|
+
else None
|
|
433
|
+
)
|
|
434
|
+
self.test_datasets = fileToList(
|
|
435
|
+
os.path.join(self.store_location, "test_datasets.txt")
|
|
436
|
+
)
|
|
437
|
+
print("loaded from store")
|
|
438
|
+
print(f"done setup, took {time.time() - start_time:.2f} seconds")
|
|
330
439
|
return self.test_datasets
|
|
331
440
|
|
|
332
441
|
def train_dataloader(self, **kwargs):
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
442
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
443
|
+
try:
|
|
444
|
+
print("Setting up the parallel train sampler...")
|
|
445
|
+
# Create the optimized parallel sampler
|
|
446
|
+
print(f"Using {self.sampler_workers} workers for class indexing")
|
|
447
|
+
train_sampler = LabelWeightedSampler(
|
|
448
|
+
labels=self.train_labels,
|
|
449
|
+
weight_scaler=self.weight_scaler,
|
|
450
|
+
num_samples=int(self.n_samples_per_epoch),
|
|
451
|
+
element_weights=self.nnz,
|
|
452
|
+
replacement=self.replacement,
|
|
453
|
+
n_workers=self.sampler_workers,
|
|
454
|
+
chunk_size=self.sampler_chunk_size,
|
|
455
|
+
store_location=self.store_location,
|
|
456
|
+
force_recompute_indices=self.force_recompute_indices,
|
|
457
|
+
curiculum=self.curiculum,
|
|
458
|
+
)
|
|
459
|
+
except ValueError as e:
|
|
460
|
+
raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
|
|
461
|
+
dataset = None
|
|
462
|
+
else:
|
|
463
|
+
dataset = Subset(self.dataset, self.idx_full)
|
|
464
|
+
train_sampler = RankShardSampler(len(dataset))
|
|
465
|
+
current_loader_kwargs = kwargs.copy()
|
|
466
|
+
current_loader_kwargs.update(self.kwargs)
|
|
348
467
|
return DataLoader(
|
|
349
|
-
self.dataset,
|
|
468
|
+
self.dataset if dataset is None else dataset,
|
|
350
469
|
sampler=train_sampler,
|
|
351
|
-
**
|
|
352
|
-
**kwargs,
|
|
470
|
+
**current_loader_kwargs,
|
|
353
471
|
)
|
|
354
472
|
|
|
355
473
|
def val_dataloader(self):
|
|
356
474
|
return (
|
|
357
475
|
DataLoader(
|
|
358
|
-
self.dataset,
|
|
359
|
-
sampler=SubsetRandomSampler(self.valid_idx),
|
|
476
|
+
Subset(self.dataset, self.valid_idx),
|
|
360
477
|
**self.kwargs,
|
|
361
478
|
)
|
|
362
479
|
if self.valid_idx is not None
|
|
363
|
-
else
|
|
480
|
+
else []
|
|
364
481
|
)
|
|
365
482
|
|
|
366
483
|
def test_dataloader(self):
|
|
@@ -369,109 +486,385 @@ class DataModule(L.LightningDataModule):
|
|
|
369
486
|
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
370
487
|
)
|
|
371
488
|
if self.test_idx is not None
|
|
372
|
-
else
|
|
489
|
+
else []
|
|
373
490
|
)
|
|
374
491
|
|
|
375
492
|
def predict_dataloader(self):
|
|
493
|
+
subset = Subset(self.dataset, self.idx_full)
|
|
376
494
|
return DataLoader(
|
|
377
|
-
|
|
378
|
-
sampler=
|
|
495
|
+
subset,
|
|
496
|
+
sampler=RankShardSampler(len(subset)),
|
|
379
497
|
**self.kwargs,
|
|
380
498
|
)
|
|
381
499
|
|
|
382
|
-
# def teardown(self):
|
|
383
|
-
# clean up state after the trainer stops, delete files...
|
|
384
|
-
# called on every process in DDP
|
|
385
|
-
# pass
|
|
386
|
-
|
|
387
500
|
|
|
388
501
|
class LabelWeightedSampler(Sampler[int]):
|
|
389
|
-
|
|
390
|
-
|
|
502
|
+
"""
|
|
503
|
+
A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
|
|
504
|
+
|
|
505
|
+
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
506
|
+
1. Parallel building of class indices
|
|
507
|
+
2. Chunked processing for large arrays
|
|
508
|
+
3. Efficient memory management
|
|
509
|
+
4. Proper handling of replacement and non-replacement sampling
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
label_weights: torch.Tensor
|
|
513
|
+
klass_indices: dict[int, torch.Tensor]
|
|
391
514
|
num_samples: int
|
|
392
|
-
|
|
515
|
+
element_weights: Optional[torch.Tensor]
|
|
393
516
|
replacement: bool
|
|
394
|
-
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
395
|
-
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
396
517
|
|
|
397
518
|
def __init__(
|
|
398
519
|
self,
|
|
399
|
-
|
|
400
|
-
labels: Sequence[int],
|
|
520
|
+
labels: np.ndarray,
|
|
401
521
|
num_samples: int,
|
|
402
522
|
replacement: bool = True,
|
|
403
|
-
|
|
523
|
+
weight_scaler: Optional[float] = None,
|
|
524
|
+
element_weights: Optional[Sequence[float]] = None,
|
|
525
|
+
n_workers: int = None,
|
|
526
|
+
chunk_size: int = None, # Process 10M elements per chunk
|
|
527
|
+
store_location: str = None,
|
|
528
|
+
force_recompute_indices: bool = False,
|
|
529
|
+
curiculum: int = 0,
|
|
404
530
|
) -> None:
|
|
405
531
|
"""
|
|
532
|
+
Initialize the sampler with parallel processing for large datasets.
|
|
406
533
|
|
|
407
|
-
:
|
|
408
|
-
|
|
409
|
-
|
|
534
|
+
Args:
|
|
535
|
+
weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
|
|
536
|
+
labels: Class label for each dataset element (length = dataset size)
|
|
537
|
+
num_samples: Number of samples to draw
|
|
538
|
+
replacement: Whether to sample with replacement
|
|
539
|
+
element_weights: Optional weights for each element within classes
|
|
540
|
+
n_workers: Number of parallel workers to use (default: number of CPUs-1)
|
|
541
|
+
chunk_size: Size of chunks to process in parallel (default: 10M elements)
|
|
410
542
|
"""
|
|
411
|
-
|
|
543
|
+
print("Initializing optimized parallel weighted sampler...")
|
|
412
544
|
super(LabelWeightedSampler, self).__init__(None)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
545
|
+
self.count = 0
|
|
546
|
+
self.curiculum = curiculum
|
|
547
|
+
|
|
548
|
+
# Compute label weights (incorporating class frequencies)
|
|
549
|
+
# Directly use labels as numpy array without conversion
|
|
550
|
+
counts = np.bincount(labels)
|
|
551
|
+
counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
|
|
552
|
+
label_weights = (weight_scaler * counts) / (counts + weight_scaler)
|
|
553
|
+
self.label_weights = torch.as_tensor(
|
|
554
|
+
label_weights, dtype=torch.float32
|
|
555
|
+
).share_memory_()
|
|
556
|
+
|
|
557
|
+
# Store element weights if provided
|
|
558
|
+
if element_weights is not None:
|
|
559
|
+
self.element_weights = torch.as_tensor(
|
|
560
|
+
element_weights, dtype=torch.float32
|
|
561
|
+
).share_memory_()
|
|
562
|
+
else:
|
|
563
|
+
self.element_weights = None
|
|
564
|
+
|
|
423
565
|
self.replacement = replacement
|
|
424
566
|
self.num_samples = num_samples
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
567
|
+
if (
|
|
568
|
+
store_location is None
|
|
569
|
+
or not os.path.exists(os.path.join(store_location, "klass_indices.pt"))
|
|
570
|
+
or force_recompute_indices
|
|
571
|
+
):
|
|
572
|
+
# Set number of workers (default to CPU count - 1, but at least 1)
|
|
573
|
+
if n_workers is None:
|
|
574
|
+
# Check if running on SLURM
|
|
575
|
+
n_workers = min(20, max(1, mp.cpu_count() - 1))
|
|
576
|
+
if "SLURM_CPUS_PER_TASK" in os.environ:
|
|
577
|
+
n_workers = min(
|
|
578
|
+
20, max(1, int(os.environ["SLURM_CPUS_PER_TASK"]) - 1)
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Try to auto-determine optimal chunk size based on memory
|
|
582
|
+
if chunk_size is None:
|
|
583
|
+
try:
|
|
584
|
+
import psutil
|
|
585
|
+
|
|
586
|
+
# Check if running on SLURM
|
|
587
|
+
available_memory = psutil.virtual_memory().available
|
|
588
|
+
for name in [
|
|
589
|
+
"SLURM_MEM_PER_NODE",
|
|
590
|
+
"SLURM_MEM_PER_CPU",
|
|
591
|
+
"SLURM_MEM_PER_GPU",
|
|
592
|
+
"SLURM_MEM_PER_TASK",
|
|
593
|
+
]:
|
|
594
|
+
if name in os.environ:
|
|
595
|
+
available_memory = (
|
|
596
|
+
int(os.environ[name]) * 1024 * 1024
|
|
597
|
+
) # Convert MB to bytes
|
|
598
|
+
break
|
|
599
|
+
|
|
600
|
+
# Use at most 50% of available memory across all workers
|
|
601
|
+
memory_per_worker = 0.5 * available_memory / n_workers
|
|
602
|
+
# Rough estimate: each label takes 4 bytes, each index 8 bytes
|
|
603
|
+
bytes_per_element = 12
|
|
604
|
+
chunk_size = min(
|
|
605
|
+
max(100_000, int(memory_per_worker / bytes_per_element / 3)),
|
|
606
|
+
2_000_000,
|
|
607
|
+
)
|
|
608
|
+
print(f"Auto-determined chunk size: {chunk_size:,} elements")
|
|
609
|
+
except (ImportError, KeyError):
|
|
610
|
+
chunk_size = 2_000_000
|
|
611
|
+
print(f"Using default chunk size: {chunk_size:,} elements")
|
|
612
|
+
|
|
613
|
+
# Parallelize the class indices building
|
|
614
|
+
print(f"Building class indices in parallel with {n_workers} workers...")
|
|
615
|
+
klass_indices = self._build_class_indices_parallel(
|
|
616
|
+
labels, chunk_size, n_workers
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Convert klass_indices to a single tensor and offset vector
|
|
620
|
+
all_indices = []
|
|
621
|
+
offsets = []
|
|
622
|
+
current_offset = 0
|
|
623
|
+
|
|
624
|
+
# Sort keys to ensure consistent ordering
|
|
625
|
+
keys = klass_indices.keys()
|
|
626
|
+
|
|
627
|
+
# Build concatenated tensor and track offsets
|
|
628
|
+
for i in range(max(keys) + 1):
|
|
629
|
+
offsets.append(current_offset)
|
|
630
|
+
if i in keys:
|
|
631
|
+
indices = klass_indices[i]
|
|
632
|
+
all_indices.append(indices)
|
|
633
|
+
current_offset += len(indices)
|
|
634
|
+
|
|
635
|
+
# Convert to tensors
|
|
636
|
+
self.klass_indices = torch.cat(all_indices).to(torch.int32).share_memory_()
|
|
637
|
+
self.klass_offsets = torch.tensor(offsets, dtype=torch.long).share_memory_()
|
|
638
|
+
if store_location is not None:
|
|
639
|
+
store_path = os.path.join(store_location, "klass_indices.pt")
|
|
640
|
+
if os.path.exists(store_path) and not force_recompute_indices:
|
|
641
|
+
self.klass_indices = torch.load(store_path).share_memory_()
|
|
642
|
+
self.klass_offsets = torch.load(
|
|
643
|
+
store_path.replace(".pt", "_offsets.pt")
|
|
644
|
+
).share_memory_()
|
|
645
|
+
print(f"Loaded sampler indices from {store_path}")
|
|
646
|
+
else:
|
|
647
|
+
torch.save(self.klass_indices, store_path)
|
|
648
|
+
torch.save(self.klass_offsets, store_path.replace(".pt", "_offsets.pt"))
|
|
649
|
+
print(f"Saved sampler indices to {store_path}")
|
|
650
|
+
print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
|
|
431
651
|
|
|
432
652
|
def __iter__(self):
|
|
653
|
+
self.count += 1
|
|
654
|
+
# Sample classes according to their weights
|
|
655
|
+
print("sampling a new batch of size", self.num_samples)
|
|
656
|
+
|
|
433
657
|
sample_labels = torch.multinomial(
|
|
434
|
-
|
|
658
|
+
(
|
|
659
|
+
self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
|
|
660
|
+
if self.curiculum
|
|
661
|
+
else self.label_weights
|
|
662
|
+
),
|
|
435
663
|
num_samples=self.num_samples,
|
|
436
664
|
replacement=True,
|
|
437
665
|
)
|
|
438
|
-
|
|
439
|
-
|
|
666
|
+
# Get counts of each class in sample_labels
|
|
667
|
+
unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
|
|
668
|
+
|
|
669
|
+
# Initialize result tensor
|
|
670
|
+
result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
|
|
671
|
+
|
|
672
|
+
# Process only the classes that were actually sampled
|
|
673
|
+
for i, (label, count) in tqdm(
|
|
674
|
+
enumerate(zip(unique_samples.tolist(), sample_counts.tolist())),
|
|
675
|
+
total=len(unique_samples),
|
|
676
|
+
desc="Processing classes in sampler",
|
|
677
|
+
):
|
|
678
|
+
klass_index = self.klass_indices[
|
|
679
|
+
self.klass_offsets[label] : self.klass_offsets[label + 1]
|
|
680
|
+
]
|
|
681
|
+
|
|
440
682
|
if klass_index.numel() == 0:
|
|
441
683
|
continue
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
continue
|
|
684
|
+
|
|
685
|
+
# Sample elements from this class
|
|
445
686
|
if self.element_weights is not None:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
687
|
+
# This is a critical point for memory
|
|
688
|
+
current_element_weights_slice = self.element_weights[klass_index]
|
|
689
|
+
|
|
690
|
+
if current_element_weights_slice.shape[0] >= (2**24) - 1:
|
|
691
|
+
ind = torch.randperm(len(klass_index))[: (2**24) - 10]
|
|
692
|
+
klass_index = klass_index[ind]
|
|
693
|
+
current_element_weights_slice = current_element_weights_slice[ind]
|
|
694
|
+
|
|
695
|
+
if self.replacement:
|
|
696
|
+
right_inds = torch.multinomial(
|
|
697
|
+
current_element_weights_slice,
|
|
698
|
+
num_samples=count,
|
|
699
|
+
replacement=True,
|
|
700
|
+
)
|
|
701
|
+
else:
|
|
702
|
+
num_to_sample = min(count, len(klass_index))
|
|
703
|
+
right_inds = torch.multinomial(
|
|
704
|
+
current_element_weights_slice,
|
|
705
|
+
num_samples=num_to_sample,
|
|
706
|
+
replacement=False,
|
|
707
|
+
)
|
|
453
708
|
elif self.replacement:
|
|
454
|
-
right_inds = torch.randint(
|
|
455
|
-
|
|
456
|
-
|
|
709
|
+
right_inds = torch.randint(len(klass_index), size=(count,))
|
|
710
|
+
else:
|
|
711
|
+
num_to_sample = min(count, len(klass_index))
|
|
712
|
+
right_inds = torch.randperm(len(klass_index))[:num_to_sample]
|
|
713
|
+
|
|
714
|
+
# Get actual indices
|
|
715
|
+
sampled_indices = klass_index[right_inds]
|
|
716
|
+
result_indices_list.append(sampled_indices)
|
|
717
|
+
|
|
718
|
+
# Combine all indices
|
|
719
|
+
if result_indices_list: # Check if the list is not empty
|
|
720
|
+
final_result_indices = torch.cat(
|
|
721
|
+
result_indices_list
|
|
722
|
+
) # Use the list with the appended new name
|
|
723
|
+
|
|
724
|
+
# Shuffle the combined indices
|
|
725
|
+
shuffled_indices = final_result_indices[
|
|
726
|
+
torch.randperm(len(final_result_indices))
|
|
727
|
+
]
|
|
728
|
+
self.num_samples = len(shuffled_indices)
|
|
729
|
+
yield from shuffled_indices.tolist()
|
|
730
|
+
else:
|
|
731
|
+
self.num_samples = 0
|
|
732
|
+
yield from iter([])
|
|
733
|
+
|
|
734
|
+
def __len__(self):
|
|
735
|
+
return self.num_samples
|
|
736
|
+
|
|
737
|
+
def _merge_chunk_results(self, results_list):
|
|
738
|
+
"""Merge results from multiple chunks into a single dictionary.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
results_list: list of dictionaries mapping class labels to index arrays
|
|
742
|
+
|
|
743
|
+
Returns:
|
|
744
|
+
merged dictionary with PyTorch tensors
|
|
745
|
+
"""
|
|
746
|
+
merged = {}
|
|
747
|
+
|
|
748
|
+
# Collect all labels across all chunks
|
|
749
|
+
all_labels = set()
|
|
750
|
+
for chunk_result in results_list:
|
|
751
|
+
all_labels.update(chunk_result.keys())
|
|
752
|
+
|
|
753
|
+
# For each unique label
|
|
754
|
+
for label in all_labels:
|
|
755
|
+
# Collect indices from all chunks where this label appears
|
|
756
|
+
indices_lists = [
|
|
757
|
+
chunk_result[label]
|
|
758
|
+
for chunk_result in results_list
|
|
759
|
+
if label in chunk_result
|
|
760
|
+
]
|
|
761
|
+
|
|
762
|
+
if indices_lists:
|
|
763
|
+
# Concatenate all indices for this label
|
|
764
|
+
merged[label] = torch.tensor(
|
|
765
|
+
np.concatenate(indices_lists), dtype=torch.long
|
|
457
766
|
)
|
|
458
767
|
else:
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
768
|
+
merged[label] = torch.tensor([], dtype=torch.long)
|
|
769
|
+
|
|
770
|
+
return merged
|
|
771
|
+
|
|
772
|
+
def _build_class_indices_parallel(self, labels, chunk_size, n_workers=None):
|
|
773
|
+
"""Build class indices in parallel across multiple workers.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
labels: array of class labels
|
|
777
|
+
n_workers: number of parallel workers
|
|
778
|
+
chunk_size: size of chunks to process
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
dictionary mapping class labels to tensors of indices
|
|
782
|
+
"""
|
|
783
|
+
n = len(labels)
|
|
784
|
+
results = []
|
|
785
|
+
# Create chunks of the labels array with proper sizing
|
|
786
|
+
n_chunks = (n + chunk_size - 1) // chunk_size # Ceiling division
|
|
787
|
+
print(f"Processing {n:,} elements in {n_chunks} chunks...")
|
|
788
|
+
|
|
789
|
+
# Process in chunks to limit memory usage
|
|
790
|
+
with ProcessPoolExecutor(
|
|
791
|
+
max_workers=n_workers, mp_context=mp.get_context("spawn")
|
|
792
|
+
) as executor:
|
|
793
|
+
# Submit chunks for processing
|
|
794
|
+
futures = []
|
|
795
|
+
for i in range(n_chunks):
|
|
796
|
+
start_idx = i * chunk_size
|
|
797
|
+
end_idx = min((i + 1) * chunk_size, n)
|
|
798
|
+
# We pass only chunk boundaries, not the data itself
|
|
799
|
+
# This avoids unnecessary copies during process creation
|
|
800
|
+
futures.append(
|
|
801
|
+
executor.submit(
|
|
802
|
+
self._process_chunk_with_slice,
|
|
803
|
+
(start_idx, end_idx, labels),
|
|
804
|
+
)
|
|
463
805
|
)
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
#
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
806
|
+
|
|
807
|
+
# Collect results as they complete with progress reporting
|
|
808
|
+
for future in tqdm(
|
|
809
|
+
as_completed(futures), total=len(futures), desc="Processing chunks"
|
|
810
|
+
):
|
|
811
|
+
results.append(future.result())
|
|
812
|
+
|
|
813
|
+
# Merge results from all chunks
|
|
814
|
+
print("Merging results from all chunks...")
|
|
815
|
+
merged_results = self._merge_chunk_results(results)
|
|
816
|
+
|
|
817
|
+
return merged_results
|
|
818
|
+
|
|
819
|
+
def _process_chunk_with_slice(self, slice_info):
|
|
820
|
+
"""Process a slice of the labels array by indices.
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
slice_info: tuple of (start_idx, end_idx, labels_array) where
|
|
824
|
+
start_idx and end_idx define the slice to process
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
dict mapping class labels to arrays of indices
|
|
828
|
+
"""
|
|
829
|
+
start_idx, end_idx, labels_array = slice_info
|
|
830
|
+
|
|
831
|
+
# We're processing a slice of the original array
|
|
832
|
+
labels_slice = labels_array[start_idx:end_idx]
|
|
833
|
+
chunk_indices = {}
|
|
834
|
+
|
|
835
|
+
# Create a direct map of indices
|
|
836
|
+
indices = np.arange(start_idx, end_idx)
|
|
837
|
+
|
|
838
|
+
# Get unique labels in this slice for more efficient processing
|
|
839
|
+
unique_labels = np.unique(labels_slice)
|
|
840
|
+
# For each valid label, find its indices
|
|
841
|
+
for label in unique_labels:
|
|
842
|
+
# Find positions where this label appears (using direct boolean indexing)
|
|
843
|
+
label_mask = labels_slice == label
|
|
844
|
+
chunk_indices[int(label)] = indices[label_mask]
|
|
845
|
+
|
|
846
|
+
return chunk_indices
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
class RankShardSampler(Sampler[int]):
|
|
850
|
+
"""Shards a dataset contiguously across ranks without padding or duplicates.
|
|
851
|
+
Preserves the existing order (e.g., your pre-shuffled idx_full)."""
|
|
852
|
+
|
|
853
|
+
def __init__(self, data_len: int):
|
|
854
|
+
self.data_len = data_len
|
|
855
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
856
|
+
self.rank = torch.distributed.get_rank()
|
|
857
|
+
self.world_size = torch.distributed.get_world_size()
|
|
858
|
+
else:
|
|
859
|
+
self.rank, self.world_size = 0, 1
|
|
860
|
+
|
|
861
|
+
# contiguous chunk per rank (last rank may be shorter)
|
|
862
|
+
per_rank = math.ceil(self.data_len / self.world_size)
|
|
863
|
+
self.start = self.rank * per_rank
|
|
864
|
+
self.end = min(self.start + per_rank, self.data_len)
|
|
865
|
+
|
|
866
|
+
def __iter__(self):
|
|
867
|
+
return iter(range(self.start, self.end))
|
|
475
868
|
|
|
476
869
|
def __len__(self):
|
|
477
|
-
return self.
|
|
870
|
+
return self.end - self.start
|