scdataloader 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +65 -56
- scdataloader/data.py +38 -54
- scdataloader/datamodule.py +124 -83
- scdataloader/mapped.py +27 -25
- scdataloader/preprocess.py +31 -16
- scdataloader/utils.py +120 -20
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.2.dist-info}/METADATA +5 -5
- scdataloader-2.0.2.dist-info/RECORD +16 -0
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.2.dist-info}/WHEEL +1 -1
- scdataloader-2.0.0.dist-info/RECORD +0 -16
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.2.dist-info}/entry_points.txt +0 -0
- {scdataloader-2.0.0.dist-info → scdataloader-2.0.2.dist-info}/licenses/LICENSE +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import multiprocessing as mp
|
|
2
3
|
import os
|
|
3
4
|
import random
|
|
4
5
|
import time
|
|
5
6
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
6
7
|
from functools import partial
|
|
7
|
-
from typing import Optional, Sequence, Union
|
|
8
|
+
from typing import List, Optional, Sequence, Union
|
|
8
9
|
|
|
9
10
|
import lamindb as ln
|
|
10
11
|
import lightning as L
|
|
11
12
|
import numpy as np
|
|
12
13
|
import pandas as pd
|
|
13
14
|
import torch
|
|
14
|
-
from torch.utils.data import DataLoader, Sampler
|
|
15
|
+
from torch.utils.data import DataLoader, Sampler, Subset
|
|
15
16
|
from torch.utils.data.sampler import (
|
|
16
17
|
RandomSampler,
|
|
17
18
|
SequentialSampler,
|
|
@@ -25,32 +26,30 @@ from .data import Dataset
|
|
|
25
26
|
from .utils import fileToList, getBiomartTable, listToFile
|
|
26
27
|
|
|
27
28
|
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
29
|
+
NNZ_SCALE = 1000
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class DataModule(L.LightningDataModule):
|
|
31
33
|
def __init__(
|
|
32
34
|
self,
|
|
33
35
|
collection_name: str,
|
|
34
|
-
clss_to_weight:
|
|
36
|
+
clss_to_weight: List[str] = ["organism_ontology_term_id"],
|
|
35
37
|
weight_scaler: int = 10,
|
|
36
38
|
n_samples_per_epoch: int = 2_000_000,
|
|
37
39
|
validation_split: float = 0.2,
|
|
38
40
|
test_split: float = 0,
|
|
39
|
-
gene_embeddings: str = "",
|
|
40
41
|
use_default_col: bool = True,
|
|
41
|
-
gene_position_tolerance: int = 10_000,
|
|
42
42
|
# this is for the mappedCollection
|
|
43
|
-
clss_to_predict:
|
|
44
|
-
hierarchical_clss:
|
|
43
|
+
clss_to_predict: List[str] = ["organism_ontology_term_id"],
|
|
44
|
+
hierarchical_clss: List[str] = [],
|
|
45
45
|
# this is for the collator
|
|
46
46
|
how: str = "random expr",
|
|
47
|
-
|
|
47
|
+
organism_col: str = "organism_ontology_term_id",
|
|
48
48
|
max_len: int = 1000,
|
|
49
|
-
add_zero_genes: int = 100,
|
|
50
49
|
replacement: bool = True,
|
|
51
|
-
|
|
50
|
+
gene_subset: Optional[list[str]] = None,
|
|
52
51
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
53
|
-
assays_to_drop:
|
|
52
|
+
assays_to_drop: List[str] = [
|
|
54
53
|
# "EFO:0008853", #patch seq
|
|
55
54
|
# "EFO:0010961", # visium
|
|
56
55
|
"EFO:0030007", # ATACseq
|
|
@@ -62,6 +61,10 @@ class DataModule(L.LightningDataModule):
|
|
|
62
61
|
force_recompute_indices: bool = False,
|
|
63
62
|
sampler_workers: int = None,
|
|
64
63
|
sampler_chunk_size: int = None,
|
|
64
|
+
organisms: Optional[str] = None,
|
|
65
|
+
genedf: Optional[pd.DataFrame] = None,
|
|
66
|
+
n_bins: int = 0,
|
|
67
|
+
curiculum: int = 0,
|
|
65
68
|
**kwargs,
|
|
66
69
|
):
|
|
67
70
|
"""
|
|
@@ -78,23 +81,19 @@ class DataModule(L.LightningDataModule):
|
|
|
78
81
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
79
82
|
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
80
83
|
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
81
|
-
|
|
84
|
+
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
85
|
+
clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
86
|
+
assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
|
|
87
|
+
gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
|
|
82
88
|
the file must have ensembl_gene_id as index.
|
|
83
89
|
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
84
|
-
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
85
|
-
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
86
|
-
any genes within this distance of each other will be considered at the same position.
|
|
87
|
-
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
88
|
-
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
|
|
89
|
-
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
|
|
90
90
|
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
|
|
91
|
-
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
|
|
92
91
|
how (str, optional): The method to use for the collator. Defaults to "random expr".
|
|
93
|
-
|
|
92
|
+
organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
94
93
|
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
95
|
-
hierarchical_clss (
|
|
94
|
+
hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
|
|
96
95
|
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
97
|
-
clss_to_predict (
|
|
96
|
+
clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
98
97
|
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
|
|
99
98
|
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
100
99
|
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
@@ -107,44 +106,44 @@ class DataModule(L.LightningDataModule):
|
|
|
107
106
|
raise ValueError(
|
|
108
107
|
"need 'organism_ontology_term_id' in the set of classes at least"
|
|
109
108
|
)
|
|
109
|
+
if metacell_mode > 0 and get_knn_cells:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"cannot use metacell mode and get_knn_cells at the same time"
|
|
112
|
+
)
|
|
110
113
|
mdataset = Dataset(
|
|
111
|
-
ln.Collection.filter(
|
|
114
|
+
ln.Collection.filter(key=collection_name, is_latest=True).first(),
|
|
112
115
|
clss_to_predict=clss_to_predict,
|
|
113
116
|
hierarchical_clss=hierarchical_clss,
|
|
114
117
|
metacell_mode=metacell_mode,
|
|
115
118
|
get_knn_cells=get_knn_cells,
|
|
116
119
|
store_location=store_location,
|
|
117
120
|
force_recompute_indices=force_recompute_indices,
|
|
121
|
+
genedf=genedf,
|
|
118
122
|
)
|
|
119
123
|
# and location
|
|
120
124
|
self.metacell_mode = bool(metacell_mode)
|
|
121
125
|
self.gene_pos = None
|
|
122
126
|
self.collection_name = collection_name
|
|
123
|
-
if
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
127
|
-
if gene_embeddings != "":
|
|
128
|
-
mdataset.genedf = mdataset.genedf.join(
|
|
129
|
-
pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
|
|
130
|
-
)
|
|
131
|
-
if do_gene_pos:
|
|
132
|
-
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
127
|
+
if gene_subset is not None:
|
|
128
|
+
tokeep = set(mdataset.genedf.index.tolist())
|
|
129
|
+
gene_subset = [u for u in gene_subset if u in tokeep]
|
|
133
130
|
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
134
131
|
# we might want not to order the genes by expression (or do it?)
|
|
135
132
|
# we might want to not introduce zeros and
|
|
136
133
|
if use_default_col:
|
|
137
134
|
kwargs["collate_fn"] = Collator(
|
|
138
|
-
organisms=mdataset.organisms,
|
|
135
|
+
organisms=mdataset.organisms if organisms is None else organisms,
|
|
139
136
|
how=how,
|
|
140
|
-
valid_genes=
|
|
137
|
+
valid_genes=gene_subset,
|
|
141
138
|
max_len=max_len,
|
|
142
|
-
|
|
143
|
-
org_to_id=mdataset.encoder[organism_name],
|
|
139
|
+
org_to_id=mdataset.encoder[organism_col],
|
|
144
140
|
tp_name=tp_name,
|
|
145
|
-
organism_name=
|
|
141
|
+
organism_name=organism_col,
|
|
146
142
|
class_names=list(self.classes.keys()),
|
|
143
|
+
genedf=genedf,
|
|
144
|
+
n_bins=n_bins,
|
|
147
145
|
)
|
|
146
|
+
self.n_bins = n_bins
|
|
148
147
|
self.validation_split = validation_split
|
|
149
148
|
self.test_split = test_split
|
|
150
149
|
self.dataset = mdataset
|
|
@@ -163,8 +162,12 @@ class DataModule(L.LightningDataModule):
|
|
|
163
162
|
self.sampler_chunk_size = sampler_chunk_size
|
|
164
163
|
self.store_location = store_location
|
|
165
164
|
self.nnz = None
|
|
165
|
+
self.idx_full = None
|
|
166
|
+
self.max_len = max_len
|
|
166
167
|
self.test_datasets = []
|
|
167
168
|
self.force_recompute_indices = force_recompute_indices
|
|
169
|
+
self.curiculum = curiculum
|
|
170
|
+
self.valid_idx = []
|
|
168
171
|
self.test_idx = []
|
|
169
172
|
super().__init__()
|
|
170
173
|
print("finished init")
|
|
@@ -183,9 +186,11 @@ class DataModule(L.LightningDataModule):
|
|
|
183
186
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
184
187
|
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
185
188
|
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
186
|
-
+ (
|
|
187
|
-
|
|
188
|
-
|
|
189
|
+
+ (
|
|
190
|
+
"\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
|
|
191
|
+
if self.idx_full is not None
|
|
192
|
+
else ")"
|
|
193
|
+
)
|
|
189
194
|
)
|
|
190
195
|
|
|
191
196
|
@property
|
|
@@ -229,12 +234,17 @@ class DataModule(L.LightningDataModule):
|
|
|
229
234
|
"""
|
|
230
235
|
return self.dataset.genedf.index.tolist()
|
|
231
236
|
|
|
232
|
-
@
|
|
233
|
-
def
|
|
234
|
-
|
|
235
|
-
|
|
237
|
+
@property
|
|
238
|
+
def genes_dict(self):
|
|
239
|
+
return {
|
|
240
|
+
i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
|
|
241
|
+
for i in self.dataset.organisms
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
def set_valid_genes_collator(self, genes):
|
|
236
245
|
self.kwargs["collate_fn"]._setup(
|
|
237
|
-
genedf
|
|
246
|
+
# cannot use genedf there since I am purposefully decreasing it...
|
|
247
|
+
# genedf=self.dataset.genedf,
|
|
238
248
|
org_to_id=self.kwargs["collate_fn"].org_to_id,
|
|
239
249
|
valid_genes=genes,
|
|
240
250
|
)
|
|
@@ -280,14 +290,11 @@ class DataModule(L.LightningDataModule):
|
|
|
280
290
|
stage (str, optional): The stage of the model training process.
|
|
281
291
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
282
292
|
"""
|
|
283
|
-
SCALE = 10
|
|
284
293
|
print("setting up the datamodule")
|
|
285
294
|
start_time = time.time()
|
|
286
295
|
if (
|
|
287
296
|
self.store_location is None
|
|
288
|
-
or not os.path.exists(
|
|
289
|
-
os.path.join(self.store_location, "train_weights.npy")
|
|
290
|
-
)
|
|
297
|
+
or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
|
|
291
298
|
or self.force_recompute_indices
|
|
292
299
|
):
|
|
293
300
|
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
@@ -295,18 +302,19 @@ class DataModule(L.LightningDataModule):
|
|
|
295
302
|
"nnz", is_cat=False
|
|
296
303
|
)
|
|
297
304
|
self.clss_to_weight.remove("nnz")
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
305
|
+
# Sigmoid scaling with 2 parameters
|
|
306
|
+
midpoint = 2000
|
|
307
|
+
steepness = 0.003
|
|
308
|
+
# Apply sigmoid transformation
|
|
309
|
+
# sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
|
|
310
|
+
# Then scale to [1, NNZ_SCALE] range
|
|
311
|
+
sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
|
|
312
|
+
self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
|
|
302
313
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
303
|
-
|
|
314
|
+
labels = self.dataset.get_label_cats(
|
|
304
315
|
self.clss_to_weight,
|
|
305
|
-
scaler=self.weight_scaler,
|
|
306
|
-
return_categories=True,
|
|
307
316
|
)
|
|
308
317
|
else:
|
|
309
|
-
weights = np.ones(1)
|
|
310
318
|
labels = np.zeros(self.n_samples, dtype=int)
|
|
311
319
|
if isinstance(self.validation_split, int):
|
|
312
320
|
len_valid = self.validation_split
|
|
@@ -363,28 +371,22 @@ class DataModule(L.LightningDataModule):
|
|
|
363
371
|
idx_full = idx_full[len_valid:]
|
|
364
372
|
else:
|
|
365
373
|
self.valid_idx = None
|
|
366
|
-
|
|
367
|
-
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
374
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
|
|
368
375
|
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
369
376
|
# this means that the associated weights should be 0.
|
|
370
377
|
# by doing np.bincount(labels)*weights this will be taken into account
|
|
371
|
-
self.train_weights = weights
|
|
372
378
|
self.train_labels = labels
|
|
373
379
|
self.idx_full = idx_full
|
|
374
380
|
if self.store_location is not None:
|
|
375
381
|
if (
|
|
376
382
|
not os.path.exists(
|
|
377
|
-
os.path.join(self.store_location, "
|
|
383
|
+
os.path.join(self.store_location, "train_labels.npy")
|
|
378
384
|
)
|
|
379
385
|
or self.force_recompute_indices
|
|
380
386
|
):
|
|
381
387
|
os.makedirs(self.store_location, exist_ok=True)
|
|
382
388
|
if self.nnz is not None:
|
|
383
389
|
np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
|
|
384
|
-
np.save(
|
|
385
|
-
os.path.join(self.store_location, "train_weights.npy"),
|
|
386
|
-
self.train_weights,
|
|
387
|
-
)
|
|
388
390
|
np.save(
|
|
389
391
|
os.path.join(self.store_location, "train_labels.npy"),
|
|
390
392
|
self.train_labels,
|
|
@@ -411,9 +413,6 @@ class DataModule(L.LightningDataModule):
|
|
|
411
413
|
if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
|
|
412
414
|
else None
|
|
413
415
|
)
|
|
414
|
-
self.train_weights = np.load(
|
|
415
|
-
os.path.join(self.store_location, "train_weights.npy")
|
|
416
|
-
)
|
|
417
416
|
self.train_labels = np.load(
|
|
418
417
|
os.path.join(self.store_location, "train_labels.npy")
|
|
419
418
|
)
|
|
@@ -446,8 +445,8 @@ class DataModule(L.LightningDataModule):
|
|
|
446
445
|
# Create the optimized parallel sampler
|
|
447
446
|
print(f"Using {self.sampler_workers} workers for class indexing")
|
|
448
447
|
train_sampler = LabelWeightedSampler(
|
|
449
|
-
label_weights=self.train_weights,
|
|
450
448
|
labels=self.train_labels,
|
|
449
|
+
weight_scaler=self.weight_scaler,
|
|
451
450
|
num_samples=int(self.n_samples_per_epoch),
|
|
452
451
|
element_weights=self.nnz,
|
|
453
452
|
replacement=self.replacement,
|
|
@@ -455,15 +454,18 @@ class DataModule(L.LightningDataModule):
|
|
|
455
454
|
chunk_size=self.sampler_chunk_size,
|
|
456
455
|
store_location=self.store_location,
|
|
457
456
|
force_recompute_indices=self.force_recompute_indices,
|
|
457
|
+
curiculum=self.curiculum,
|
|
458
458
|
)
|
|
459
459
|
except ValueError as e:
|
|
460
460
|
raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
|
|
461
|
+
dataset = None
|
|
461
462
|
else:
|
|
462
|
-
|
|
463
|
+
dataset = Subset(self.dataset, self.idx_full)
|
|
464
|
+
train_sampler = RankShardSampler(len(dataset))
|
|
463
465
|
current_loader_kwargs = kwargs.copy()
|
|
464
466
|
current_loader_kwargs.update(self.kwargs)
|
|
465
467
|
return DataLoader(
|
|
466
|
-
self.dataset,
|
|
468
|
+
self.dataset if dataset is None else dataset,
|
|
467
469
|
sampler=train_sampler,
|
|
468
470
|
**current_loader_kwargs,
|
|
469
471
|
)
|
|
@@ -471,12 +473,11 @@ class DataModule(L.LightningDataModule):
|
|
|
471
473
|
def val_dataloader(self):
|
|
472
474
|
return (
|
|
473
475
|
DataLoader(
|
|
474
|
-
self.dataset,
|
|
475
|
-
sampler=SubsetRandomSampler(self.valid_idx),
|
|
476
|
+
Subset(self.dataset, self.valid_idx),
|
|
476
477
|
**self.kwargs,
|
|
477
478
|
)
|
|
478
479
|
if self.valid_idx is not None
|
|
479
|
-
else
|
|
480
|
+
else []
|
|
480
481
|
)
|
|
481
482
|
|
|
482
483
|
def test_dataloader(self):
|
|
@@ -485,20 +486,21 @@ class DataModule(L.LightningDataModule):
|
|
|
485
486
|
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
486
487
|
)
|
|
487
488
|
if self.test_idx is not None
|
|
488
|
-
else
|
|
489
|
+
else []
|
|
489
490
|
)
|
|
490
491
|
|
|
491
492
|
def predict_dataloader(self):
|
|
493
|
+
subset = Subset(self.dataset, self.idx_full)
|
|
492
494
|
return DataLoader(
|
|
493
|
-
|
|
494
|
-
sampler=
|
|
495
|
+
subset,
|
|
496
|
+
sampler=RankShardSampler(len(subset)),
|
|
495
497
|
**self.kwargs,
|
|
496
498
|
)
|
|
497
499
|
|
|
498
500
|
|
|
499
501
|
class LabelWeightedSampler(Sampler[int]):
|
|
500
502
|
"""
|
|
501
|
-
A weighted random sampler that samples from a dataset with respect
|
|
503
|
+
A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
|
|
502
504
|
|
|
503
505
|
This sampler is designed to handle very large datasets efficiently, with optimizations for:
|
|
504
506
|
1. Parallel building of class indices
|
|
@@ -515,21 +517,22 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
515
517
|
|
|
516
518
|
def __init__(
|
|
517
519
|
self,
|
|
518
|
-
label_weights: Sequence[float],
|
|
519
520
|
labels: np.ndarray,
|
|
520
521
|
num_samples: int,
|
|
521
522
|
replacement: bool = True,
|
|
523
|
+
weight_scaler: Optional[float] = None,
|
|
522
524
|
element_weights: Optional[Sequence[float]] = None,
|
|
523
525
|
n_workers: int = None,
|
|
524
526
|
chunk_size: int = None, # Process 10M elements per chunk
|
|
525
527
|
store_location: str = None,
|
|
526
528
|
force_recompute_indices: bool = False,
|
|
529
|
+
curiculum: int = 0,
|
|
527
530
|
) -> None:
|
|
528
531
|
"""
|
|
529
532
|
Initialize the sampler with parallel processing for large datasets.
|
|
530
533
|
|
|
531
534
|
Args:
|
|
532
|
-
|
|
535
|
+
weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
|
|
533
536
|
labels: Class label for each dataset element (length = dataset size)
|
|
534
537
|
num_samples: Number of samples to draw
|
|
535
538
|
replacement: Whether to sample with replacement
|
|
@@ -539,10 +542,14 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
539
542
|
"""
|
|
540
543
|
print("Initializing optimized parallel weighted sampler...")
|
|
541
544
|
super(LabelWeightedSampler, self).__init__(None)
|
|
545
|
+
self.count = 0
|
|
546
|
+
self.curiculum = curiculum
|
|
542
547
|
|
|
543
548
|
# Compute label weights (incorporating class frequencies)
|
|
544
549
|
# Directly use labels as numpy array without conversion
|
|
545
|
-
|
|
550
|
+
counts = np.bincount(labels)
|
|
551
|
+
counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
|
|
552
|
+
label_weights = (weight_scaler * counts) / (counts + weight_scaler)
|
|
546
553
|
self.label_weights = torch.as_tensor(
|
|
547
554
|
label_weights, dtype=torch.float32
|
|
548
555
|
).share_memory_()
|
|
@@ -643,11 +650,16 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
643
650
|
print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
|
|
644
651
|
|
|
645
652
|
def __iter__(self):
|
|
653
|
+
self.count += 1
|
|
646
654
|
# Sample classes according to their weights
|
|
647
655
|
print("sampling a new batch of size", self.num_samples)
|
|
648
656
|
|
|
649
657
|
sample_labels = torch.multinomial(
|
|
650
|
-
|
|
658
|
+
(
|
|
659
|
+
self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
|
|
660
|
+
if self.curiculum
|
|
661
|
+
else self.label_weights
|
|
662
|
+
),
|
|
651
663
|
num_samples=self.num_samples,
|
|
652
664
|
replacement=True,
|
|
653
665
|
)
|
|
@@ -675,6 +687,11 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
675
687
|
# This is a critical point for memory
|
|
676
688
|
current_element_weights_slice = self.element_weights[klass_index]
|
|
677
689
|
|
|
690
|
+
if current_element_weights_slice.shape[0] >= (2**24) - 1:
|
|
691
|
+
ind = torch.randperm(len(klass_index))[: (2**24) - 10]
|
|
692
|
+
klass_index = klass_index[ind]
|
|
693
|
+
current_element_weights_slice = current_element_weights_slice[ind]
|
|
694
|
+
|
|
678
695
|
if self.replacement:
|
|
679
696
|
right_inds = torch.multinomial(
|
|
680
697
|
current_element_weights_slice,
|
|
@@ -827,3 +844,27 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
827
844
|
chunk_indices[int(label)] = indices[label_mask]
|
|
828
845
|
|
|
829
846
|
return chunk_indices
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
class RankShardSampler(Sampler[int]):
|
|
850
|
+
"""Shards a dataset contiguously across ranks without padding or duplicates.
|
|
851
|
+
Preserves the existing order (e.g., your pre-shuffled idx_full)."""
|
|
852
|
+
|
|
853
|
+
def __init__(self, data_len: int):
|
|
854
|
+
self.data_len = data_len
|
|
855
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
856
|
+
self.rank = torch.distributed.get_rank()
|
|
857
|
+
self.world_size = torch.distributed.get_world_size()
|
|
858
|
+
else:
|
|
859
|
+
self.rank, self.world_size = 0, 1
|
|
860
|
+
|
|
861
|
+
# contiguous chunk per rank (last rank may be shorter)
|
|
862
|
+
per_rank = math.ceil(self.data_len / self.world_size)
|
|
863
|
+
self.start = self.rank * per_rank
|
|
864
|
+
self.end = min(self.start + per_rank, self.data_len)
|
|
865
|
+
|
|
866
|
+
def __iter__(self):
|
|
867
|
+
return iter(range(self.start, self.end))
|
|
868
|
+
|
|
869
|
+
def __len__(self):
|
|
870
|
+
return self.end - self.start
|
scdataloader/mapped.py
CHANGED
|
@@ -10,7 +10,7 @@ from __future__ import annotations
|
|
|
10
10
|
import os
|
|
11
11
|
from collections import Counter
|
|
12
12
|
from functools import reduce
|
|
13
|
-
from typing import TYPE_CHECKING, Literal
|
|
13
|
+
from typing import TYPE_CHECKING, List, Literal
|
|
14
14
|
|
|
15
15
|
import numpy as np
|
|
16
16
|
import pandas as pd
|
|
@@ -117,20 +117,20 @@ class MappedCollection:
|
|
|
117
117
|
|
|
118
118
|
def __init__(
|
|
119
119
|
self,
|
|
120
|
-
path_list:
|
|
121
|
-
layers_keys: str |
|
|
122
|
-
obs_keys: str |
|
|
123
|
-
obsm_keys: str |
|
|
120
|
+
path_list: List[UPathStr],
|
|
121
|
+
layers_keys: str | List[str] | None = None,
|
|
122
|
+
obs_keys: str | List[str] | None = None,
|
|
123
|
+
obsm_keys: str | List[str] | None = None,
|
|
124
124
|
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
125
125
|
join: Literal["inner", "outer"] | None = "inner",
|
|
126
|
-
encode_labels: bool |
|
|
126
|
+
encode_labels: bool | List[str] = True,
|
|
127
127
|
unknown_label: str | dict[str, str] | None = None,
|
|
128
128
|
cache_categories: bool = True,
|
|
129
129
|
parallel: bool = False,
|
|
130
130
|
dtype: str | None = None,
|
|
131
131
|
metacell_mode: float = 0.0,
|
|
132
132
|
get_knn_cells: bool = False,
|
|
133
|
-
meta_assays:
|
|
133
|
+
meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
|
|
134
134
|
store_location: str | None = None,
|
|
135
135
|
force_recompute_indices: bool = False,
|
|
136
136
|
):
|
|
@@ -200,7 +200,9 @@ class MappedCollection:
|
|
|
200
200
|
self._cache_categories(self.obs_keys)
|
|
201
201
|
torch.save(self._cache_cats, self.store_location)
|
|
202
202
|
else:
|
|
203
|
-
self._cache_cats = torch.load(
|
|
203
|
+
self._cache_cats = torch.load(
|
|
204
|
+
self.store_location, weights_only=False
|
|
205
|
+
)
|
|
204
206
|
print(f"Loaded categories from {self.store_location}")
|
|
205
207
|
self.encoders: dict = {}
|
|
206
208
|
if self.encode_labels:
|
|
@@ -348,7 +350,7 @@ class MappedCollection:
|
|
|
348
350
|
vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
|
|
349
351
|
return all(vrs_sort_status)
|
|
350
352
|
|
|
351
|
-
def check_vars_non_aligned(self, vars: pd.Index |
|
|
353
|
+
def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
|
|
352
354
|
"""Returns indices of objects with non-aligned variables.
|
|
353
355
|
|
|
354
356
|
Args:
|
|
@@ -380,7 +382,7 @@ class MappedCollection:
|
|
|
380
382
|
return (self.n_obs, self.n_vars)
|
|
381
383
|
|
|
382
384
|
@property
|
|
383
|
-
def original_shapes(self) ->
|
|
385
|
+
def original_shapes(self) -> List[tuple[int, int]]:
|
|
384
386
|
"""Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
|
|
385
387
|
if self.n_vars_list is None:
|
|
386
388
|
n_vars_list = [None] * len(self.n_obs_list)
|
|
@@ -437,7 +439,20 @@ class MappedCollection:
|
|
|
437
439
|
print(out)
|
|
438
440
|
raise
|
|
439
441
|
|
|
440
|
-
if self.
|
|
442
|
+
if self.get_knn_cells:
|
|
443
|
+
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
444
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
445
|
+
out["knn_cells"] = np.array(
|
|
446
|
+
[
|
|
447
|
+
self._get_data_idx(
|
|
448
|
+
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
449
|
+
)
|
|
450
|
+
for i in nn_idx
|
|
451
|
+
],
|
|
452
|
+
dtype=int,
|
|
453
|
+
)
|
|
454
|
+
out["knn_cells_info"] = distances[nn_idx]
|
|
455
|
+
elif self.metacell_mode > 0:
|
|
441
456
|
if (
|
|
442
457
|
len(self.meta_assays) > 0
|
|
443
458
|
and "assay_ontology_term_id" in self.obs_keys
|
|
@@ -454,19 +469,6 @@ class MappedCollection:
|
|
|
454
469
|
out[layers_key] += self._get_data_idx(
|
|
455
470
|
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
456
471
|
)
|
|
457
|
-
elif self.get_knn_cells:
|
|
458
|
-
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
459
|
-
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
460
|
-
out["knn_cells"] = np.array(
|
|
461
|
-
[
|
|
462
|
-
self._get_data_idx(
|
|
463
|
-
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
464
|
-
)
|
|
465
|
-
for i in nn_idx
|
|
466
|
-
],
|
|
467
|
-
dtype=int,
|
|
468
|
-
)
|
|
469
|
-
out["distances"] = distances[nn_idx]
|
|
470
472
|
|
|
471
473
|
return out
|
|
472
474
|
|
|
@@ -541,7 +543,7 @@ class MappedCollection:
|
|
|
541
543
|
|
|
542
544
|
def get_label_weights(
|
|
543
545
|
self,
|
|
544
|
-
obs_keys: str |
|
|
546
|
+
obs_keys: str | List[str],
|
|
545
547
|
scaler: float | None = None,
|
|
546
548
|
return_categories: bool = False,
|
|
547
549
|
):
|