scdataloader 1.6.4__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -0
- scdataloader/__main__.py +38 -8
- scdataloader/collator.py +6 -2
- scdataloader/config.py +99 -0
- scdataloader/data.py +42 -35
- scdataloader/datamodule.py +123 -40
- scdataloader/mapped.py +700 -0
- scdataloader/preprocess.py +229 -86
- scdataloader/utils.py +49 -27
- {scdataloader-1.6.4.dist-info → scdataloader-1.7.0.dist-info}/METADATA +7 -6
- scdataloader-1.7.0.dist-info/RECORD +15 -0
- {scdataloader-1.6.4.dist-info → scdataloader-1.7.0.dist-info}/WHEEL +1 -1
- scdataloader-1.6.4.dist-info/RECORD +0 -14
- {scdataloader-1.6.4.dist-info → scdataloader-1.7.0.dist-info}/licenses/LICENSE +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -16,6 +16,10 @@ from torch.utils.data.sampler import (
|
|
|
16
16
|
from .collator import Collator
|
|
17
17
|
from .data import Dataset
|
|
18
18
|
from .utils import getBiomartTable
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class DataModule(L.LightningDataModule):
|
|
@@ -32,22 +36,24 @@ class DataModule(L.LightningDataModule):
|
|
|
32
36
|
use_default_col: bool = True,
|
|
33
37
|
gene_position_tolerance: int = 10_000,
|
|
34
38
|
# this is for the mappedCollection
|
|
35
|
-
|
|
36
|
-
all_clss: list = ["organism_ontology_term_id"],
|
|
39
|
+
clss_to_predict: list = ["organism_ontology_term_id"],
|
|
37
40
|
hierarchical_clss: list = [],
|
|
38
41
|
# this is for the collator
|
|
39
42
|
how: str = "random expr",
|
|
40
43
|
organism_name: str = "organism_ontology_term_id",
|
|
41
44
|
max_len: int = 1000,
|
|
42
45
|
add_zero_genes: int = 100,
|
|
46
|
+
replacement: bool = True,
|
|
43
47
|
do_gene_pos: Union[bool, str] = True,
|
|
44
48
|
tp_name: Optional[str] = None, # "heat_diff"
|
|
45
49
|
assays_to_drop: list = [
|
|
46
|
-
"EFO:0008853",
|
|
47
|
-
"EFO:0010961",
|
|
48
|
-
"EFO:0030007",
|
|
49
|
-
"EFO:0030062",
|
|
50
|
+
# "EFO:0008853", #patch seq
|
|
51
|
+
# "EFO:0010961", # visium
|
|
52
|
+
"EFO:0030007", # ATACseq
|
|
53
|
+
# "EFO:0030062", # slide-seq
|
|
50
54
|
],
|
|
55
|
+
restart_num: int = 0,
|
|
56
|
+
metacell_mode: float = 0.0,
|
|
51
57
|
**kwargs,
|
|
52
58
|
):
|
|
53
59
|
"""
|
|
@@ -59,7 +65,6 @@ class DataModule(L.LightningDataModule):
|
|
|
59
65
|
|
|
60
66
|
Args:
|
|
61
67
|
collection_name (str): The lamindb collection to be used.
|
|
62
|
-
clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
|
|
63
68
|
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
64
69
|
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
65
70
|
train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
|
|
@@ -81,23 +86,23 @@ class DataModule(L.LightningDataModule):
|
|
|
81
86
|
organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
82
87
|
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
83
88
|
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
|
|
84
|
-
|
|
85
|
-
|
|
89
|
+
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
|
|
90
|
+
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
86
91
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
87
|
-
|
|
92
|
+
restart_num (int, optional): The number of the restart if we are continuing a previous run -> /!\ NEEDS TO BE SET. Defaults to 0.
|
|
88
93
|
see @file data.py and @file collator.py for more details about some of the parameters
|
|
89
94
|
"""
|
|
90
95
|
if collection_name is not None:
|
|
91
96
|
mdataset = Dataset(
|
|
92
97
|
ln.Collection.filter(name=collection_name).first(),
|
|
93
98
|
organisms=organisms,
|
|
94
|
-
|
|
95
|
-
clss_to_pred=clss_to_pred,
|
|
99
|
+
clss_to_predict=clss_to_predict,
|
|
96
100
|
hierarchical_clss=hierarchical_clss,
|
|
101
|
+
metacell_mode=metacell_mode,
|
|
97
102
|
)
|
|
98
|
-
# print(mdataset)
|
|
99
103
|
# and location
|
|
100
104
|
self.gene_pos = None
|
|
105
|
+
self.collection_name = collection_name
|
|
101
106
|
if do_gene_pos:
|
|
102
107
|
if type(do_gene_pos) is str:
|
|
103
108
|
print("seeing a string: loading gene positions as biomart parquet file")
|
|
@@ -151,11 +156,12 @@ class DataModule(L.LightningDataModule):
|
|
|
151
156
|
org_to_id=mdataset.encoder[organism_name],
|
|
152
157
|
tp_name=tp_name,
|
|
153
158
|
organism_name=organism_name,
|
|
154
|
-
class_names=
|
|
159
|
+
class_names=clss_to_predict,
|
|
155
160
|
)
|
|
156
161
|
self.validation_split = validation_split
|
|
157
162
|
self.test_split = test_split
|
|
158
163
|
self.dataset = mdataset
|
|
164
|
+
self.replacement = replacement
|
|
159
165
|
self.kwargs = kwargs
|
|
160
166
|
if "sampler" in self.kwargs:
|
|
161
167
|
self.kwargs.pop("sampler")
|
|
@@ -166,6 +172,8 @@ class DataModule(L.LightningDataModule):
|
|
|
166
172
|
self.clss_to_weight = clss_to_weight
|
|
167
173
|
self.train_weights = None
|
|
168
174
|
self.train_labels = None
|
|
175
|
+
self.nnz = None
|
|
176
|
+
self.restart_num = restart_num
|
|
169
177
|
self.test_datasets = []
|
|
170
178
|
self.test_idx = []
|
|
171
179
|
super().__init__()
|
|
@@ -184,12 +192,8 @@ class DataModule(L.LightningDataModule):
|
|
|
184
192
|
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
185
193
|
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
186
194
|
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
187
|
-
+ (
|
|
188
|
-
|
|
189
|
-
+ str((self.train_weights != 0).sum())
|
|
190
|
-
+ ")\n)"
|
|
191
|
-
)
|
|
192
|
-
if self.train_weights is not None
|
|
195
|
+
+ ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
|
|
196
|
+
if self.idx_full is not None
|
|
193
197
|
else ")"
|
|
194
198
|
)
|
|
195
199
|
|
|
@@ -247,9 +251,17 @@ class DataModule(L.LightningDataModule):
|
|
|
247
251
|
stage (str, optional): The stage of the model training process.
|
|
248
252
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
249
253
|
"""
|
|
254
|
+
SCALE = 10
|
|
250
255
|
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
256
|
+
if "nnz" in self.clss_to_weight:
|
|
257
|
+
self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
|
|
258
|
+
self.clss_to_weight.remove("nnz")
|
|
259
|
+
(
|
|
260
|
+
(self.nnz.max() / SCALE)
|
|
261
|
+
/ ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
|
|
262
|
+
).min()
|
|
251
263
|
weights, labels = self.dataset.get_label_weights(
|
|
252
|
-
self.clss_to_weight, scaler=self.weight_scaler
|
|
264
|
+
self.clss_to_weight, scaler=self.weight_scaler, return_categories=True
|
|
253
265
|
)
|
|
254
266
|
else:
|
|
255
267
|
weights = np.ones(1)
|
|
@@ -268,12 +280,11 @@ class DataModule(L.LightningDataModule):
|
|
|
268
280
|
|
|
269
281
|
idx_full = []
|
|
270
282
|
if len(self.assays_to_drop) > 0:
|
|
271
|
-
|
|
272
|
-
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
idx_full = np.array(idx_full)
|
|
283
|
+
badloc = np.isin(
|
|
284
|
+
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
|
|
285
|
+
self.assays_to_drop,
|
|
286
|
+
)
|
|
287
|
+
idx_full = np.arange(len(labels))[~badloc]
|
|
277
288
|
else:
|
|
278
289
|
idx_full = np.arange(self.n_samples)
|
|
279
290
|
if len_test > 0:
|
|
@@ -302,12 +313,15 @@ class DataModule(L.LightningDataModule):
|
|
|
302
313
|
np.random.shuffle(idx_full)
|
|
303
314
|
if len_valid > 0:
|
|
304
315
|
self.valid_idx = idx_full[:len_valid].copy()
|
|
316
|
+
# store it for later
|
|
305
317
|
idx_full = idx_full[len_valid:]
|
|
306
318
|
else:
|
|
307
319
|
self.valid_idx = None
|
|
308
320
|
weights = np.concatenate([weights, np.zeros(1)])
|
|
309
321
|
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
310
|
-
|
|
322
|
+
# some labels will now not exist anymore as replaced by len(weights) - 1.
|
|
323
|
+
# this means that the associated weights should be 0.
|
|
324
|
+
# by doing np.bincount(labels)*weights this will be taken into account
|
|
311
325
|
self.train_weights = weights
|
|
312
326
|
self.train_labels = labels
|
|
313
327
|
self.idx_full = idx_full
|
|
@@ -319,17 +333,30 @@ class DataModule(L.LightningDataModule):
|
|
|
319
333
|
# int(self.n_samples*self.train_oversampling_per_epoch),
|
|
320
334
|
# replacement=True,
|
|
321
335
|
# )
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
336
|
+
try:
|
|
337
|
+
train_sampler = LabelWeightedSampler(
|
|
338
|
+
self.train_weights,
|
|
339
|
+
self.train_labels,
|
|
340
|
+
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
341
|
+
element_weights=self.nnz,
|
|
342
|
+
replacement=self.replacement,
|
|
343
|
+
restart_num=self.restart_num,
|
|
344
|
+
)
|
|
345
|
+
except ValueError as e:
|
|
346
|
+
raise ValueError(e + "have you run `datamodule.setup()`?")
|
|
347
|
+
return DataLoader(
|
|
348
|
+
self.dataset,
|
|
349
|
+
sampler=train_sampler,
|
|
350
|
+
**self.kwargs,
|
|
351
|
+
**kwargs,
|
|
326
352
|
)
|
|
327
|
-
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
328
353
|
|
|
329
354
|
def val_dataloader(self):
|
|
330
355
|
return (
|
|
331
356
|
DataLoader(
|
|
332
|
-
self.dataset,
|
|
357
|
+
self.dataset,
|
|
358
|
+
sampler=SubsetRandomSampler(self.valid_idx),
|
|
359
|
+
**self.kwargs,
|
|
333
360
|
)
|
|
334
361
|
if self.valid_idx is not None
|
|
335
362
|
else None
|
|
@@ -346,7 +373,9 @@ class DataModule(L.LightningDataModule):
|
|
|
346
373
|
|
|
347
374
|
def predict_dataloader(self):
|
|
348
375
|
return DataLoader(
|
|
349
|
-
self.dataset,
|
|
376
|
+
self.dataset,
|
|
377
|
+
sampler=SubsetRandomSampler(self.idx_full),
|
|
378
|
+
**self.kwargs,
|
|
350
379
|
)
|
|
351
380
|
|
|
352
381
|
# def teardown(self):
|
|
@@ -359,18 +388,27 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
359
388
|
label_weights: Sequence[float]
|
|
360
389
|
klass_indices: Sequence[Sequence[int]]
|
|
361
390
|
num_samples: int
|
|
362
|
-
|
|
391
|
+
nnz: Optional[Sequence[int]]
|
|
392
|
+
replacement: bool
|
|
393
|
+
restart_num: int
|
|
363
394
|
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
364
395
|
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
365
|
-
|
|
396
|
+
|
|
366
397
|
def __init__(
|
|
367
|
-
self,
|
|
398
|
+
self,
|
|
399
|
+
label_weights: Sequence[float],
|
|
400
|
+
labels: Sequence[int],
|
|
401
|
+
num_samples: int,
|
|
402
|
+
replacement: bool = True,
|
|
403
|
+
element_weights: Sequence[float] = None,
|
|
404
|
+
restart_num=0,
|
|
368
405
|
) -> None:
|
|
369
406
|
"""
|
|
370
407
|
|
|
371
408
|
:param label_weights: list(len=num_classes)[float], weights for each class.
|
|
372
409
|
:param labels: list(len=dataset_len)[int], labels of a dataset.
|
|
373
410
|
:param num_samples: number of samples.
|
|
411
|
+
:param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
|
|
374
412
|
"""
|
|
375
413
|
|
|
376
414
|
super(LabelWeightedSampler, self).__init__(None)
|
|
@@ -379,24 +417,69 @@ class LabelWeightedSampler(Sampler[int]):
|
|
|
379
417
|
|
|
380
418
|
self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
|
|
381
419
|
self.labels = torch.as_tensor(labels, dtype=torch.int)
|
|
420
|
+
self.element_weights = (
|
|
421
|
+
torch.as_tensor(element_weights, dtype=torch.float32)
|
|
422
|
+
if element_weights is not None
|
|
423
|
+
else None
|
|
424
|
+
)
|
|
425
|
+
self.replacement = replacement
|
|
382
426
|
self.num_samples = num_samples
|
|
427
|
+
self.restart_num = restart_num
|
|
383
428
|
# list of tensor.
|
|
384
429
|
self.klass_indices = [
|
|
385
430
|
(self.labels == i_klass).nonzero().squeeze(1)
|
|
386
431
|
for i_klass in range(len(label_weights))
|
|
387
432
|
]
|
|
433
|
+
self.klass_sizes = [len(klass_indices) for klass_indices in self.klass_indices]
|
|
388
434
|
|
|
389
435
|
def __iter__(self):
|
|
390
436
|
sample_labels = torch.multinomial(
|
|
391
|
-
self.label_weights,
|
|
437
|
+
self.label_weights,
|
|
438
|
+
num_samples=self.num_samples,
|
|
439
|
+
replacement=True,
|
|
440
|
+
generator=None
|
|
441
|
+
if self.restart_num == 0
|
|
442
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
392
443
|
)
|
|
393
444
|
sample_indices = torch.empty_like(sample_labels)
|
|
445
|
+
|
|
394
446
|
for i_klass, klass_index in enumerate(self.klass_indices):
|
|
395
447
|
if klass_index.numel() == 0:
|
|
396
448
|
continue
|
|
397
449
|
left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
|
|
398
|
-
|
|
450
|
+
if len(left_inds) == 0:
|
|
451
|
+
continue
|
|
452
|
+
if self.element_weights is not None:
|
|
453
|
+
right_inds = torch.multinomial(
|
|
454
|
+
self.element_weights[klass_index],
|
|
455
|
+
num_samples=len(klass_index)
|
|
456
|
+
if not self.replacement and len(klass_index) < len(left_inds)
|
|
457
|
+
else len(left_inds),
|
|
458
|
+
replacement=self.replacement,
|
|
459
|
+
generator=None
|
|
460
|
+
if self.restart_num == 0
|
|
461
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
462
|
+
)
|
|
463
|
+
elif self.replacement:
|
|
464
|
+
right_inds = torch.randint(
|
|
465
|
+
len(klass_index),
|
|
466
|
+
size=(len(left_inds),),
|
|
467
|
+
generator=None
|
|
468
|
+
if self.restart_num == 0
|
|
469
|
+
else torch.Generator().manual_seed(self.restart_num),
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
maxelem = (
|
|
473
|
+
len(left_inds)
|
|
474
|
+
if len(left_inds) < len(klass_index)
|
|
475
|
+
else len(klass_index)
|
|
476
|
+
)
|
|
477
|
+
right_inds = torch.randperm(len(klass_index))[:maxelem]
|
|
399
478
|
sample_indices[left_inds] = klass_index[right_inds]
|
|
479
|
+
# torch shuffle
|
|
480
|
+
sample_indices = sample_indices[torch.randperm(len(sample_indices))]
|
|
481
|
+
print(sample_indices.tolist()[:10], sample_labels[:10])
|
|
482
|
+
# raise Exception("stop")
|
|
400
483
|
yield from iter(sample_indices.tolist())
|
|
401
484
|
|
|
402
485
|
def __len__(self):
|