scdataloader 0.0.4__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__main__.py +3 -0
- scdataloader/collator.py +56 -31
- scdataloader/config.py +6 -0
- scdataloader/data.py +98 -87
- scdataloader/datamodule.py +66 -38
- scdataloader/mapped.py +266 -105
- scdataloader/preprocess.py +3 -207
- scdataloader/utils.py +57 -8
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.1.dist-info}/METADATA +45 -20
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader-0.0.4.dist-info/RECORD +0 -16
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
scdataloader/datamodule.py
CHANGED
|
@@ -6,6 +6,7 @@ from torch.utils.data.sampler import (
|
|
|
6
6
|
WeightedRandomSampler,
|
|
7
7
|
SubsetRandomSampler,
|
|
8
8
|
SequentialSampler,
|
|
9
|
+
RandomSampler,
|
|
9
10
|
)
|
|
10
11
|
import torch
|
|
11
12
|
from torch.utils.data import DataLoader, Sampler
|
|
@@ -22,7 +23,7 @@ class DataModule(L.LightningDataModule):
|
|
|
22
23
|
def __init__(
|
|
23
24
|
self,
|
|
24
25
|
collection_name: str,
|
|
25
|
-
|
|
26
|
+
clss_to_weight: list = ["organism_ontology_term_id"],
|
|
26
27
|
organisms: list = ["NCBITaxon:9606"],
|
|
27
28
|
weight_scaler: int = 10,
|
|
28
29
|
train_oversampling_per_epoch: float = 0.1,
|
|
@@ -32,9 +33,9 @@ class DataModule(L.LightningDataModule):
|
|
|
32
33
|
use_default_col: bool = True,
|
|
33
34
|
gene_position_tolerance: int = 10_000,
|
|
34
35
|
# this is for the mappedCollection
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
clss_to_pred: list = ["organism_ontology_term_id"],
|
|
37
|
+
all_clss: list = ["organism_ontology_term_id"],
|
|
38
|
+
hierarchical_clss: list = [],
|
|
38
39
|
# this is for the collator
|
|
39
40
|
how: str = "random expr",
|
|
40
41
|
organism_name: str = "organism_ontology_term_id",
|
|
@@ -59,36 +60,55 @@ class DataModule(L.LightningDataModule):
|
|
|
59
60
|
|
|
60
61
|
Args:
|
|
61
62
|
collection_name (str): The lamindb collection to be used.
|
|
62
|
-
|
|
63
|
-
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
64
|
-
any genes within this distance of each other will be considered at the same position.
|
|
65
|
-
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
|
|
66
|
-
the file must have ensembl_gene_id as index.
|
|
67
|
-
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
63
|
+
clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
|
|
68
64
|
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
69
|
-
|
|
65
|
+
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
66
|
+
train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
|
|
70
67
|
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
71
68
|
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
72
69
|
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
73
|
-
|
|
70
|
+
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
|
|
71
|
+
the file must have ensembl_gene_id as index.
|
|
72
|
+
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
73
|
+
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
74
|
+
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
75
|
+
any genes within this distance of each other will be considered at the same position.
|
|
76
|
+
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
77
|
+
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
|
|
78
|
+
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
|
|
79
|
+
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
|
|
80
|
+
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
|
|
81
|
+
how (str, optional): The method to use for the collator. Defaults to "random expr".
|
|
82
|
+
organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
83
|
+
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
84
|
+
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
|
|
85
|
+
all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
|
|
86
|
+
clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
74
87
|
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
88
|
+
|
|
89
|
+
see @file data.py and @file collator.py for more details about some of the parameters
|
|
75
90
|
"""
|
|
76
91
|
if collection_name is not None:
|
|
77
92
|
mdataset = Dataset(
|
|
78
93
|
ln.Collection.filter(name=collection_name).first(),
|
|
79
94
|
organisms=organisms,
|
|
80
|
-
obs=
|
|
81
|
-
clss_to_pred=
|
|
82
|
-
hierarchical_clss=
|
|
95
|
+
obs=all_clss,
|
|
96
|
+
clss_to_pred=clss_to_pred,
|
|
97
|
+
hierarchical_clss=hierarchical_clss,
|
|
83
98
|
)
|
|
84
|
-
print(mdataset)
|
|
99
|
+
# print(mdataset)
|
|
85
100
|
# and location
|
|
101
|
+
self.gene_pos = None
|
|
86
102
|
if do_gene_pos:
|
|
87
103
|
if type(do_gene_pos) is str:
|
|
88
104
|
print("seeing a string: loading gene positions as biomart parquet file")
|
|
89
105
|
biomart = pd.read_parquet(do_gene_pos)
|
|
90
106
|
else:
|
|
91
107
|
# and annotations
|
|
108
|
+
if organisms != ["NCBITaxon:9606"]:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"need to provide your own table as this automated function only works for humans for now"
|
|
111
|
+
)
|
|
92
112
|
biomart = getBiomartTable(
|
|
93
113
|
attributes=["start_position", "chromosome_name"]
|
|
94
114
|
).set_index("ensembl_gene_id")
|
|
@@ -118,7 +138,7 @@ class DataModule(L.LightningDataModule):
|
|
|
118
138
|
)
|
|
119
139
|
if do_gene_pos:
|
|
120
140
|
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
121
|
-
self.
|
|
141
|
+
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
122
142
|
# we might want not to order the genes by expression (or do it?)
|
|
123
143
|
# we might want to not introduce zeros and
|
|
124
144
|
if use_default_col:
|
|
@@ -131,19 +151,23 @@ class DataModule(L.LightningDataModule):
|
|
|
131
151
|
org_to_id=mdataset.encoder[organism_name],
|
|
132
152
|
tp_name=tp_name,
|
|
133
153
|
organism_name=organism_name,
|
|
134
|
-
class_names=
|
|
154
|
+
class_names=clss_to_weight,
|
|
135
155
|
)
|
|
136
156
|
self.validation_split = validation_split
|
|
137
157
|
self.test_split = test_split
|
|
138
158
|
self.dataset = mdataset
|
|
139
159
|
self.kwargs = kwargs
|
|
160
|
+
if "sampler" in self.kwargs:
|
|
161
|
+
self.kwargs.pop("sampler")
|
|
140
162
|
self.assays_to_drop = assays_to_drop
|
|
141
163
|
self.n_samples = len(mdataset)
|
|
142
164
|
self.weight_scaler = weight_scaler
|
|
143
165
|
self.train_oversampling_per_epoch = train_oversampling_per_epoch
|
|
144
|
-
self.
|
|
166
|
+
self.clss_to_weight = clss_to_weight
|
|
145
167
|
self.train_weights = None
|
|
146
168
|
self.train_labels = None
|
|
169
|
+
self.test_datasets = []
|
|
170
|
+
self.test_idx = []
|
|
147
171
|
super().__init__()
|
|
148
172
|
|
|
149
173
|
def __repr__(self):
|
|
@@ -154,8 +178,12 @@ class DataModule(L.LightningDataModule):
|
|
|
154
178
|
f"\ttest_split={self.test_split},\n"
|
|
155
179
|
f"\tn_samples={self.n_samples},\n"
|
|
156
180
|
f"\tweight_scaler={self.weight_scaler},\n"
|
|
157
|
-
f"\
|
|
158
|
-
f"\
|
|
181
|
+
f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
|
|
182
|
+
f"\tassays_to_drop={self.assays_to_drop},\n"
|
|
183
|
+
f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
|
|
184
|
+
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
185
|
+
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
186
|
+
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
159
187
|
+ (
|
|
160
188
|
"\twith train_dataset size of=("
|
|
161
189
|
+ str((self.train_weights != 0).sum())
|
|
@@ -179,22 +207,22 @@ class DataModule(L.LightningDataModule):
|
|
|
179
207
|
return decoders
|
|
180
208
|
|
|
181
209
|
@property
|
|
182
|
-
def
|
|
210
|
+
def labels_hierarchy(self):
|
|
183
211
|
"""
|
|
184
|
-
|
|
212
|
+
labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy
|
|
185
213
|
|
|
186
214
|
Returns:
|
|
187
215
|
dict[str, dict[str, str]]
|
|
188
216
|
"""
|
|
189
|
-
|
|
190
|
-
for k, dic in self.dataset.
|
|
217
|
+
labels_hierarchy = {}
|
|
218
|
+
for k, dic in self.dataset.labels_groupings.items():
|
|
191
219
|
rdic = {}
|
|
192
220
|
for sk, v in dic.items():
|
|
193
221
|
rdic[self.dataset.encoder[k][sk]] = [
|
|
194
222
|
self.dataset.encoder[k][i] for i in list(v)
|
|
195
223
|
]
|
|
196
|
-
|
|
197
|
-
return
|
|
224
|
+
labels_hierarchy[k] = rdic
|
|
225
|
+
return labels_hierarchy
|
|
198
226
|
|
|
199
227
|
@property
|
|
200
228
|
def genes(self):
|
|
@@ -219,9 +247,9 @@ class DataModule(L.LightningDataModule):
|
|
|
219
247
|
stage (str, optional): The stage of the model training process.
|
|
220
248
|
It can be either 'fit' or 'test'. Defaults to None.
|
|
221
249
|
"""
|
|
222
|
-
if len(self.
|
|
250
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
223
251
|
weights, labels = self.dataset.get_label_weights(
|
|
224
|
-
self.
|
|
252
|
+
self.clss_to_weight, scaler=self.weight_scaler
|
|
225
253
|
)
|
|
226
254
|
else:
|
|
227
255
|
weights = np.ones(1)
|
|
@@ -248,7 +276,6 @@ class DataModule(L.LightningDataModule):
|
|
|
248
276
|
idx_full = np.array(idx_full)
|
|
249
277
|
else:
|
|
250
278
|
idx_full = np.arange(self.n_samples)
|
|
251
|
-
test_datasets = []
|
|
252
279
|
if len_test > 0:
|
|
253
280
|
# this way we work on some never seen datasets
|
|
254
281
|
# keeping at least one
|
|
@@ -258,17 +285,15 @@ class DataModule(L.LightningDataModule):
|
|
|
258
285
|
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
259
286
|
)
|
|
260
287
|
cs = 0
|
|
261
|
-
print("these files will be considered test datasets:")
|
|
262
288
|
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
263
289
|
if cs + c > len_test:
|
|
264
290
|
break
|
|
265
291
|
else:
|
|
266
|
-
|
|
267
|
-
|
|
292
|
+
self.test_datasets.append(
|
|
293
|
+
self.dataset.mapped_dataset._path_list[i].path
|
|
294
|
+
)
|
|
268
295
|
cs += c
|
|
269
|
-
|
|
270
296
|
len_test = cs
|
|
271
|
-
print("perc test: ", len_test / self.n_samples)
|
|
272
297
|
self.test_idx = idx_full[:len_test]
|
|
273
298
|
idx_full = idx_full[len_test:]
|
|
274
299
|
else:
|
|
@@ -286,8 +311,7 @@ class DataModule(L.LightningDataModule):
|
|
|
286
311
|
self.train_weights = weights
|
|
287
312
|
self.train_labels = labels
|
|
288
313
|
self.idx_full = idx_full
|
|
289
|
-
|
|
290
|
-
return test_datasets
|
|
314
|
+
return self.test_datasets
|
|
291
315
|
|
|
292
316
|
def train_dataloader(self, **kwargs):
|
|
293
317
|
# train_sampler = WeightedRandomSampler(
|
|
@@ -299,7 +323,6 @@ class DataModule(L.LightningDataModule):
|
|
|
299
323
|
self.train_weights,
|
|
300
324
|
self.train_labels,
|
|
301
325
|
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
302
|
-
# replacement=True,
|
|
303
326
|
)
|
|
304
327
|
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
305
328
|
|
|
@@ -321,6 +344,11 @@ class DataModule(L.LightningDataModule):
|
|
|
321
344
|
else None
|
|
322
345
|
)
|
|
323
346
|
|
|
347
|
+
def predict_dataloader(self):
|
|
348
|
+
return DataLoader(
|
|
349
|
+
self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
|
|
350
|
+
)
|
|
351
|
+
|
|
324
352
|
# def teardown(self):
|
|
325
353
|
# clean up state after the trainer stops, delete files...
|
|
326
354
|
# called on every process in DDP
|