scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +66 -42
- scdataloader/collator.py +136 -67
- scdataloader/config.py +112 -0
- scdataloader/data.py +160 -169
- scdataloader/datamodule.py +403 -0
- scdataloader/mapped.py +285 -109
- scdataloader/preprocess.py +240 -109
- scdataloader/utils.py +162 -70
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/METADATA +87 -18
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
scdataloader/dataloader.py
DELETED
|
@@ -1,318 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import lamindb as ln
|
|
4
|
-
|
|
5
|
-
from torch.utils.data.sampler import (
|
|
6
|
-
WeightedRandomSampler,
|
|
7
|
-
SubsetRandomSampler,
|
|
8
|
-
SequentialSampler,
|
|
9
|
-
)
|
|
10
|
-
from torch.utils.data import DataLoader
|
|
11
|
-
import lightning as L
|
|
12
|
-
|
|
13
|
-
from typing import Optional
|
|
14
|
-
|
|
15
|
-
from .data import Dataset
|
|
16
|
-
from .collator import Collator
|
|
17
|
-
from .mapped import MappedDataset
|
|
18
|
-
from .utils import getBiomartTable
|
|
19
|
-
|
|
20
|
-
# TODO: put in config
|
|
21
|
-
COARSE_TISSUE = {
|
|
22
|
-
"adipose tissue": "",
|
|
23
|
-
"bladder organ": "",
|
|
24
|
-
"blood": "",
|
|
25
|
-
"bone marrow": "",
|
|
26
|
-
"brain": "",
|
|
27
|
-
"breast": "",
|
|
28
|
-
"esophagus": "",
|
|
29
|
-
"eye": "",
|
|
30
|
-
"embryo": "",
|
|
31
|
-
"fallopian tube": "",
|
|
32
|
-
"gall bladder": "",
|
|
33
|
-
"heart": "",
|
|
34
|
-
"intestine": "",
|
|
35
|
-
"kidney": "",
|
|
36
|
-
"liver": "",
|
|
37
|
-
"lung": "",
|
|
38
|
-
"lymph node": "",
|
|
39
|
-
"musculature of body": "",
|
|
40
|
-
"nose": "",
|
|
41
|
-
"ovary": "",
|
|
42
|
-
"pancreas": "",
|
|
43
|
-
"placenta": "",
|
|
44
|
-
"skin of body": "",
|
|
45
|
-
"spinal cord": "",
|
|
46
|
-
"spleen": "",
|
|
47
|
-
"stomach": "",
|
|
48
|
-
"thymus": "",
|
|
49
|
-
"thyroid gland": "",
|
|
50
|
-
"tongue": "",
|
|
51
|
-
"uterus": "",
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
COARSE_ANCESTRY = {
|
|
55
|
-
"African": "",
|
|
56
|
-
"Chinese": "",
|
|
57
|
-
"East Asian": "",
|
|
58
|
-
"Eskimo": "",
|
|
59
|
-
"European": "",
|
|
60
|
-
"Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
|
|
61
|
-
"Hispanic or Latin American": "",
|
|
62
|
-
"Native American": "",
|
|
63
|
-
"Oceanian": "",
|
|
64
|
-
"South Asian": "",
|
|
65
|
-
}
|
|
66
|
-
|
|
67
|
-
COARSE_DEVELOPMENT_STAGE = {
|
|
68
|
-
"Embryonic human": "",
|
|
69
|
-
"Fetal": "",
|
|
70
|
-
"Immature": "",
|
|
71
|
-
"Mature": "",
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
COARSE_ASSAY = {
|
|
75
|
-
"10x 3'": "",
|
|
76
|
-
"10x 5'": "",
|
|
77
|
-
"10x multiome": "",
|
|
78
|
-
"CEL-seq2": "",
|
|
79
|
-
"Drop-seq": "",
|
|
80
|
-
"GEXSCOPE technology": "",
|
|
81
|
-
"inDrop": "",
|
|
82
|
-
"microwell-seq": "",
|
|
83
|
-
"sci-Plex": "",
|
|
84
|
-
"sci-RNA-seq": "",
|
|
85
|
-
"Seq-Well": "",
|
|
86
|
-
"Slide-seq": "",
|
|
87
|
-
"Smart-seq": "",
|
|
88
|
-
"SPLiT-seq": "",
|
|
89
|
-
"TruDrop": "",
|
|
90
|
-
"Visium Spatial Gene Expression": "",
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class DataModule(L.LightningDataModule):
|
|
95
|
-
"""
|
|
96
|
-
Base class for all data loaders
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
def __init__(
|
|
100
|
-
self,
|
|
101
|
-
mdataset: Optional[MappedDataset] = None,
|
|
102
|
-
collection_name=None,
|
|
103
|
-
organisms: list = ["NCBITaxon:9606"],
|
|
104
|
-
weight_scaler: int = 30,
|
|
105
|
-
train_oversampling=1,
|
|
106
|
-
label_to_weight: list = [],
|
|
107
|
-
label_to_pred: list = [],
|
|
108
|
-
validation_split: float = 0.2,
|
|
109
|
-
test_split: float = 0,
|
|
110
|
-
use_default_col=True,
|
|
111
|
-
all_labels=[],
|
|
112
|
-
hierarchical_labels=[],
|
|
113
|
-
how="most expr",
|
|
114
|
-
organism_name="organism_ontology_term_id",
|
|
115
|
-
max_len=1000,
|
|
116
|
-
add_zero_genes=100,
|
|
117
|
-
do_gene_pos=True,
|
|
118
|
-
gene_embeddings="",
|
|
119
|
-
gene_position_tolerance=10_000,
|
|
120
|
-
**kwargs,
|
|
121
|
-
):
|
|
122
|
-
"""
|
|
123
|
-
Initializes the DataModule.
|
|
124
|
-
|
|
125
|
-
Args:
|
|
126
|
-
dataset (MappedDataset): The dataset to be used.
|
|
127
|
-
weight_scaler (int, optional): The weight scaler for weighted random sampling. Defaults to 30.
|
|
128
|
-
label_to_weight (list, optional): List of labels to weight. Defaults to [].
|
|
129
|
-
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
130
|
-
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
131
|
-
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
132
|
-
"""
|
|
133
|
-
if collection_name is not None:
|
|
134
|
-
mdataset = Dataset(
|
|
135
|
-
ln.Collection.filter(name=collection_name).first(),
|
|
136
|
-
organisms=organisms,
|
|
137
|
-
obs=all_labels,
|
|
138
|
-
clss_to_pred=label_to_pred,
|
|
139
|
-
hierarchical_clss=hierarchical_labels,
|
|
140
|
-
)
|
|
141
|
-
print(mdataset)
|
|
142
|
-
# and location
|
|
143
|
-
if do_gene_pos:
|
|
144
|
-
# and annotations
|
|
145
|
-
biomart = getBiomartTable(
|
|
146
|
-
attributes=["start_position", "chromosome_name"]
|
|
147
|
-
).set_index("ensembl_gene_id")
|
|
148
|
-
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
149
|
-
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
150
|
-
c = []
|
|
151
|
-
i = 0
|
|
152
|
-
prev_position = -100000
|
|
153
|
-
prev_chromosome = None
|
|
154
|
-
for _, r in biomart.iterrows():
|
|
155
|
-
if (
|
|
156
|
-
r["chromosome_name"] != prev_chromosome
|
|
157
|
-
or r["start_position"] - prev_position > gene_position_tolerance
|
|
158
|
-
):
|
|
159
|
-
i += 1
|
|
160
|
-
c.append(i)
|
|
161
|
-
prev_position = r["start_position"]
|
|
162
|
-
prev_chromosome = r["chromosome_name"]
|
|
163
|
-
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
164
|
-
biomart["pos"] = c
|
|
165
|
-
mdataset.genedf = biomart.loc[
|
|
166
|
-
mdataset.genedf[mdataset.genedf.index.isin(biomart.index)].index
|
|
167
|
-
]
|
|
168
|
-
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
169
|
-
|
|
170
|
-
if gene_embeddings != "":
|
|
171
|
-
mdataset.genedf = mdataset.genedf.join(
|
|
172
|
-
pd.read_parquet(gene_embeddings), how="inner"
|
|
173
|
-
)
|
|
174
|
-
if do_gene_pos:
|
|
175
|
-
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
176
|
-
self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
177
|
-
# we might want not to order the genes by expression (or do it?)
|
|
178
|
-
# we might want to not introduce zeros and
|
|
179
|
-
if use_default_col:
|
|
180
|
-
kwargs["collate_fn"] = Collator(
|
|
181
|
-
organisms=organisms,
|
|
182
|
-
how=how,
|
|
183
|
-
valid_genes=mdataset.genedf.index.tolist(),
|
|
184
|
-
max_len=max_len,
|
|
185
|
-
add_zero_genes=add_zero_genes,
|
|
186
|
-
org_to_id=mdataset.encoder[organism_name],
|
|
187
|
-
tp_name="heat_diff",
|
|
188
|
-
organism_name=organism_name,
|
|
189
|
-
class_names=label_to_weight,
|
|
190
|
-
)
|
|
191
|
-
self.validation_split = validation_split
|
|
192
|
-
self.test_split = test_split
|
|
193
|
-
self.dataset = mdataset
|
|
194
|
-
self.kwargs = kwargs
|
|
195
|
-
self.n_samples = len(mdataset)
|
|
196
|
-
self.weight_scaler = weight_scaler
|
|
197
|
-
self.train_oversampling = train_oversampling
|
|
198
|
-
self.label_to_weight = label_to_weight
|
|
199
|
-
super().__init__()
|
|
200
|
-
|
|
201
|
-
@property
|
|
202
|
-
def decoders(self):
|
|
203
|
-
decoders = {}
|
|
204
|
-
for k, v in self.dataset.encoder.items():
|
|
205
|
-
decoders[k] = {va: ke for ke, va in v.items()}
|
|
206
|
-
return decoders
|
|
207
|
-
|
|
208
|
-
@property
|
|
209
|
-
def cls_hierarchy(self):
|
|
210
|
-
cls_hierarchy = {}
|
|
211
|
-
for k, dic in self.dataset.class_groupings.items():
|
|
212
|
-
rdic = {}
|
|
213
|
-
for sk, v in dic.items():
|
|
214
|
-
rdic[self.dataset.encoder[k][sk]] = [
|
|
215
|
-
self.dataset.encoder[k][i] for i in list(v)
|
|
216
|
-
]
|
|
217
|
-
cls_hierarchy[k] = rdic
|
|
218
|
-
return cls_hierarchy
|
|
219
|
-
|
|
220
|
-
@property
|
|
221
|
-
def genes(self):
|
|
222
|
-
return self.dataset.genedf.index.tolist()
|
|
223
|
-
|
|
224
|
-
def setup(self, stage=None):
|
|
225
|
-
"""
|
|
226
|
-
setup method is used to prepare the data for the training, validation, and test sets.
|
|
227
|
-
It shuffles the data, calculates weights for each set, and creates samplers for each set.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
stage (str, optional): The stage of the model training process.
|
|
231
|
-
It can be either 'fit' or 'test'. Defaults to None.
|
|
232
|
-
"""
|
|
233
|
-
|
|
234
|
-
if len(self.label_to_weight) > 0:
|
|
235
|
-
weights = self.dataset.get_label_weights(
|
|
236
|
-
self.label_to_weight, scaler=self.weight_scaler
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
weights = np.ones(self.n_samples)
|
|
240
|
-
if isinstance(self.validation_split, int):
|
|
241
|
-
len_valid = self.validation_split
|
|
242
|
-
else:
|
|
243
|
-
len_valid = int(self.n_samples * self.validation_split)
|
|
244
|
-
if isinstance(self.test_split, int):
|
|
245
|
-
len_test = self.test_split
|
|
246
|
-
else:
|
|
247
|
-
len_test = int(self.n_samples * self.test_split)
|
|
248
|
-
assert (
|
|
249
|
-
len_test + len_valid < self.n_samples
|
|
250
|
-
), "test set + valid set size is configured to be larger than entire dataset."
|
|
251
|
-
idx_full = np.arange(self.n_samples)
|
|
252
|
-
if len_test > 0:
|
|
253
|
-
# this way we work on some never seen datasets
|
|
254
|
-
# keeping at least one
|
|
255
|
-
len_test = (
|
|
256
|
-
len_test
|
|
257
|
-
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
258
|
-
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
259
|
-
)
|
|
260
|
-
cs = 0
|
|
261
|
-
test_datasets = []
|
|
262
|
-
print("these files will be considered test datasets:")
|
|
263
|
-
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
264
|
-
if cs + c > len_test:
|
|
265
|
-
break
|
|
266
|
-
else:
|
|
267
|
-
print(" " + self.dataset.mapped_dataset.path_list[i].path)
|
|
268
|
-
test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
|
|
269
|
-
cs += c
|
|
270
|
-
|
|
271
|
-
len_test = cs
|
|
272
|
-
print("perc test: ", len_test / self.n_samples)
|
|
273
|
-
test_idx = idx_full[:len_test]
|
|
274
|
-
idx_full = idx_full[len_test:]
|
|
275
|
-
self.test_sampler = SequentialSampler(test_idx)
|
|
276
|
-
else:
|
|
277
|
-
self.test_sampler = None
|
|
278
|
-
test_datasets = None
|
|
279
|
-
|
|
280
|
-
np.random.shuffle(idx_full)
|
|
281
|
-
if len_valid > 0:
|
|
282
|
-
valid_idx = idx_full[:len_valid]
|
|
283
|
-
idx_full = idx_full[len_valid:]
|
|
284
|
-
self.valid_sampler = SubsetRandomSampler(valid_idx)
|
|
285
|
-
else:
|
|
286
|
-
self.valid_sampler = None
|
|
287
|
-
|
|
288
|
-
weights[~idx_full] = 0
|
|
289
|
-
self.train_sampler = WeightedRandomSampler(
|
|
290
|
-
weights,
|
|
291
|
-
int(len(idx_full) * self.train_oversampling),
|
|
292
|
-
replacement=True,
|
|
293
|
-
)
|
|
294
|
-
return test_datasets
|
|
295
|
-
|
|
296
|
-
def train_dataloader(self, **kwargs):
|
|
297
|
-
return DataLoader(
|
|
298
|
-
self.dataset, sampler=self.train_sampler, **self.kwargs, **kwargs
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
def val_dataloader(self):
|
|
302
|
-
return (
|
|
303
|
-
DataLoader(self.dataset, sampler=self.valid_sampler, **self.kwargs)
|
|
304
|
-
if self.valid_sampler is not None
|
|
305
|
-
else None
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
def test_dataloader(self):
|
|
309
|
-
return (
|
|
310
|
-
DataLoader(self.dataset, sampler=self.test_sampler, **self.kwargs)
|
|
311
|
-
if self.test_sampler is not None
|
|
312
|
-
else None
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# def teardown(self):
|
|
316
|
-
# clean up state after the trainer stops, delete files...
|
|
317
|
-
# called on every process in DDP
|
|
318
|
-
# pass
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
scdataloader/VERSION,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
|
-
scdataloader/__init__.py,sha256=cuj9n8np6jXU05e0VzDkUQv4CYJI6StvQ0TAsURS7wg,122
|
|
3
|
-
scdataloader/__main__.py,sha256=x-EDMcfJscSM5ViRZmH0ekCm7QoYRgRF7qVeNKg2Dyc,5733
|
|
4
|
-
scdataloader/base.py,sha256=M1gD59OffRdLOgS1vHKygOomUoAMuzjpRtAfM3SBKF8,338
|
|
5
|
-
scdataloader/collator.py,sha256=vV4kuygk_x_HthyitKvJNn1yDzcL1COMvsP8N5vaME0,9524
|
|
6
|
-
scdataloader/data.py,sha256=8G-ric6pmHf1U4X_0VnTS-nKcA6ztKtrhWJwjXsmUV0,13029
|
|
7
|
-
scdataloader/dataloader.py,sha256=MqASZkmu3FG0z_cIG6L_7_T1uJd5iyVtAyEgap8Fv6c,10281
|
|
8
|
-
scdataloader/mapped.py,sha256=ldBgCXnbFQUlEJ7dSWFgJ0654b6e_AK41mMxAgRn1hM,12635
|
|
9
|
-
scdataloader/preprocess.py,sha256=aX69Z7cDrRG0qBa1yKngMyfkLK7DvAUnUXkiKUE-bbo,22550
|
|
10
|
-
scdataloader/utils.py,sha256=7RKTZIAw0fLAmG31ph0WVvzEQPyPLRUpWjjg3P53ofc,17282
|
|
11
|
-
scdataloader-0.0.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
12
|
-
scdataloader-0.0.3.dist-info/METADATA,sha256=jY_1yqWY5KYiy1jiaRSvo9z5bQcmedLBaRPk5J8WHlo,38289
|
|
13
|
-
scdataloader-0.0.3.dist-info/WHEEL,sha256=d2fvjOD7sXsVzChCqf0Ty0JbHKBaLYwDbGQDwQTnJ50,88
|
|
14
|
-
scdataloader-0.0.3.dist-info/entry_points.txt,sha256=nLqucZaa5wiF7-1FCgMXO916WDQ9Qm0TcxQp0f1DwE4,59
|
|
15
|
-
scdataloader-0.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|