scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +66 -42
- scdataloader/collator.py +136 -67
- scdataloader/config.py +112 -0
- scdataloader/data.py +160 -169
- scdataloader/datamodule.py +403 -0
- scdataloader/mapped.py +285 -109
- scdataloader/preprocess.py +240 -109
- scdataloader/utils.py +162 -70
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/METADATA +87 -18
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
|
|
5
|
+
from torch.utils.data.sampler import (
|
|
6
|
+
WeightedRandomSampler,
|
|
7
|
+
SubsetRandomSampler,
|
|
8
|
+
SequentialSampler,
|
|
9
|
+
RandomSampler,
|
|
10
|
+
)
|
|
11
|
+
import torch
|
|
12
|
+
from torch.utils.data import DataLoader, Sampler
|
|
13
|
+
import lightning as L
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union, Sequence
|
|
16
|
+
|
|
17
|
+
from .data import Dataset
|
|
18
|
+
from .collator import Collator
|
|
19
|
+
from .utils import getBiomartTable
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataModule(L.LightningDataModule):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
collection_name: str,
|
|
26
|
+
clss_to_weight: list = ["organism_ontology_term_id"],
|
|
27
|
+
organisms: list = ["NCBITaxon:9606"],
|
|
28
|
+
weight_scaler: int = 10,
|
|
29
|
+
train_oversampling_per_epoch: float = 0.1,
|
|
30
|
+
validation_split: float = 0.2,
|
|
31
|
+
test_split: float = 0,
|
|
32
|
+
gene_embeddings: str = "",
|
|
33
|
+
use_default_col: bool = True,
|
|
34
|
+
gene_position_tolerance: int = 10_000,
|
|
35
|
+
# this is for the mappedCollection
|
|
36
|
+
clss_to_pred: list = ["organism_ontology_term_id"],
|
|
37
|
+
all_clss: list = ["organism_ontology_term_id"],
|
|
38
|
+
hierarchical_clss: list = [],
|
|
39
|
+
# this is for the collator
|
|
40
|
+
how: str = "random expr",
|
|
41
|
+
organism_name: str = "organism_ontology_term_id",
|
|
42
|
+
max_len: int = 1000,
|
|
43
|
+
add_zero_genes: int = 100,
|
|
44
|
+
do_gene_pos: Union[bool, str] = True,
|
|
45
|
+
tp_name: Optional[str] = None, # "heat_diff"
|
|
46
|
+
assays_to_drop: list = [
|
|
47
|
+
"EFO:0008853",
|
|
48
|
+
"EFO:0010961",
|
|
49
|
+
"EFO:0030007",
|
|
50
|
+
"EFO:0030062",
|
|
51
|
+
],
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
DataModule a pytorch lighting datamodule directly from a lamin Collection.
|
|
56
|
+
it can work with bare pytorch too
|
|
57
|
+
|
|
58
|
+
It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
|
|
59
|
+
This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
collection_name (str): The lamindb collection to be used.
|
|
63
|
+
clss_to_weight (list, optional): The classes to weight in the trainer's weighted random sampler. Defaults to ["organism_ontology_term_id"].
|
|
64
|
+
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
65
|
+
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
66
|
+
train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
|
|
67
|
+
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
68
|
+
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
69
|
+
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
70
|
+
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
|
|
71
|
+
the file must have ensembl_gene_id as index.
|
|
72
|
+
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
73
|
+
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
|
|
74
|
+
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
75
|
+
any genes within this distance of each other will be considered at the same position.
|
|
76
|
+
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
77
|
+
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
|
|
78
|
+
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
|
|
79
|
+
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
|
|
80
|
+
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
|
|
81
|
+
how (str, optional): The method to use for the collator. Defaults to "random expr".
|
|
82
|
+
organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
|
|
83
|
+
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
|
|
84
|
+
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
|
|
85
|
+
all_clss (list, optional): List of all classes. Defaults to ["organism_ontology_term_id"].
|
|
86
|
+
clss_to_pred (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
|
|
87
|
+
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
88
|
+
|
|
89
|
+
see @file data.py and @file collator.py for more details about some of the parameters
|
|
90
|
+
"""
|
|
91
|
+
if collection_name is not None:
|
|
92
|
+
mdataset = Dataset(
|
|
93
|
+
ln.Collection.filter(name=collection_name).first(),
|
|
94
|
+
organisms=organisms,
|
|
95
|
+
obs=all_clss,
|
|
96
|
+
clss_to_pred=clss_to_pred,
|
|
97
|
+
hierarchical_clss=hierarchical_clss,
|
|
98
|
+
)
|
|
99
|
+
# print(mdataset)
|
|
100
|
+
# and location
|
|
101
|
+
self.gene_pos = None
|
|
102
|
+
if do_gene_pos:
|
|
103
|
+
if type(do_gene_pos) is str:
|
|
104
|
+
print("seeing a string: loading gene positions as biomart parquet file")
|
|
105
|
+
biomart = pd.read_parquet(do_gene_pos)
|
|
106
|
+
else:
|
|
107
|
+
# and annotations
|
|
108
|
+
if organisms != ["NCBITaxon:9606"]:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"need to provide your own table as this automated function only works for humans for now"
|
|
111
|
+
)
|
|
112
|
+
biomart = getBiomartTable(
|
|
113
|
+
attributes=["start_position", "chromosome_name"]
|
|
114
|
+
).set_index("ensembl_gene_id")
|
|
115
|
+
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
116
|
+
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
117
|
+
c = []
|
|
118
|
+
i = 0
|
|
119
|
+
prev_position = -100000
|
|
120
|
+
prev_chromosome = None
|
|
121
|
+
for _, r in biomart.iterrows():
|
|
122
|
+
if (
|
|
123
|
+
r["chromosome_name"] != prev_chromosome
|
|
124
|
+
or r["start_position"] - prev_position > gene_position_tolerance
|
|
125
|
+
):
|
|
126
|
+
i += 1
|
|
127
|
+
c.append(i)
|
|
128
|
+
prev_position = r["start_position"]
|
|
129
|
+
prev_chromosome = r["chromosome_name"]
|
|
130
|
+
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
131
|
+
biomart["pos"] = c
|
|
132
|
+
mdataset.genedf = biomart.loc[mdataset.genedf.index]
|
|
133
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
134
|
+
|
|
135
|
+
if gene_embeddings != "":
|
|
136
|
+
mdataset.genedf = mdataset.genedf.join(
|
|
137
|
+
pd.read_parquet(gene_embeddings), how="inner"
|
|
138
|
+
)
|
|
139
|
+
if do_gene_pos:
|
|
140
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
141
|
+
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
142
|
+
# we might want not to order the genes by expression (or do it?)
|
|
143
|
+
# we might want to not introduce zeros and
|
|
144
|
+
if use_default_col:
|
|
145
|
+
kwargs["collate_fn"] = Collator(
|
|
146
|
+
organisms=organisms,
|
|
147
|
+
how=how,
|
|
148
|
+
valid_genes=mdataset.genedf.index.tolist(),
|
|
149
|
+
max_len=max_len,
|
|
150
|
+
add_zero_genes=add_zero_genes,
|
|
151
|
+
org_to_id=mdataset.encoder[organism_name],
|
|
152
|
+
tp_name=tp_name,
|
|
153
|
+
organism_name=organism_name,
|
|
154
|
+
class_names=clss_to_weight,
|
|
155
|
+
)
|
|
156
|
+
self.validation_split = validation_split
|
|
157
|
+
self.test_split = test_split
|
|
158
|
+
self.dataset = mdataset
|
|
159
|
+
self.kwargs = kwargs
|
|
160
|
+
if "sampler" in self.kwargs:
|
|
161
|
+
self.kwargs.pop("sampler")
|
|
162
|
+
self.assays_to_drop = assays_to_drop
|
|
163
|
+
self.n_samples = len(mdataset)
|
|
164
|
+
self.weight_scaler = weight_scaler
|
|
165
|
+
self.train_oversampling_per_epoch = train_oversampling_per_epoch
|
|
166
|
+
self.clss_to_weight = clss_to_weight
|
|
167
|
+
self.train_weights = None
|
|
168
|
+
self.train_labels = None
|
|
169
|
+
self.test_datasets = []
|
|
170
|
+
self.test_idx = []
|
|
171
|
+
super().__init__()
|
|
172
|
+
|
|
173
|
+
def __repr__(self):
|
|
174
|
+
return (
|
|
175
|
+
f"DataLoader(\n"
|
|
176
|
+
f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
|
|
177
|
+
f"\tvalidation_split={self.validation_split},\n"
|
|
178
|
+
f"\ttest_split={self.test_split},\n"
|
|
179
|
+
f"\tn_samples={self.n_samples},\n"
|
|
180
|
+
f"\tweight_scaler={self.weight_scaler},\n"
|
|
181
|
+
f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
|
|
182
|
+
f"\tassays_to_drop={self.assays_to_drop},\n"
|
|
183
|
+
f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
|
|
184
|
+
f"\ttest datasets={str(self.test_datasets)},\n"
|
|
185
|
+
f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
|
|
186
|
+
f"\tclss_to_weight={self.clss_to_weight}\n"
|
|
187
|
+
+ (
|
|
188
|
+
"\twith train_dataset size of=("
|
|
189
|
+
+ str((self.train_weights != 0).sum())
|
|
190
|
+
+ ")\n)"
|
|
191
|
+
)
|
|
192
|
+
if self.train_weights is not None
|
|
193
|
+
else ")"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def decoders(self):
|
|
198
|
+
"""
|
|
199
|
+
decoders the decoders for any labels that would have been encoded
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
dict[str, dict[int, str]]
|
|
203
|
+
"""
|
|
204
|
+
decoders = {}
|
|
205
|
+
for k, v in self.dataset.encoder.items():
|
|
206
|
+
decoders[k] = {va: ke for ke, va in v.items()}
|
|
207
|
+
return decoders
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def labels_hierarchy(self):
|
|
211
|
+
"""
|
|
212
|
+
labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
dict[str, dict[str, str]]
|
|
216
|
+
"""
|
|
217
|
+
labels_hierarchy = {}
|
|
218
|
+
for k, dic in self.dataset.labels_groupings.items():
|
|
219
|
+
rdic = {}
|
|
220
|
+
for sk, v in dic.items():
|
|
221
|
+
rdic[self.dataset.encoder[k][sk]] = [
|
|
222
|
+
self.dataset.encoder[k][i] for i in list(v)
|
|
223
|
+
]
|
|
224
|
+
labels_hierarchy[k] = rdic
|
|
225
|
+
return labels_hierarchy
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def genes(self):
|
|
229
|
+
"""
|
|
230
|
+
genes the genes used in this datamodule
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
list
|
|
234
|
+
"""
|
|
235
|
+
return self.dataset.genedf.index.tolist()
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def num_datasets(self):
|
|
239
|
+
return len(self.dataset.mapped_dataset.storages)
|
|
240
|
+
|
|
241
|
+
def setup(self, stage=None):
|
|
242
|
+
"""
|
|
243
|
+
setup method is used to prepare the data for the training, validation, and test sets.
|
|
244
|
+
It shuffles the data, calculates weights for each set, and creates samplers for each set.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
stage (str, optional): The stage of the model training process.
|
|
248
|
+
It can be either 'fit' or 'test'. Defaults to None.
|
|
249
|
+
"""
|
|
250
|
+
if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
|
|
251
|
+
weights, labels = self.dataset.get_label_weights(
|
|
252
|
+
self.clss_to_weight, scaler=self.weight_scaler
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
weights = np.ones(1)
|
|
256
|
+
labels = np.zeros(self.n_samples)
|
|
257
|
+
if isinstance(self.validation_split, int):
|
|
258
|
+
len_valid = self.validation_split
|
|
259
|
+
else:
|
|
260
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
261
|
+
if isinstance(self.test_split, int):
|
|
262
|
+
len_test = self.test_split
|
|
263
|
+
else:
|
|
264
|
+
len_test = int(self.n_samples * self.test_split)
|
|
265
|
+
assert (
|
|
266
|
+
len_test + len_valid < self.n_samples
|
|
267
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
268
|
+
|
|
269
|
+
idx_full = []
|
|
270
|
+
if len(self.assays_to_drop) > 0:
|
|
271
|
+
for i, a in enumerate(
|
|
272
|
+
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
|
|
273
|
+
):
|
|
274
|
+
if a not in self.assays_to_drop:
|
|
275
|
+
idx_full.append(i)
|
|
276
|
+
idx_full = np.array(idx_full)
|
|
277
|
+
else:
|
|
278
|
+
idx_full = np.arange(self.n_samples)
|
|
279
|
+
if len_test > 0:
|
|
280
|
+
# this way we work on some never seen datasets
|
|
281
|
+
# keeping at least one
|
|
282
|
+
len_test = (
|
|
283
|
+
len_test
|
|
284
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
285
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
286
|
+
)
|
|
287
|
+
cs = 0
|
|
288
|
+
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
289
|
+
if cs + c > len_test:
|
|
290
|
+
break
|
|
291
|
+
else:
|
|
292
|
+
self.test_datasets.append(
|
|
293
|
+
self.dataset.mapped_dataset._path_list[i].path
|
|
294
|
+
)
|
|
295
|
+
cs += c
|
|
296
|
+
len_test = cs
|
|
297
|
+
self.test_idx = idx_full[:len_test]
|
|
298
|
+
idx_full = idx_full[len_test:]
|
|
299
|
+
else:
|
|
300
|
+
self.test_idx = None
|
|
301
|
+
|
|
302
|
+
np.random.shuffle(idx_full)
|
|
303
|
+
if len_valid > 0:
|
|
304
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
305
|
+
idx_full = idx_full[len_valid:]
|
|
306
|
+
else:
|
|
307
|
+
self.valid_idx = None
|
|
308
|
+
weights = np.concatenate([weights, np.zeros(1)])
|
|
309
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
310
|
+
|
|
311
|
+
self.train_weights = weights
|
|
312
|
+
self.train_labels = labels
|
|
313
|
+
self.idx_full = idx_full
|
|
314
|
+
return self.test_datasets
|
|
315
|
+
|
|
316
|
+
def train_dataloader(self, **kwargs):
|
|
317
|
+
# train_sampler = WeightedRandomSampler(
|
|
318
|
+
# self.train_weights[self.train_labels],
|
|
319
|
+
# int(self.n_samples*self.train_oversampling_per_epoch),
|
|
320
|
+
# replacement=True,
|
|
321
|
+
# )
|
|
322
|
+
train_sampler = LabelWeightedSampler(
|
|
323
|
+
self.train_weights,
|
|
324
|
+
self.train_labels,
|
|
325
|
+
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
326
|
+
)
|
|
327
|
+
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
328
|
+
|
|
329
|
+
def val_dataloader(self):
|
|
330
|
+
return (
|
|
331
|
+
DataLoader(
|
|
332
|
+
self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
|
|
333
|
+
)
|
|
334
|
+
if self.valid_idx is not None
|
|
335
|
+
else None
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def test_dataloader(self):
|
|
339
|
+
return (
|
|
340
|
+
DataLoader(
|
|
341
|
+
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
342
|
+
)
|
|
343
|
+
if self.test_idx is not None
|
|
344
|
+
else None
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def predict_dataloader(self):
|
|
348
|
+
return DataLoader(
|
|
349
|
+
self.dataset, sampler=SubsetRandomSampler(self.idx_full), **self.kwargs
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# def teardown(self):
|
|
353
|
+
# clean up state after the trainer stops, delete files...
|
|
354
|
+
# called on every process in DDP
|
|
355
|
+
# pass
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class LabelWeightedSampler(Sampler[int]):
|
|
359
|
+
label_weights: Sequence[float]
|
|
360
|
+
klass_indices: Sequence[Sequence[int]]
|
|
361
|
+
num_samples: int
|
|
362
|
+
|
|
363
|
+
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
364
|
+
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
365
|
+
# NOTE: here we use replacement=True, you can change it if you don't upsample a class.
|
|
366
|
+
def __init__(
|
|
367
|
+
self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
|
|
368
|
+
) -> None:
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
:param label_weights: list(len=num_classes)[float], weights for each class.
|
|
372
|
+
:param labels: list(len=dataset_len)[int], labels of a dataset.
|
|
373
|
+
:param num_samples: number of samples.
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
super(LabelWeightedSampler, self).__init__(None)
|
|
377
|
+
# reweight labels from counter otherwsie same weight to labels that have many elements vs a few
|
|
378
|
+
label_weights = np.array(label_weights) * np.bincount(labels)
|
|
379
|
+
|
|
380
|
+
self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
|
|
381
|
+
self.labels = torch.as_tensor(labels, dtype=torch.int)
|
|
382
|
+
self.num_samples = num_samples
|
|
383
|
+
# list of tensor.
|
|
384
|
+
self.klass_indices = [
|
|
385
|
+
(self.labels == i_klass).nonzero().squeeze(1)
|
|
386
|
+
for i_klass in range(len(label_weights))
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
def __iter__(self):
|
|
390
|
+
sample_labels = torch.multinomial(
|
|
391
|
+
self.label_weights, num_samples=self.num_samples, replacement=True
|
|
392
|
+
)
|
|
393
|
+
sample_indices = torch.empty_like(sample_labels)
|
|
394
|
+
for i_klass, klass_index in enumerate(self.klass_indices):
|
|
395
|
+
if klass_index.numel() == 0:
|
|
396
|
+
continue
|
|
397
|
+
left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
|
|
398
|
+
right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
|
|
399
|
+
sample_indices[left_inds] = klass_index[right_inds]
|
|
400
|
+
yield from iter(sample_indices.tolist())
|
|
401
|
+
|
|
402
|
+
def __len__(self):
|
|
403
|
+
return self.num_samples
|