scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -0
- scdataloader/__main__.py +98 -36
- scdataloader/collator.py +13 -7
- scdataloader/config.py +99 -0
- scdataloader/data.py +48 -35
- scdataloader/datamodule.py +138 -44
- scdataloader/mapped.py +656 -0
- scdataloader/preprocess.py +239 -91
- scdataloader/utils.py +71 -27
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/METADATA +10 -8
- scdataloader-1.8.0.dist-info/RECORD +16 -0
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/WHEEL +1 -1
- scdataloader-1.8.0.dist-info/entry_points.txt +2 -0
- scdataloader-1.6.4.dist-info/RECORD +0 -14
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/licenses/LICENSE +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from typing import Optional, Sequence, Union
|
|
2
3
|
|
|
3
4
|
import lamindb as ln
|
|
@@ -15,7 +16,9 @@ from torch.utils.data.sampler import (
|
|
|
15
16
|
|
|
16
17
|
from .collator import Collator
|
|
17
18
|
from .data import Dataset
|
|
18
|
-
from .utils import getBiomartTable
|
|
19
|
+
from .utils import getBiomartTable, slurm_restart_count
|
|
20
|
+
|
|
21
|
+
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class DataModule(L.LightningDataModule):
|
|
@@ -32,22 +35,24 @@ class DataModule(L.LightningDataModule):
|
|
|
32
35
|
use_default_col: bool = True,
|
|
33
36
|
gene_position_tolerance: int = 10_000,
|
|
34
37
|
# this is for the mappedCollection
|
|
35
|
-
|
|
36
|
-
all_clss: list = ["organism_ontology_term_id"],
|
|
38
|
+
clss_to_predict: list = ["organism_ontology_term_id"],
|
|
37
39
|
hierarchical_clss: list = [],
|
|
38
40
|
# this is for the collator
|
|
39
41
|
how: str = "random expr",
|
|
40
42
|
organism_name: str = "organism_ontology_term_id",
|
|
41
43
|
max_len: int = 1000,
|
|
42
44
|
add_zero_genes: int = 100,
|
|
45
|
+
replacement: bool = True,
|
|
43
46
|
do_gene_pos: Union[bool, str] = True,
|
|
44
47
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
45
48
|
assays_to_drop: list = [
|
|
46
|
-
"EFO:0008853",
|
|
47
|
-
"EFO:0010961",
|
|
48
|
-
"EFO:0030007",
|
|
49
|
-
"EFO:0030062",
|
|
49
|
+
# "EFO:0008853", #patch seq
|
|
50
|
+
# "EFO:0010961", # visium
|
|
51
|
+
"EFO:0030007", # ATACseq
|
|
52
|
+
# "EFO:0030062", # slide-seq
|
|
50
53
|
],
|
|
54
|
+
metacell_mode: float = 0.0,
|
|
55
|
+
modify_seed_on_requeue: bool = True,
|
|
51
56
|
**kwargs,
|
|
52
57
|
):
|
|
53
58
|
"""
|
|
@@ -59,7 +64,6 @@ class DataModule(L.LightningDataModule):
|
|
|
59
64
|
|
|
60
65
|
Args:
|
|
61
66
|
collection_name (str): The lamindb collection to be used.
|
|
62
|
-
clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
|
|
63
67
|
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
64
68
|
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
65
69
|
train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
|
|
@@ -81,23 +85,24 @@ class DataModule(L.LightningDataModule):
|
|
|
81
85
|
organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
82
86
|
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
83
87
|
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
|
|
84
|
-
|
|
85
|
-
|
|
88
|
+
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
89
|
+
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
90
|
+
modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
|
|
86
91
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
87
|
-
|
|
88
92
|
see @file data.py and @file collator.py for more details about some of the parameters
|
|
89
93
|
"""
|
|
90
94
|
if collection_name is not None:
|
|
91
95
|
mdataset = Dataset(
|
|
92
96
|
ln.Collection.filter(name=collection_name).first(),
|
|
93
97
|
organisms=organisms,
|
|
94
|
-
|
|
95
|
-
clss_to_pred=clss_to_pred,
|
|
98
|
+
clss_to_predict=clss_to_predict,
|
|
96
99
|
hierarchical_clss=hierarchical_clss,
|
|
100
|
+
metacell_mode=metacell_mode,
|
|
97
101
|
)
|
|
98
|
-
# print(mdataset)
|
|
99
102
|
# and location
|
|
103
|
+
self.metacell_mode = bool(metacell_mode)
|
|
100
104
|
self.gene_pos = None
|
|
105
|
+
self.collection_name = collection_name
|
|
101
106
|
if do_gene_pos:
|
|
102
107
|
if type(do_gene_pos) is str:
|
|
103
108
|
print("seeing a string: loading gene positions as biomart parquet file")
|
|
@@ -127,7 +132,7 @@ class DataModule(L.LightningDataModule):
|
|
|
127
132
|
c.append(i)
|
|
128
133
|
prev_position = r["start_position"]
|
|
129
134
|
prev_chromosome = r["chromosome_name"]
|
|
130
|
-
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
135
|
+
print(f"reduced the size to {len(set(c)) / len(biomart)}")
|
|
131
136
|
biomart["pos"] = c
|
|
132
137
|
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
|
|
133
138
|
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
@@ -151,11 +156,13 @@ class DataModule(L.LightningDataModule):
|
|
|
151
156
|
org_to_id=mdataset.encoder[organism_name],
|
|
152
157
|
tp_name=tp_name,
|
|
153
158
|
organism_name=organism_name,
|
|
154
|
-
class_names=
|
|
159
|
+
class_names=clss_to_predict,
|
|
160
|
+
metacell_mode=bool(metacell_mode),
|
|
155
161
|
)
|
|
156
162
|
self.validation_split = validation_split
|
|
157
163
|
self.test_split = test_split
|
|
158
164
|
self.dataset = mdataset
|
|
165
|
+
self.replacement = replacement
|
|
159
166
|
self.kwargs = kwargs
|
|
160
167
|
if "sampler" in self.kwargs:
|
|
161
168
|
self.kwargs.pop("sampler")
|
|
@@ -166,6 +173,9 @@ class DataModule(L.LightningDataModule):
|
|
|
166
173
|
self.clss_to_weight = clss_to_weight
|
|
167
174
|
self.train_weights = None
|
|
168
175
|
self.train_labels = None
|
|
176
|
+
self.modify_seed_on_requeue = modify_seed_on_requeue
|
|
177
|
+
self.nnz = None
|
|
178
|
+
self.restart_num = 0
|
|
169
179
|
self.test_datasets = []
|
|
170
180
|
self.test_idx = []
|
|
171
181
|
super().__init__()
|
|
@@ -184,12 +194,8 @@ class DataModule(L.LightningDataModule):
|
|
|
184
194
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
185
195
|
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
186
196
|
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
187
|
-
+ (
|
|
188
|
-
|
|
189
|
-
+ str((self.train_weights != 0).sum())
|
|
190
|
-
+ ")\n)"
|
|
191
|
-
)
|
|
192
|
-
if self.train_weights is not None
|
|
197
|
+
+ ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
|
|
198
|
+
if self.idx_full is not None
|
|
193
199
|
else ")"
|
|
194
200
|
)
|
|
195
201
|
|
|
@@ -247,13 +253,23 @@ class DataModule(L.LightningDataModule):
|
|
|
247
253
|
stage (str, optional): The stage of the model training process.
|
|
248
254
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
249
255
|
"""
|
|
256
|
+
SCALE = 10
|
|
257
|
+
if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
|
|
258
|
+
self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
|
|
259
|
+
self.clss_to_weight.remove("nnz")
|
|
260
|
+
(
|
|
261
|
+
(self.nnz.max() / SCALE)
|
|
262
|
+
/ ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
|
|
263
|
+
).min()
|
|
250
264
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
251
265
|
weights, labels = self.dataset.get_label_weights(
|
|
252
|
-
self.clss_to_weight,
|
|
266
|
+
self.clss_to_weight,
|
|
267
|
+
scaler=self.weight_scaler,
|
|
268
|
+
return_categories=True,
|
|
253
269
|
)
|
|
254
270
|
else:
|
|
255
271
|
weights = np.ones(1)
|
|
256
|
-
labels = np.zeros(self.n_samples)
|
|
272
|
+
labels = np.zeros(self.n_samples, dtype=int)
|
|
257
273
|
if isinstance(self.validation_split, int):
|
|
258
274
|
len_valid = self.validation_split
|
|
259
275
|
else:
|
|
@@ -268,12 +284,11 @@ class DataModule(L.LightningDataModule):
|
|
|
268
284
|
|
|
269
285
|
idx_full = []
|
|
270
286
|
if len(self.assays_to_drop) > 0:
|
|
271
|
-
|
|
272
|
-
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
idx_full = np.array(idx_full)
|
|
287
|
+
badloc = np.isin(
|
|
288
|
+
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
|
|
289
|
+
self.assays_to_drop,
|
|
290
|
+
)
|
|
291
|
+
idx_full = np.arange(len(labels))[~badloc]
|
|
277
292
|
else:
|
|
278
293
|
idx_full = np.arange(self.n_samples)
|
|
279
294
|
if len_test > 0:
|
|
@@ -302,12 +317,15 @@ class DataModule(L.LightningDataModule):
|
|
|
302
317
|
np.random.shuffle(idx_full)
|
|
303
318
|
if len_valid > 0:
|
|
304
319
|
self.valid_idx = idx_full[:len_valid].copy()
|
|
320
|
+
# store it for later
|
|
305
321
|
idx_full = idx_full[len_valid:]
|
|
306
322
|
else:
|
|
307
323
|
self.valid_idx = None
|
|
308
324
|
weights = np.concatenate([weights, np.zeros(1)])
|
|
309
325
|
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
310
|
-
|
|
326
|
+
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
327
|
+
# this means that the associated weights should be 0.
|
|
328
|
+
# by doing np.bincount(labels)*weights this will be taken into account
|
|
311
329
|
self.train_weights = weights
|
|
312
330
|
self.train_labels = labels
|
|
313
331
|
self.idx_full = idx_full
|
|
@@ -319,17 +337,31 @@ class DataModule(L.LightningDataModule):
|
|
|
319
337
|
# int(self.n_samples*self.train_oversampling_per_epoch),
|
|
320
338
|
# replacement=True,
|
|
321
339
|
# )
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
340
|
+
try:
|
|
341
|
+
train_sampler = LabelWeightedSampler(
|
|
342
|
+
label_weights=self.train_weights,
|
|
343
|
+
labels=self.train_labels,
|
|
344
|
+
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
345
|
+
element_weights=self.nnz,
|
|
346
|
+
replacement=self.replacement,
|
|
347
|
+
restart_num=self.restart_num,
|
|
348
|
+
modify_seed_on_requeue=self.modify_seed_on_requeue,
|
|
349
|
+
)
|
|
350
|
+
except ValueError as e:
|
|
351
|
+
raise ValueError(e + "have you run `datamodule.setup()`?")
|
|
352
|
+
return DataLoader(
|
|
353
|
+
self.dataset,
|
|
354
|
+
sampler=train_sampler,
|
|
355
|
+
**self.kwargs,
|
|
356
|
+
**kwargs,
|
|
326
357
|
)
|
|
327
|
-
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
328
358
|
|
|
329
359
|
def val_dataloader(self):
|
|
330
360
|
return (
|
|
331
361
|
DataLoader(
|
|
332
|
-
self.dataset,
|
|
362
|
+
self.dataset,
|
|
363
|
+
sampler=SubsetRandomSampler(self.valid_idx),
|
|
364
|
+
**self.kwargs,
|
|
333
365
|
)
|
|
334
366
|
if self.valid_idx is not None
|
|
335
367
|
else None
|
|
@@ -346,7 +378,9 @@ class DataModule(L.LightningDataModule):
|
|
|
346
378
|
|
|
347
379
|
def predict_dataloader(self):
|
|
348
380
|
return DataLoader(
|
|
349
|
-
self.dataset,
|
|
381
|
+
self.dataset,
|
|
382
|
+
sampler=SubsetRandomSampler(self.idx_full),
|
|
383
|
+
**self.kwargs,
|
|
350
384
|
)
|
|
351
385
|
|
|
352
386
|
# def teardown(self):
|
|
@@ -359,18 +393,29 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
359
393
|
label_weights: Sequence[float]
|
|
360
394
|
klass_indices: Sequence[Sequence[int]]
|
|
361
395
|
num_samples: int
|
|
362
|
-
|
|
396
|
+
nnz: Optional[Sequence[int]]
|
|
397
|
+
replacement: bool
|
|
398
|
+
restart_num: int
|
|
399
|
+
modify_seed_on_requeue: bool
|
|
363
400
|
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
364
401
|
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
365
|
-
|
|
402
|
+
|
|
366
403
|
def __init__(
|
|
367
|
-
self,
|
|
404
|
+
self,
|
|
405
|
+
label_weights: Sequence[float],
|
|
406
|
+
labels: Sequence[int],
|
|
407
|
+
num_samples: int,
|
|
408
|
+
replacement: bool = True,
|
|
409
|
+
element_weights: Sequence[float] = None,
|
|
410
|
+
restart_num: int = 0,
|
|
411
|
+
modify_seed_on_requeue: bool = True,
|
|
368
412
|
) -> None:
|
|
369
413
|
"""
|
|
370
414
|
|
|
371
415
|
:param label_weights: list(len=num_classes)[float], weights for each class.
|
|
372
416
|
:param labels: list(len=dataset_len)[int], labels of a dataset.
|
|
373
417
|
:param num_samples: number of samples.
|
|
418
|
+
:param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
|
|
374
419
|
"""
|
|
375
420
|
|
|
376
421
|
super(LabelWeightedSampler, self).__init__(None)
|
|
@@ -379,24 +424,73 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
379
424
|
|
|
380
425
|
self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
|
|
381
426
|
self.labels = torch.as_tensor(labels, dtype=torch.int)
|
|
427
|
+
self.element_weights = (
|
|
428
|
+
torch.as_tensor(element_weights, dtype=torch.float32)
|
|
429
|
+
if element_weights is not None
|
|
430
|
+
else None
|
|
431
|
+
)
|
|
432
|
+
self.replacement = replacement
|
|
382
433
|
self.num_samples = num_samples
|
|
434
|
+
self.restart_num = slurm_restart_count(use_mine=True) + restart_num
|
|
435
|
+
self.modify_seed_on_requeue = modify_seed_on_requeue
|
|
383
436
|
# list of tensor.
|
|
384
437
|
self.klass_indices = [
|
|
385
438
|
(self.labels == i_klass).nonzero().squeeze(1)
|
|
386
439
|
for i_klass in range(len(label_weights))
|
|
387
440
|
]
|
|
441
|
+
self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
|
|
388
442
|
|
|
389
443
|
def __iter__(self):
|
|
390
444
|
sample_labels = torch.multinomial(
|
|
391
|
-
self.label_weights,
|
|
445
|
+
self.label_weights,
|
|
446
|
+
num_samples=self.num_samples,
|
|
447
|
+
replacement=True,
|
|
448
|
+
generator=None
|
|
449
|
+
if self.restart_num == 0 and not self.modify_seed_on_requeue
|
|
450
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
392
451
|
)
|
|
393
452
|
sample_indices = torch.empty_like(sample_labels)
|
|
394
453
|
for i_klass, klass_index in enumerate(self.klass_indices):
|
|
395
454
|
if klass_index.numel() == 0:
|
|
396
455
|
continue
|
|
397
456
|
left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
|
|
398
|
-
|
|
399
|
-
|
|
457
|
+
if len(left_inds) == 0:
|
|
458
|
+
continue
|
|
459
|
+
if self.element_weights is not None:
|
|
460
|
+
right_inds = torch.multinomial(
|
|
461
|
+
self.element_weights[klass_index],
|
|
462
|
+
num_samples=len(klass_index)
|
|
463
|
+
if not self.replacement and len(klass_index) < len(left_inds)
|
|
464
|
+
else len(left_inds),
|
|
465
|
+
replacement=self.replacement,
|
|
466
|
+
generator=None
|
|
467
|
+
if self.restart_num == 0 and not self.modify_seed_on_requeue
|
|
468
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
469
|
+
)
|
|
470
|
+
elif self.replacement:
|
|
471
|
+
right_inds = torch.randint(
|
|
472
|
+
len(klass_index),
|
|
473
|
+
size=(len(left_inds),),
|
|
474
|
+
generator=None
|
|
475
|
+
if self.restart_num == 0 and not self.modify_seed_on_requeue
|
|
476
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
maxelem = (
|
|
480
|
+
len(left_inds)
|
|
481
|
+
if len(left_inds) < len(klass_index)
|
|
482
|
+
else len(klass_index)
|
|
483
|
+
)
|
|
484
|
+
right_inds = torch.randperm(len(klass_index))[:maxelem]
|
|
485
|
+
sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
|
|
486
|
+
if len(right_inds) < len(left_inds):
|
|
487
|
+
sample_indices[left_inds[len(right_inds) :]] = -1
|
|
488
|
+
# drop all -1
|
|
489
|
+
sample_indices = sample_indices[sample_indices != -1]
|
|
490
|
+
# torch shuffle
|
|
491
|
+
sample_indices = sample_indices[torch.randperm(len(sample_indices))]
|
|
492
|
+
self.num_samples = len(sample_indices)
|
|
493
|
+
# raise Exception("stop")
|
|
400
494
|
yield from iter(sample_indices.tolist())
|
|
401
495
|
|
|
402
496
|
def __len__(self):
|