scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__init__.py +4 -0
- scdataloader/__main__.py +188 -0
- scdataloader/collator.py +263 -0
- scdataloader/data.py +142 -159
- scdataloader/dataloader.py +318 -0
- scdataloader/mapped.py +24 -25
- scdataloader/preprocess.py +126 -145
- scdataloader/utils.py +99 -76
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/METADATA +33 -7
- scdataloader-0.0.3.dist-info/RECORD +15 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/WHEEL +1 -1
- scdataloader-0.0.2.dist-info/RECORD +0 -12
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
|
|
5
|
+
from torch.utils.data.sampler import (
|
|
6
|
+
WeightedRandomSampler,
|
|
7
|
+
SubsetRandomSampler,
|
|
8
|
+
SequentialSampler,
|
|
9
|
+
)
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
import lightning as L
|
|
12
|
+
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from .data import Dataset
|
|
16
|
+
from .collator import Collator
|
|
17
|
+
from .mapped import MappedDataset
|
|
18
|
+
from .utils import getBiomartTable
|
|
19
|
+
|
|
20
|
+
# TODO: put in config
|
|
21
|
+
COARSE_TISSUE = {
|
|
22
|
+
"adipose tissue": "",
|
|
23
|
+
"bladder organ": "",
|
|
24
|
+
"blood": "",
|
|
25
|
+
"bone marrow": "",
|
|
26
|
+
"brain": "",
|
|
27
|
+
"breast": "",
|
|
28
|
+
"esophagus": "",
|
|
29
|
+
"eye": "",
|
|
30
|
+
"embryo": "",
|
|
31
|
+
"fallopian tube": "",
|
|
32
|
+
"gall bladder": "",
|
|
33
|
+
"heart": "",
|
|
34
|
+
"intestine": "",
|
|
35
|
+
"kidney": "",
|
|
36
|
+
"liver": "",
|
|
37
|
+
"lung": "",
|
|
38
|
+
"lymph node": "",
|
|
39
|
+
"musculature of body": "",
|
|
40
|
+
"nose": "",
|
|
41
|
+
"ovary": "",
|
|
42
|
+
"pancreas": "",
|
|
43
|
+
"placenta": "",
|
|
44
|
+
"skin of body": "",
|
|
45
|
+
"spinal cord": "",
|
|
46
|
+
"spleen": "",
|
|
47
|
+
"stomach": "",
|
|
48
|
+
"thymus": "",
|
|
49
|
+
"thyroid gland": "",
|
|
50
|
+
"tongue": "",
|
|
51
|
+
"uterus": "",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
COARSE_ANCESTRY = {
|
|
55
|
+
"African": "",
|
|
56
|
+
"Chinese": "",
|
|
57
|
+
"East Asian": "",
|
|
58
|
+
"Eskimo": "",
|
|
59
|
+
"European": "",
|
|
60
|
+
"Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
|
|
61
|
+
"Hispanic or Latin American": "",
|
|
62
|
+
"Native American": "",
|
|
63
|
+
"Oceanian": "",
|
|
64
|
+
"South Asian": "",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
COARSE_DEVELOPMENT_STAGE = {
|
|
68
|
+
"Embryonic human": "",
|
|
69
|
+
"Fetal": "",
|
|
70
|
+
"Immature": "",
|
|
71
|
+
"Mature": "",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
COARSE_ASSAY = {
|
|
75
|
+
"10x 3'": "",
|
|
76
|
+
"10x 5'": "",
|
|
77
|
+
"10x multiome": "",
|
|
78
|
+
"CEL-seq2": "",
|
|
79
|
+
"Drop-seq": "",
|
|
80
|
+
"GEXSCOPE technology": "",
|
|
81
|
+
"inDrop": "",
|
|
82
|
+
"microwell-seq": "",
|
|
83
|
+
"sci-Plex": "",
|
|
84
|
+
"sci-RNA-seq": "",
|
|
85
|
+
"Seq-Well": "",
|
|
86
|
+
"Slide-seq": "",
|
|
87
|
+
"Smart-seq": "",
|
|
88
|
+
"SPLiT-seq": "",
|
|
89
|
+
"TruDrop": "",
|
|
90
|
+
"Visium Spatial Gene Expression": "",
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class DataModule(L.LightningDataModule):
|
|
95
|
+
"""
|
|
96
|
+
Base class for all data loaders
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
mdataset: Optional[MappedDataset] = None,
|
|
102
|
+
collection_name=None,
|
|
103
|
+
organisms: list = ["NCBITaxon:9606"],
|
|
104
|
+
weight_scaler: int = 30,
|
|
105
|
+
train_oversampling=1,
|
|
106
|
+
label_to_weight: list = [],
|
|
107
|
+
label_to_pred: list = [],
|
|
108
|
+
validation_split: float = 0.2,
|
|
109
|
+
test_split: float = 0,
|
|
110
|
+
use_default_col=True,
|
|
111
|
+
all_labels=[],
|
|
112
|
+
hierarchical_labels=[],
|
|
113
|
+
how="most expr",
|
|
114
|
+
organism_name="organism_ontology_term_id",
|
|
115
|
+
max_len=1000,
|
|
116
|
+
add_zero_genes=100,
|
|
117
|
+
do_gene_pos=True,
|
|
118
|
+
gene_embeddings="",
|
|
119
|
+
gene_position_tolerance=10_000,
|
|
120
|
+
**kwargs,
|
|
121
|
+
):
|
|
122
|
+
"""
|
|
123
|
+
Initializes the DataModule.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dataset (MappedDataset): The dataset to be used.
|
|
127
|
+
weight_scaler (int, optional): The weight scaler for weighted random sampling. Defaults to 30.
|
|
128
|
+
label_to_weight (list, optional): List of labels to weight. Defaults to [].
|
|
129
|
+
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
130
|
+
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
131
|
+
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
132
|
+
"""
|
|
133
|
+
if collection_name is not None:
|
|
134
|
+
mdataset = Dataset(
|
|
135
|
+
ln.Collection.filter(name=collection_name).first(),
|
|
136
|
+
organisms=organisms,
|
|
137
|
+
obs=all_labels,
|
|
138
|
+
clss_to_pred=label_to_pred,
|
|
139
|
+
hierarchical_clss=hierarchical_labels,
|
|
140
|
+
)
|
|
141
|
+
print(mdataset)
|
|
142
|
+
# and location
|
|
143
|
+
if do_gene_pos:
|
|
144
|
+
# and annotations
|
|
145
|
+
biomart = getBiomartTable(
|
|
146
|
+
attributes=["start_position", "chromosome_name"]
|
|
147
|
+
).set_index("ensembl_gene_id")
|
|
148
|
+
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
149
|
+
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
150
|
+
c = []
|
|
151
|
+
i = 0
|
|
152
|
+
prev_position = -100000
|
|
153
|
+
prev_chromosome = None
|
|
154
|
+
for _, r in biomart.iterrows():
|
|
155
|
+
if (
|
|
156
|
+
r["chromosome_name"] != prev_chromosome
|
|
157
|
+
or r["start_position"] - prev_position > gene_position_tolerance
|
|
158
|
+
):
|
|
159
|
+
i += 1
|
|
160
|
+
c.append(i)
|
|
161
|
+
prev_position = r["start_position"]
|
|
162
|
+
prev_chromosome = r["chromosome_name"]
|
|
163
|
+
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
164
|
+
biomart["pos"] = c
|
|
165
|
+
mdataset.genedf = biomart.loc[
|
|
166
|
+
mdataset.genedf[mdataset.genedf.index.isin(biomart.index)].index
|
|
167
|
+
]
|
|
168
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
169
|
+
|
|
170
|
+
if gene_embeddings != "":
|
|
171
|
+
mdataset.genedf = mdataset.genedf.join(
|
|
172
|
+
pd.read_parquet(gene_embeddings), how="inner"
|
|
173
|
+
)
|
|
174
|
+
if do_gene_pos:
|
|
175
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
176
|
+
self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
177
|
+
# we might want not to order the genes by expression (or do it?)
|
|
178
|
+
# we might want to not introduce zeros and
|
|
179
|
+
if use_default_col:
|
|
180
|
+
kwargs["collate_fn"] = Collator(
|
|
181
|
+
organisms=organisms,
|
|
182
|
+
how=how,
|
|
183
|
+
valid_genes=mdataset.genedf.index.tolist(),
|
|
184
|
+
max_len=max_len,
|
|
185
|
+
add_zero_genes=add_zero_genes,
|
|
186
|
+
org_to_id=mdataset.encoder[organism_name],
|
|
187
|
+
tp_name="heat_diff",
|
|
188
|
+
organism_name=organism_name,
|
|
189
|
+
class_names=label_to_weight,
|
|
190
|
+
)
|
|
191
|
+
self.validation_split = validation_split
|
|
192
|
+
self.test_split = test_split
|
|
193
|
+
self.dataset = mdataset
|
|
194
|
+
self.kwargs = kwargs
|
|
195
|
+
self.n_samples = len(mdataset)
|
|
196
|
+
self.weight_scaler = weight_scaler
|
|
197
|
+
self.train_oversampling = train_oversampling
|
|
198
|
+
self.label_to_weight = label_to_weight
|
|
199
|
+
super().__init__()
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def decoders(self):
|
|
203
|
+
decoders = {}
|
|
204
|
+
for k, v in self.dataset.encoder.items():
|
|
205
|
+
decoders[k] = {va: ke for ke, va in v.items()}
|
|
206
|
+
return decoders
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def cls_hierarchy(self):
|
|
210
|
+
cls_hierarchy = {}
|
|
211
|
+
for k, dic in self.dataset.class_groupings.items():
|
|
212
|
+
rdic = {}
|
|
213
|
+
for sk, v in dic.items():
|
|
214
|
+
rdic[self.dataset.encoder[k][sk]] = [
|
|
215
|
+
self.dataset.encoder[k][i] for i in list(v)
|
|
216
|
+
]
|
|
217
|
+
cls_hierarchy[k] = rdic
|
|
218
|
+
return cls_hierarchy
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def genes(self):
|
|
222
|
+
return self.dataset.genedf.index.tolist()
|
|
223
|
+
|
|
224
|
+
def setup(self, stage=None):
|
|
225
|
+
"""
|
|
226
|
+
setup method is used to prepare the data for the training, validation, and test sets.
|
|
227
|
+
It shuffles the data, calculates weights for each set, and creates samplers for each set.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
stage (str, optional): The stage of the model training process.
|
|
231
|
+
It can be either 'fit' or 'test'. Defaults to None.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
if len(self.label_to_weight) > 0:
|
|
235
|
+
weights = self.dataset.get_label_weights(
|
|
236
|
+
self.label_to_weight, scaler=self.weight_scaler
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
weights = np.ones(self.n_samples)
|
|
240
|
+
if isinstance(self.validation_split, int):
|
|
241
|
+
len_valid = self.validation_split
|
|
242
|
+
else:
|
|
243
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
244
|
+
if isinstance(self.test_split, int):
|
|
245
|
+
len_test = self.test_split
|
|
246
|
+
else:
|
|
247
|
+
len_test = int(self.n_samples * self.test_split)
|
|
248
|
+
assert (
|
|
249
|
+
len_test + len_valid < self.n_samples
|
|
250
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
251
|
+
idx_full = np.arange(self.n_samples)
|
|
252
|
+
if len_test > 0:
|
|
253
|
+
# this way we work on some never seen datasets
|
|
254
|
+
# keeping at least one
|
|
255
|
+
len_test = (
|
|
256
|
+
len_test
|
|
257
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
258
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
259
|
+
)
|
|
260
|
+
cs = 0
|
|
261
|
+
test_datasets = []
|
|
262
|
+
print("these files will be considered test datasets:")
|
|
263
|
+
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
264
|
+
if cs + c > len_test:
|
|
265
|
+
break
|
|
266
|
+
else:
|
|
267
|
+
print(" " + self.dataset.mapped_dataset.path_list[i].path)
|
|
268
|
+
test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
|
|
269
|
+
cs += c
|
|
270
|
+
|
|
271
|
+
len_test = cs
|
|
272
|
+
print("perc test: ", len_test / self.n_samples)
|
|
273
|
+
test_idx = idx_full[:len_test]
|
|
274
|
+
idx_full = idx_full[len_test:]
|
|
275
|
+
self.test_sampler = SequentialSampler(test_idx)
|
|
276
|
+
else:
|
|
277
|
+
self.test_sampler = None
|
|
278
|
+
test_datasets = None
|
|
279
|
+
|
|
280
|
+
np.random.shuffle(idx_full)
|
|
281
|
+
if len_valid > 0:
|
|
282
|
+
valid_idx = idx_full[:len_valid]
|
|
283
|
+
idx_full = idx_full[len_valid:]
|
|
284
|
+
self.valid_sampler = SubsetRandomSampler(valid_idx)
|
|
285
|
+
else:
|
|
286
|
+
self.valid_sampler = None
|
|
287
|
+
|
|
288
|
+
weights[~idx_full] = 0
|
|
289
|
+
self.train_sampler = WeightedRandomSampler(
|
|
290
|
+
weights,
|
|
291
|
+
int(len(idx_full) * self.train_oversampling),
|
|
292
|
+
replacement=True,
|
|
293
|
+
)
|
|
294
|
+
return test_datasets
|
|
295
|
+
|
|
296
|
+
def train_dataloader(self, **kwargs):
|
|
297
|
+
return DataLoader(
|
|
298
|
+
self.dataset, sampler=self.train_sampler, **self.kwargs, **kwargs
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def val_dataloader(self):
|
|
302
|
+
return (
|
|
303
|
+
DataLoader(self.dataset, sampler=self.valid_sampler, **self.kwargs)
|
|
304
|
+
if self.valid_sampler is not None
|
|
305
|
+
else None
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def test_dataloader(self):
|
|
309
|
+
return (
|
|
310
|
+
DataLoader(self.dataset, sampler=self.test_sampler, **self.kwargs)
|
|
311
|
+
if self.test_sampler is not None
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# def teardown(self):
|
|
316
|
+
# clean up state after the trainer stops, delete files...
|
|
317
|
+
# called on every process in DDP
|
|
318
|
+
# pass
|
scdataloader/mapped.py
CHANGED
|
@@ -80,10 +80,13 @@ class MappedDataset:
|
|
|
80
80
|
join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
|
|
81
81
|
encode_labels: Optional[Union[bool, List[str]]] = False,
|
|
82
82
|
parallel: bool = False,
|
|
83
|
+
unknown_class: str = "unknown",
|
|
83
84
|
):
|
|
84
85
|
self.storages = []
|
|
85
86
|
self.conns = []
|
|
86
87
|
self.parallel = parallel
|
|
88
|
+
self.unknown_class = unknown_class
|
|
89
|
+
self.path_list = path_list
|
|
87
90
|
self._make_connections(path_list, parallel)
|
|
88
91
|
|
|
89
92
|
self.n_obs_list = []
|
|
@@ -96,22 +99,16 @@ class MappedDataset:
|
|
|
96
99
|
self.n_obs_list.append(X.attrs["shape"][0])
|
|
97
100
|
self.n_obs = sum(self.n_obs_list)
|
|
98
101
|
|
|
99
|
-
self.indices = np.hstack(
|
|
100
|
-
|
|
101
|
-
)
|
|
102
|
-
self.storage_idx = np.repeat(
|
|
103
|
-
np.arange(len(self.storages)), self.n_obs_list
|
|
104
|
-
)
|
|
102
|
+
self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
|
|
103
|
+
self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
|
|
105
104
|
|
|
106
105
|
self.join_vars = join_vars if len(path_list) > 1 else None
|
|
107
106
|
self.var_indices = None
|
|
108
|
-
if self.join_vars
|
|
107
|
+
if self.join_vars != "None":
|
|
109
108
|
self._make_join_vars()
|
|
110
109
|
|
|
111
110
|
self.encode_labels = encode_labels
|
|
112
|
-
self.label_keys = (
|
|
113
|
-
[label_keys] if isinstance(label_keys, str) else label_keys
|
|
114
|
-
)
|
|
111
|
+
self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
|
|
115
112
|
if isinstance(encode_labels, bool):
|
|
116
113
|
if encode_labels:
|
|
117
114
|
encode_labels = label_keys
|
|
@@ -122,6 +119,8 @@ class MappedDataset:
|
|
|
122
119
|
for label in encode_labels:
|
|
123
120
|
cats = self.get_merged_categories(label)
|
|
124
121
|
self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
|
|
122
|
+
if unknown_class in self.encoders[label]:
|
|
123
|
+
self.encoders[label][unknown_class] = -1
|
|
125
124
|
else:
|
|
126
125
|
self.encoders = {}
|
|
127
126
|
self._closed = False
|
|
@@ -157,9 +156,15 @@ class MappedDataset:
|
|
|
157
156
|
raise ValueError(
|
|
158
157
|
"The provided AnnData objects don't have shared varibales."
|
|
159
158
|
)
|
|
160
|
-
self.var_indices = [
|
|
161
|
-
|
|
162
|
-
|
|
159
|
+
self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
|
|
160
|
+
|
|
161
|
+
def _check_aligned_vars(self, vars: list):
|
|
162
|
+
i = 0
|
|
163
|
+
for storage in self.storages:
|
|
164
|
+
with _Connect(storage) as store:
|
|
165
|
+
if vars == _safer_read_index(store["var"]).tolist():
|
|
166
|
+
i += 1
|
|
167
|
+
print("{}% are aligned".format(i * 100 / len(self.storages)))
|
|
163
168
|
|
|
164
169
|
def __len__(self):
|
|
165
170
|
return self.n_obs
|
|
@@ -172,14 +177,14 @@ class MappedDataset:
|
|
|
172
177
|
else:
|
|
173
178
|
var_idxs = None
|
|
174
179
|
with _Connect(self.storages[storage_idx]) as store:
|
|
175
|
-
out =
|
|
180
|
+
out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
|
|
176
181
|
if self.label_keys is not None:
|
|
177
182
|
for i, label in enumerate(self.label_keys):
|
|
178
183
|
label_idx = self.get_label_idx(store, obs_idx, label)
|
|
179
184
|
if label in self.encoders:
|
|
180
|
-
out.
|
|
185
|
+
out.update({label: self.encoders[label][label_idx]})
|
|
181
186
|
else:
|
|
182
|
-
out.
|
|
187
|
+
out.update({label: label_idx})
|
|
183
188
|
return out
|
|
184
189
|
|
|
185
190
|
def uns(self, idx, key):
|
|
@@ -240,9 +245,7 @@ class MappedDataset:
|
|
|
240
245
|
if i == 0:
|
|
241
246
|
labels = self.get_merged_labels(val)
|
|
242
247
|
else:
|
|
243
|
-
labels += "_" + self.get_merged_labels(val).astype(str).astype(
|
|
244
|
-
"O"
|
|
245
|
-
)
|
|
248
|
+
labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
|
|
246
249
|
counter = Counter(labels) # type: ignore
|
|
247
250
|
counter = np.array([counter[label] for label in labels])
|
|
248
251
|
weights = scaler / (counter + scaler)
|
|
@@ -255,9 +258,7 @@ class MappedDataset:
|
|
|
255
258
|
for storage in self.storages:
|
|
256
259
|
with _Connect(storage) as store:
|
|
257
260
|
codes = self.get_codes(store, label_key)
|
|
258
|
-
labels = (
|
|
259
|
-
decode(codes) if isinstance(codes[0], bytes) else codes
|
|
260
|
-
)
|
|
261
|
+
labels = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
261
262
|
cats = self.get_categories(store, label_key)
|
|
262
263
|
if cats is not None:
|
|
263
264
|
cats = decode(cats) if isinstance(cats[0], bytes) else cats
|
|
@@ -277,9 +278,7 @@ class MappedDataset:
|
|
|
277
278
|
cats_merge.update(cats)
|
|
278
279
|
else:
|
|
279
280
|
codes = self.get_codes(store, label_key)
|
|
280
|
-
codes = (
|
|
281
|
-
decode(codes) if isinstance(codes[0], bytes) else codes
|
|
282
|
-
)
|
|
281
|
+
codes = decode(codes) if isinstance(codes[0], bytes) else codes
|
|
283
282
|
cats_merge.update(codes)
|
|
284
283
|
return cats_merge
|
|
285
284
|
|