scdataloader 2.0.0__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +65 -56
- scdataloader/data.py +38 -54
- scdataloader/datamodule.py +139 -86
- scdataloader/mapped.py +27 -25
- scdataloader/preprocess.py +31 -16
- scdataloader/utils.py +120 -20
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.3.dist-info}/METADATA +5 -5
- scdataloader-2.0.3.dist-info/RECORD +16 -0
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.3.dist-info}/WHEEL +1 -1
- scdataloader-2.0.0.dist-info/RECORD +0 -16
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.3.dist-info}/entry_points.txt +0 -0
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.3.dist-info}/licenses/LICENSE +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import multiprocessing as mp
|
|
2
3
|
import os
|
|
3
4
|
import random
|
|
4
5
|
import time
|
|
5
6
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
6
7
|
from functools import partial
|
|
7
|
-
from typing import Optional, Sequence, Union
|
|
8
|
+
from typing import List, Optional, Sequence, Union
|
|
8
9
|
|
|
9
10
|
import lamindb as ln
|
|
10
11
|
import lightning as L
|
|
11
12
|
import numpy as np
|
|
12
13
|
import pandas as pd
|
|
13
14
|
import torch
|
|
14
|
-
from torch.utils.data import DataLoader, Sampler
|
|
15
|
+
from torch.utils.data import DataLoader, Sampler, Subset
|
|
15
16
|
from torch.utils.data.sampler import (
|
|
16
17
|
RandomSampler,
|
|
17
18
|
SequentialSampler,
|
|
@@ -25,32 +26,30 @@ from .data import Dataset
|
|
|
25
26
|
from .utils import fileToList, getBiomartTable, listToFile
|
|
26
27
|
|
|
27
28
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
29
|
+
NNZ_SCALE = 1000
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class DataModule(L.LightningDataModule):
|
|
31
33
|
def __init__(
|
|
32
34
|
self,
|
|
33
35
|
collection_name: str,
|
|
34
|
-
clss_to_weight:
|
|
36
|
+
clss_to_weight: List[str] = ["organism_ontology_term_id"],
|
|
35
37
|
weight_scaler: int = 10,
|
|
36
38
|
n_samples_per_epoch: int = 2_000_000,
|
|
37
39
|
validation_split: float = 0.2,
|
|
38
40
|
test_split: float = 0,
|
|
39
|
-
gene_embeddings: str = "",
|
|
40
41
|
use_default_col: bool = True,
|
|
41
|
-
gene_position_tolerance: int = 10_000,
|
|
42
42
|
# this is for the mappedCollection
|
|
43
|
-
clss_to_predict:
|
|
44
|
-
hierarchical_clss:
|
|
43
|
+
clss_to_predict: List[str] = ["organism_ontology_term_id"],
|
|
44
|
+
hierarchical_clss: List[str] = [],
|
|
45
45
|
# this is for the collator
|
|
46
46
|
how: str = "random expr",
|
|
47
|
-
|
|
47
|
+
organism_col: str = "organism_ontology_term_id",
|
|
48
48
|
max_len: int = 1000,
|
|
49
|
-
add_zero_genes: int = 100,
|
|
50
49
|
replacement: bool = True,
|
|
51
|
-
|
|
50
|
+
gene_subset: Optional[list[str]] = None,
|
|
52
51
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
53
|
-
assays_to_drop:
|
|
52
|
+
assays_to_drop: List[str] = [
|
|
54
53
|
# "EFO:0008853", #patch seq
|
|
55
54
|
# "EFO:0010961", # visium
|
|
56
55
|
"EFO:0030007", # ATACseq
|
|
@@ -62,6 +61,11 @@ class DataModule(L.LightningDataModule):
|
|
|
62
61
|
force_recompute_indices: bool = False,
|
|
63
62
|
sampler_workers: int = None,
|
|
64
63
|
sampler_chunk_size: int = None,
|
|
64
|
+
organisms: Optional[str] = None,
|
|
65
|
+
genedf: Optional[pd.DataFrame] = None,
|
|
66
|
+
n_bins: int = 0,
|
|
67
|
+
curiculum: int = 0,
|
|
68
|
+
start_at: int = 0,
|
|
65
69
|
**kwargs,
|
|
66
70
|
):
|
|
67
71
|
"""
|
|
@@ -78,23 +82,19 @@ class DataModule(L.LightningDataModule):
|
|
|
78
82
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
79
83
|
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
80
84
|
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
81
|
-
|
|
85
|
+
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
86
|
+
clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
87
|
+
assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
|
|
88
|
+
gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
|
|
82
89
|
the file must have ensembl_gene_id as index.
|
|
83
90
|
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
84
|
-
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
85
|
-
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
86
|
-
any genes within this distance of each other will be considered at the same position.
|
|
87
|
-
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
88
|
-
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
|
|
89
|
-
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
|
|
90
91
|
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
|
|
91
|
-
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
|
|
92
92
|
how (str, optional): The method to use for the collator. Defaults to "random expr".
|
|
93
|
-
|
|
93
|
+
organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
94
94
|
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
95
|
-
hierarchical_clss (
|
|
95
|
+
hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
|
|
96
96
|
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
97
|
-
clss_to_predict (
|
|
97
|
+
clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
98
98
|
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
|
|
99
99
|
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
100
100
|
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
@@ -107,44 +107,44 @@ class DataModule(L.LightningDataModule):
|
|
|
107
107
|
raise ValueError(
|
|
108
108
|
"need 'organism_ontology_term_id' in the set of classes at least"
|
|
109
109
|
)
|
|
110
|
+
if metacell_mode > 0 and get_knn_cells:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"cannot use metacell mode and get_knn_cells at the same time"
|
|
113
|
+
)
|
|
110
114
|
mdataset = Dataset(
|
|
111
|
-
ln.Collection.filter(
|
|
115
|
+
ln.Collection.filter(key=collection_name, is_latest=True).first(),
|
|
112
116
|
clss_to_predict=clss_to_predict,
|
|
113
117
|
hierarchical_clss=hierarchical_clss,
|
|
114
118
|
metacell_mode=metacell_mode,
|
|
115
119
|
get_knn_cells=get_knn_cells,
|
|
116
120
|
store_location=store_location,
|
|
117
121
|
force_recompute_indices=force_recompute_indices,
|
|
122
|
+
genedf=genedf,
|
|
118
123
|
)
|
|
119
124
|
# and location
|
|
120
125
|
self.metacell_mode = bool(metacell_mode)
|
|
121
126
|
self.gene_pos = None
|
|
122
127
|
self.collection_name = collection_name
|
|
123
|
-
if
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
127
|
-
if gene_embeddings != "":
|
|
128
|
-
mdataset.genedf = mdataset.genedf.join(
|
|
129
|
-
pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
|
|
130
|
-
)
|
|
131
|
-
if do_gene_pos:
|
|
132
|
-
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
128
|
+
if gene_subset is not None:
|
|
129
|
+
tokeep = set(mdataset.genedf.index.tolist())
|
|
130
|
+
gene_subset = [u for u in gene_subset if u in tokeep]
|
|
133
131
|
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
134
132
|
# we might want not to order the genes by expression (or do it?)
|
|
135
133
|
# we might want to not introduce zeros and
|
|
136
134
|
if use_default_col:
|
|
137
135
|
kwargs["collate_fn"] = Collator(
|
|
138
|
-
organisms=mdataset.organisms,
|
|
136
|
+
organisms=mdataset.organisms if organisms is None else organisms,
|
|
139
137
|
how=how,
|
|
140
|
-
valid_genes=
|
|
138
|
+
valid_genes=gene_subset,
|
|
141
139
|
max_len=max_len,
|
|
142
|
-
|
|
143
|
-
org_to_id=mdataset.encoder[organism_name],
|
|
140
|
+
org_to_id=mdataset.encoder[organism_col],
|
|
144
141
|
tp_name=tp_name,
|
|
145
|
-
organism_name=
|
|
142
|
+
organism_name=organism_col,
|
|
146
143
|
class_names=list(self.classes.keys()),
|
|
144
|
+
genedf=genedf,
|
|
145
|
+
n_bins=n_bins,
|
|
147
146
|
)
|
|
147
|
+
self.n_bins = n_bins
|
|
148
148
|
self.validation_split = validation_split
|
|
149
149
|
self.test_split = test_split
|
|
150
150
|
self.dataset = mdataset
|
|
@@ -163,8 +163,13 @@ class DataModule(L.LightningDataModule):
|
|
|
163
163
|
self.sampler_chunk_size = sampler_chunk_size
|
|
164
164
|
self.store_location = store_location
|
|
165
165
|
self.nnz = None
|
|
166
|
+
self.start_at = start_at
|
|
167
|
+
self.idx_full = None
|
|
168
|
+
self.max_len = max_len
|
|
166
169
|
self.test_datasets = []
|
|
167
170
|
self.force_recompute_indices = force_recompute_indices
|
|
171
|
+
self.curiculum = curiculum
|
|
172
|
+
self.valid_idx = []
|
|
168
173
|
self.test_idx = []
|
|
169
174
|
super().__init__()
|
|
170
175
|
print("finished init")
|
|
@@ -183,9 +188,11 @@ class DataModule(L.LightningDataModule):
|
|
|
183
188
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
184
189
|
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
185
190
|
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
186
|
-
+ (
|
|
187
|
-
|
|
188
|
-
|
|
191
|
+
+ (
|
|
192
|
+
"\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
|
|
193
|
+
if self.idx_full is not None
|
|
194
|
+
else ")"
|
|
195
|
+
)
|
|
189
196
|
)
|
|
190
197
|
|
|
191
198
|
@property
|
|
@@ -229,12 +236,17 @@ class DataModule(L.LightningDataModule):
|
|
|
229
236
|
"""
|
|
230
237
|
return self.dataset.genedf.index.tolist()
|
|
231
238
|
|
|
232
|
-
@
|
|
233
|
-
def
|
|
234
|
-
|
|
235
|
-
|
|
239
|
+
@property
|
|
240
|
+
def genes_dict(self):
|
|
241
|
+
return {
|
|
242
|
+
i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
|
|
243
|
+
for i in self.dataset.organisms
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
def set_valid_genes_collator(self, genes):
|
|
236
247
|
self.kwargs["collate_fn"]._setup(
|
|
237
|
-
genedf
|
|
248
|
+
# cannot use genedf there since I am purposefully decreasing it...
|
|
249
|
+
# genedf=self.dataset.genedf,
|
|
238
250
|
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
239
251
|
valid_genes=genes,
|
|
240
252
|
)
|
|
@@ -280,14 +292,11 @@ class DataModule(L.LightningDataModule):
|
|
|
280
292
|
stage (str, optional): The stage of the model training process.
|
|
281
293
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
282
294
|
"""
|
|
283
|
-
SCALE = 10
|
|
284
295
|
print("setting up the datamodule")
|
|
285
296
|
start_time = time.time()
|
|
286
297
|
if (
|
|
287
298
|
self.store_location is None
|
|
288
|
-
or not os.path.exists(
|
|
289
|
-
os.path.join(self.store_location, "train_weights.npy")
|
|
290
|
-
)
|
|
299
|
+
or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
|
|
291
300
|
or self.force_recompute_indices
|
|
292
301
|
):
|
|
293
302
|
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
@@ -295,18 +304,19 @@ class DataModule(L.LightningDataModule):
|
|
|
295
304
|
"nnz", is_cat=False
|
|
296
305
|
)
|
|
297
306
|
self.clss_to_weight.remove("nnz")
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
307
|
+
# Sigmoid scaling with 2 parameters
|
|
308
|
+
midpoint = 2000
|
|
309
|
+
steepness = 0.003
|
|
310
|
+
# Apply sigmoid transformation
|
|
311
|
+
# sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
|
|
312
|
+
# Then scale to [1, NNZ_SCALE] range
|
|
313
|
+
sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
|
|
314
|
+
self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
|
|
302
315
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
303
|
-
|
|
316
|
+
labels = self.dataset.get_label_cats(
|
|
304
317
|
self.clss_to_weight,
|
|
305
|
-
scaler=self.weight_scaler,
|
|
306
|
-
return_categories=True,
|
|
307
318
|
)
|
|
308
319
|
else:
|
|
309
|
-
weights = np.ones(1)
|
|
310
320
|
labels = np.zeros(self.n_samples, dtype=int)
|
|
311
321
|
if isinstance(self.validation_split, int):
|
|
312
322
|
len_valid = self.validation_split
|
|
@@ -316,9 +326,9 @@ class DataModule(L.LightningDataModule):
|
|
|
316
326
|
len_test = self.test_split
|
|
317
327
|
else:
|
|
318
328
|
len_test = int(self.n_samples * self.test_split)
|
|
319
|
-
assert
|
|
320
|
-
|
|
321
|
-
)
|
|
329
|
+
assert (
|
|
330
|
+
len_test + len_valid < self.n_samples
|
|
331
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
322
332
|
|
|
323
333
|
idx_full = []
|
|
324
334
|
if len(self.assays_to_drop) > 0:
|
|
@@ -363,28 +373,22 @@ class DataModule(L.LightningDataModule):
|
|
|
363
373
|
idx_full = idx_full[len_valid:]
|
|
364
374
|
else:
|
|
365
375
|
self.valid_idx = None
|
|
366
|
-
|
|
367
|
-
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
376
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
|
|
368
377
|
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
369
378
|
# this means that the associated weights should be 0.
|
|
370
379
|
# by doing np.bincount(labels)*weights this will be taken into account
|
|
371
|
-
self.train_weights = weights
|
|
372
380
|
self.train_labels = labels
|
|
373
381
|
self.idx_full = idx_full
|
|
374
382
|
if self.store_location is not None:
|
|
375
383
|
if (
|
|
376
384
|
not os.path.exists(
|
|
377
|
-
os.path.join(self.store_location, "
|
|
385
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
378
386
|
)
|
|
379
387
|
or self.force_recompute_indices
|
|
380
388
|
):
|
|
381
389
|
os.makedirs(self.store_location, exist_ok=True)
|
|
382
390
|
if self.nnz is not None:
|
|
383
391
|
np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
|
|
384
|
-
np.save(
|
|
385
|
-
os.path.join(self.store_location, "train_weights.npy"),
|
|
386
|
-
self.train_weights,
|
|
387
|
-
)
|
|
388
392
|
np.save(
|
|
389
393
|
os.path.join(self.store_location, "train_labels.npy"),
|
|
390
394
|
self.train_labels,
|
|
@@ -411,9 +415,6 @@ class DataModule(L.LightningDataModule):
|
|
|
411
415
|
if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
|
|
412
416
|
else None
|
|
413
417
|
)
|
|
414
|
-
self.train_weights = np.load(
|
|
415
|
-
os.path.join(self.store_location, "train_weights.npy")
|
|
416
|
-
)
|
|
417
418
|
self.train_labels = np.load(
|
|
418
419
|
os.path.join(self.store_location, "train_labels.npy")
|
|
419
420
|
)
|
|
@@ -446,8 +447,8 @@ class DataModule(L.LightningDataModule):
|
|
|
446
447
|
# Create the optimized parallel sampler
|
|
447
448
|
print(f"Using {self.sampler_workers} workers for class indexing")
|
|
448
449
|
train_sampler = LabelWeightedSampler(
|
|
449
|
-
label_weights=self.train_weights,
|
|
450
450
|
labels=self.train_labels,
|
|
451
|
+
weight_scaler=self.weight_scaler,
|
|
451
452
|
num_samples=int(self.n_samples_per_epoch),
|
|
452
453
|
element_weights=self.nnz,
|
|
453
454
|
replacement=self.replacement,
|
|
@@ -455,15 +456,18 @@ class DataModule(L.LightningDataModule):
|
|
|
455
456
|
chunk_size=self.sampler_chunk_size,
|
|
456
457
|
store_location=self.store_location,
|
|
457
458
|
force_recompute_indices=self.force_recompute_indices,
|
|
459
|
+
curiculum=self.curiculum,
|
|
458
460
|
)
|
|
459
461
|
except ValueError as e:
|
|
460
462
|
raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
|
|
463
|
+
dataset = None
|
|
461
464
|
else:
|
|
462
|
-
|
|
465
|
+
dataset = Subset(self.dataset, self.idx_full)
|
|
466
|
+
train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
|
|
463
467
|
current_loader_kwargs = kwargs.copy()
|
|
464
468
|
current_loader_kwargs.update(self.kwargs)
|
|
465
469
|
return DataLoader(
|
|
466
|
-
self.dataset,
|
|
470
|
+
self.dataset if dataset is None else dataset,
|
|
467
471
|
sampler=train_sampler,
|
|
468
472
|
**current_loader_kwargs,
|
|
469
473
|
)
|
|
@@ -471,12 +475,11 @@ class DataModule(L.LightningDataModule):
|
|
|
471
475
|
def val_dataloader(self):
|
|
472
476
|
return (
|
|
473
477
|
DataLoader(
|
|
474
|
-
self.dataset,
|
|
475
|
-
sampler=SubsetRandomSampler(self.valid_idx),
|
|
478
|
+
Subset(self.dataset, self.valid_idx),
|
|
476
479
|
**self.kwargs,
|
|
477
480
|
)
|
|
478
481
|
if self.valid_idx is not None
|
|
479
|
-
else
|
|
482
|
+
else []
|
|
480
483
|
)
|
|
481
484
|
|
|
482
485
|
def test_dataloader(self):
|
|
@@ -485,20 +488,21 @@ class DataModule(L.LightningDataModule):
|
|
|
485
488
|
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
486
489
|
)
|
|
487
490
|
if self.test_idx is not None
|
|
488
|
-
else
|
|
491
|
+
else []
|
|
489
492
|
)
|
|
490
493
|
|
|
491
494
|
def predict_dataloader(self):
|
|
495
|
+
subset = Subset(self.dataset, self.idx_full)
|
|
492
496
|
return DataLoader(
|
|
493
497
|
self.dataset,
|
|
494
|
-
sampler=
|
|
498
|
+
sampler=RankShardSampler(len(subset), start_at=self.start_at),
|
|
495
499
|
**self.kwargs,
|
|
496
500
|
)
|
|
497
501
|
|
|
498
502
|
|
|
499
503
|
class LabelWeightedSampler(Sampler[int]):
|
|
500
504
|
"""
|
|
501
|
-
A weighted random sampler that samples from a dataset with respect
|
|
505
|
+
A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
|
|
502
506
|
|
|
503
507
|
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
504
508
|
1. Parallel building of class indices
|
|
@@ -515,21 +519,22 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
515
519
|
|
|
516
520
|
def __init__(
|
|
517
521
|
self,
|
|
518
|
-
label_weights: Sequence[float],
|
|
519
522
|
labels: np.ndarray,
|
|
520
523
|
num_samples: int,
|
|
521
524
|
replacement: bool = True,
|
|
525
|
+
weight_scaler: Optional[float] = None,
|
|
522
526
|
element_weights: Optional[Sequence[float]] = None,
|
|
523
527
|
n_workers: int = None,
|
|
524
528
|
chunk_size: int = None, # Process 10M elements per chunk
|
|
525
529
|
store_location: str = None,
|
|
526
530
|
force_recompute_indices: bool = False,
|
|
531
|
+
curiculum: int = 0,
|
|
527
532
|
) -> None:
|
|
528
533
|
"""
|
|
529
534
|
Initialize the sampler with parallel processing for large datasets.
|
|
530
535
|
|
|
531
536
|
Args:
|
|
532
|
-
|
|
537
|
+
weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
|
|
533
538
|
labels: Class label for each dataset element (length = dataset size)
|
|
534
539
|
num_samples: Number of samples to draw
|
|
535
540
|
replacement: Whether to sample with replacement
|
|
@@ -539,10 +544,14 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
539
544
|
"""
|
|
540
545
|
print("Initializing optimized parallel weighted sampler...")
|
|
541
546
|
super(LabelWeightedSampler, self).__init__(None)
|
|
547
|
+
self.count = 0
|
|
548
|
+
self.curiculum = curiculum
|
|
542
549
|
|
|
543
550
|
# Compute label weights (incorporating class frequencies)
|
|
544
551
|
# Directly use labels as numpy array without conversion
|
|
545
|
-
|
|
552
|
+
counts = np.bincount(labels)
|
|
553
|
+
counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
|
|
554
|
+
label_weights = (weight_scaler * counts) / (counts + weight_scaler)
|
|
546
555
|
self.label_weights = torch.as_tensor(
|
|
547
556
|
label_weights, dtype=torch.float32
|
|
548
557
|
).share_memory_()
|
|
@@ -643,11 +652,16 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
643
652
|
print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
|
|
644
653
|
|
|
645
654
|
def __iter__(self):
|
|
655
|
+
self.count += 1
|
|
646
656
|
# Sample classes according to their weights
|
|
647
657
|
print("sampling a new batch of size", self.num_samples)
|
|
648
658
|
|
|
649
659
|
sample_labels = torch.multinomial(
|
|
650
|
-
|
|
660
|
+
(
|
|
661
|
+
self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
|
|
662
|
+
if self.curiculum
|
|
663
|
+
else self.label_weights
|
|
664
|
+
),
|
|
651
665
|
num_samples=self.num_samples,
|
|
652
666
|
replacement=True,
|
|
653
667
|
)
|
|
@@ -655,7 +669,9 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
655
669
|
unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
|
|
656
670
|
|
|
657
671
|
# Initialize result tensor
|
|
658
|
-
result_indices_list =
|
|
672
|
+
result_indices_list = (
|
|
673
|
+
[]
|
|
674
|
+
) # Changed name to avoid conflict if you had result_indices elsewhere
|
|
659
675
|
|
|
660
676
|
# Process only the classes that were actually sampled
|
|
661
677
|
for i, (label, count) in tqdm(
|
|
@@ -675,6 +691,11 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
675
691
|
# This is a critical point for memory
|
|
676
692
|
current_element_weights_slice = self.element_weights[klass_index]
|
|
677
693
|
|
|
694
|
+
if current_element_weights_slice.shape[0] >= (2**24) - 1:
|
|
695
|
+
ind = torch.randperm(len(klass_index))[: (2**24) - 10]
|
|
696
|
+
klass_index = klass_index[ind]
|
|
697
|
+
current_element_weights_slice = current_element_weights_slice[ind]
|
|
698
|
+
|
|
678
699
|
if self.replacement:
|
|
679
700
|
right_inds = torch.multinomial(
|
|
680
701
|
current_element_weights_slice,
|
|
@@ -827,3 +848,35 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
827
848
|
chunk_indices[int(label)] = indices[label_mask]
|
|
828
849
|
|
|
829
850
|
return chunk_indices
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
class RankShardSampler(Sampler[int]):
|
|
854
|
+
"""Shards a dataset contiguously across ranks without padding or duplicates.
|
|
855
|
+
Preserves the existing order (e.g., your pre-shuffled idx_full)."""
|
|
856
|
+
|
|
857
|
+
def __init__(self, data_len: int, start_at: int = 0) -> None:
|
|
858
|
+
self.data_len = data_len
|
|
859
|
+
self.start_at = start_at
|
|
860
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
861
|
+
self.rank = torch.distributed.get_rank()
|
|
862
|
+
self.world_size = torch.distributed.get_world_size()
|
|
863
|
+
else:
|
|
864
|
+
self.rank, self.world_size = 0, 1
|
|
865
|
+
|
|
866
|
+
# contiguous chunk per rank (last rank may be shorter)
|
|
867
|
+
if self.start_at > 0:
|
|
868
|
+
print(
|
|
869
|
+
"!!!!ATTTENTION: make sure that you are running on the exact same \
|
|
870
|
+
number of GPU as your previous run!!!!!"
|
|
871
|
+
)
|
|
872
|
+
print(f"Sharding data of size {data_len} over {self.world_size} ranks")
|
|
873
|
+
per_rank = math.ceil((self.data_len - self.start_at) / self.world_size)
|
|
874
|
+
self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
|
|
875
|
+
self.end = min(self.start + per_rank, self.data_len)
|
|
876
|
+
print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
|
|
877
|
+
|
|
878
|
+
def __iter__(self):
|
|
879
|
+
return iter(range(self.start, self.end))
|
|
880
|
+
|
|
881
|
+
def __len__(self):
|
|
882
|
+
return self.end - self.start
|
scdataloader/mapped.py
CHANGED
|
@@ -10,7 +10,7 @@ from __future__ import annotations
|
|
|
10
10
|
import os
|
|
11
11
|
from collections import Counter
|
|
12
12
|
from functools import reduce
|
|
13
|
-
from typing import TYPE_CHECKING, Literal
|
|
13
|
+
from typing import TYPE_CHECKING, List, Literal
|
|
14
14
|
|
|
15
15
|
import numpy as np
|
|
16
16
|
import pandas as pd
|
|
@@ -117,20 +117,20 @@ class MappedCollection:
|
|
|
117
117
|
|
|
118
118
|
def __init__(
|
|
119
119
|
self,
|
|
120
|
-
path_list:
|
|
121
|
-
layers_keys: str |
|
|
122
|
-
obs_keys: str |
|
|
123
|
-
obsm_keys: str |
|
|
120
|
+
path_list: List[UPathStr],
|
|
121
|
+
layers_keys: str | List[str] | None = None,
|
|
122
|
+
obs_keys: str | List[str] | None = None,
|
|
123
|
+
obsm_keys: str | List[str] | None = None,
|
|
124
124
|
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
125
125
|
join: Literal["inner", "outer"] | None = "inner",
|
|
126
|
-
encode_labels: bool |
|
|
126
|
+
encode_labels: bool | List[str] = True,
|
|
127
127
|
unknown_label: str | dict[str, str] | None = None,
|
|
128
128
|
cache_categories: bool = True,
|
|
129
129
|
parallel: bool = False,
|
|
130
130
|
dtype: str | None = None,
|
|
131
131
|
metacell_mode: float = 0.0,
|
|
132
132
|
get_knn_cells: bool = False,
|
|
133
|
-
meta_assays:
|
|
133
|
+
meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
|
|
134
134
|
store_location: str | None = None,
|
|
135
135
|
force_recompute_indices: bool = False,
|
|
136
136
|
):
|
|
@@ -200,7 +200,9 @@ class MappedCollection:
|
|
|
200
200
|
self._cache_categories(self.obs_keys)
|
|
201
201
|
torch.save(self._cache_cats, self.store_location)
|
|
202
202
|
else:
|
|
203
|
-
self._cache_cats = torch.load(
|
|
203
|
+
self._cache_cats = torch.load(
|
|
204
|
+
self.store_location, weights_only=False
|
|
205
|
+
)
|
|
204
206
|
print(f"Loaded categories from {self.store_location}")
|
|
205
207
|
self.encoders: dict = {}
|
|
206
208
|
if self.encode_labels:
|
|
@@ -348,7 +350,7 @@ class MappedCollection:
|
|
|
348
350
|
vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
|
|
349
351
|
return all(vrs_sort_status)
|
|
350
352
|
|
|
351
|
-
def check_vars_non_aligned(self, vars: pd.Index |
|
|
353
|
+
def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
|
|
352
354
|
"""Returns indices of objects with non-aligned variables.
|
|
353
355
|
|
|
354
356
|
Args:
|
|
@@ -380,7 +382,7 @@ class MappedCollection:
|
|
|
380
382
|
return (self.n_obs, self.n_vars)
|
|
381
383
|
|
|
382
384
|
@property
|
|
383
|
-
def original_shapes(self) ->
|
|
385
|
+
def original_shapes(self) -> List[tuple[int, int]]:
|
|
384
386
|
"""Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
|
|
385
387
|
if self.n_vars_list is None:
|
|
386
388
|
n_vars_list = [None] * len(self.n_obs_list)
|
|
@@ -437,7 +439,20 @@ class MappedCollection:
|
|
|
437
439
|
print(out)
|
|
438
440
|
raise
|
|
439
441
|
|
|
440
|
-
if self.
|
|
442
|
+
if self.get_knn_cells:
|
|
443
|
+
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
444
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
445
|
+
out["knn_cells"] = np.array(
|
|
446
|
+
[
|
|
447
|
+
self._get_data_idx(
|
|
448
|
+
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
449
|
+
)
|
|
450
|
+
for i in nn_idx
|
|
451
|
+
],
|
|
452
|
+
dtype=int,
|
|
453
|
+
)
|
|
454
|
+
out["knn_cells_info"] = distances[nn_idx]
|
|
455
|
+
elif self.metacell_mode > 0:
|
|
441
456
|
if (
|
|
442
457
|
len(self.meta_assays) > 0
|
|
443
458
|
and "assay_ontology_term_id" in self.obs_keys
|
|
@@ -454,19 +469,6 @@ class MappedCollection:
|
|
|
454
469
|
out[layers_key] += self._get_data_idx(
|
|
455
470
|
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
456
471
|
)
|
|
457
|
-
elif self.get_knn_cells:
|
|
458
|
-
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
459
|
-
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
460
|
-
out["knn_cells"] = np.array(
|
|
461
|
-
[
|
|
462
|
-
self._get_data_idx(
|
|
463
|
-
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
464
|
-
)
|
|
465
|
-
for i in nn_idx
|
|
466
|
-
],
|
|
467
|
-
dtype=int,
|
|
468
|
-
)
|
|
469
|
-
out["distances"] = distances[nn_idx]
|
|
470
472
|
|
|
471
473
|
return out
|
|
472
474
|
|
|
@@ -541,7 +543,7 @@ class MappedCollection:
|
|
|
541
543
|
|
|
542
544
|
def get_label_weights(
|
|
543
545
|
self,
|
|
544
|
-
obs_keys: str |
|
|
546
|
+
obs_keys: str | List[str],
|
|
545
547
|
scaler: float | None = None,
|
|
546
548
|
return_categories: bool = False,
|
|
547
549
|
):
|