scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__init__.py +4 -0
- scdataloader/__main__.py +188 -0
- scdataloader/collator.py +263 -0
- scdataloader/data.py +142 -159
- scdataloader/dataloader.py +318 -0
- scdataloader/mapped.py +24 -25
- scdataloader/preprocess.py +126 -145
- scdataloader/utils.py +99 -76
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/METADATA +33 -7
- scdataloader-0.0.3.dist-info/RECORD +15 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/WHEEL +1 -1
- scdataloader-0.0.2.dist-info/RECORD +0 -12
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -4,46 +4,47 @@ import lamindb as ln
|
|
|
4
4
|
import lnschema_bionty as lb
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from torch.utils.data import Dataset as torchDataset
|
|
7
|
-
|
|
7
|
+
from typing import Union
|
|
8
8
|
from scdataloader import mapped
|
|
9
|
+
import warnings
|
|
9
10
|
|
|
10
11
|
# TODO: manage load gene embeddings to make
|
|
11
12
|
# from scprint.dataloader.embedder import embed
|
|
12
|
-
from scdataloader.utils import get_ancestry_mapping,
|
|
13
|
+
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
13
14
|
|
|
14
15
|
LABELS_TOADD = {
|
|
15
|
-
"assay_ontology_term_id":
|
|
16
|
-
"10x transcription profiling",
|
|
17
|
-
"spatial transcriptomics",
|
|
18
|
-
"10x 3' transcription profiling",
|
|
19
|
-
"10x 5' transcription profiling",
|
|
20
|
-
|
|
21
|
-
"disease_ontology_term_id":
|
|
22
|
-
"metabolic disease",
|
|
23
|
-
"chronic kidney disease",
|
|
24
|
-
"chromosomal disorder",
|
|
25
|
-
"infectious disease",
|
|
26
|
-
"inflammatory disease",
|
|
16
|
+
"assay_ontology_term_id": {
|
|
17
|
+
"10x transcription profiling": "EFO:0030003",
|
|
18
|
+
"spatial transcriptomics": "EFO:0008994",
|
|
19
|
+
"10x 3' transcription profiling": "EFO:0030003",
|
|
20
|
+
"10x 5' transcription profiling": "EFO:0030004",
|
|
21
|
+
},
|
|
22
|
+
"disease_ontology_term_id": {
|
|
23
|
+
"metabolic disease": "MONDO:0005066",
|
|
24
|
+
"chronic kidney disease": "MONDO:0005300",
|
|
25
|
+
"chromosomal disorder": "MONDO:0019040",
|
|
26
|
+
"infectious disease": "MONDO:0005550",
|
|
27
|
+
"inflammatory disease": "MONDO:0021166",
|
|
27
28
|
# "immune system disease",
|
|
28
|
-
"disorder of development or morphogenesis",
|
|
29
|
-
"mitochondrial disease",
|
|
30
|
-
"psychiatric disorder",
|
|
31
|
-
"cancer or benign tumor",
|
|
32
|
-
"neoplasm",
|
|
33
|
-
|
|
34
|
-
"cell_type_ontology_term_id":
|
|
35
|
-
"progenitor cell",
|
|
36
|
-
"hematopoietic cell",
|
|
37
|
-
"myoblast",
|
|
38
|
-
"myeloid cell",
|
|
39
|
-
"neuron",
|
|
40
|
-
"electrically active cell",
|
|
41
|
-
"epithelial cell",
|
|
42
|
-
"secretory cell",
|
|
43
|
-
"stem cell",
|
|
44
|
-
"non-terminally differentiated cell",
|
|
45
|
-
"supporting cell",
|
|
46
|
-
|
|
29
|
+
"disorder of development or morphogenesis": "MONDO:0021147",
|
|
30
|
+
"mitochondrial disease": "MONDO:0044970",
|
|
31
|
+
"psychiatric disorder": "MONDO:0002025",
|
|
32
|
+
"cancer or benign tumor": "MONDO:0002025",
|
|
33
|
+
"neoplasm": "MONDO:0005070",
|
|
34
|
+
},
|
|
35
|
+
"cell_type_ontology_term_id": {
|
|
36
|
+
"progenitor cell": "CL:0011026",
|
|
37
|
+
"hematopoietic cell": "CL:0000988",
|
|
38
|
+
"myoblast": "CL:0000056",
|
|
39
|
+
"myeloid cell": "CL:0000763",
|
|
40
|
+
"neuron": "CL:0000540",
|
|
41
|
+
"electrically active cell": "CL:0000211",
|
|
42
|
+
"epithelial cell": "CL:0000066",
|
|
43
|
+
"secretory cell": "CL:0000151",
|
|
44
|
+
"stem cell": "CL:0000034",
|
|
45
|
+
"non-terminally differentiated cell": "CL:0000055",
|
|
46
|
+
"supporting cell": "CL:0000630",
|
|
47
|
+
},
|
|
47
48
|
}
|
|
48
49
|
|
|
49
50
|
|
|
@@ -66,16 +67,14 @@ class Dataset(torchDataset):
|
|
|
66
67
|
gene_embedding: dataframe containing the gene embeddings
|
|
67
68
|
organisms (list[str]): list of organisms to load
|
|
68
69
|
obs (list[str]): list of observations to load
|
|
69
|
-
|
|
70
|
-
|
|
70
|
+
clss_to_pred (list[str]): list of observations to encode
|
|
71
|
+
hierarchical_clss: list of observations to map to a hierarchy
|
|
71
72
|
"""
|
|
72
73
|
|
|
73
|
-
lamin_dataset: ln.
|
|
74
|
+
lamin_dataset: ln.Collection
|
|
74
75
|
genedf: pd.DataFrame = None
|
|
75
|
-
gene_embedding: pd.DataFrame =
|
|
76
|
-
|
|
77
|
-
)
|
|
78
|
-
organisms: list[str] = field(
|
|
76
|
+
# gene_embedding: pd.DataFrame = None # TODO: make it part of specialized dataset
|
|
77
|
+
organisms: Union[list[str], str] = field(
|
|
79
78
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
80
79
|
)
|
|
81
80
|
obs: list[str] = field(
|
|
@@ -94,43 +93,63 @@ class Dataset(torchDataset):
|
|
|
94
93
|
"nnz",
|
|
95
94
|
]
|
|
96
95
|
)
|
|
97
|
-
|
|
98
|
-
|
|
96
|
+
# set of obs to prepare for prediction (encode)
|
|
97
|
+
clss_to_pred: list[str] = field(default_factory=list)
|
|
98
|
+
# set of obs that need to be hierarchically prepared
|
|
99
|
+
hierarchical_clss: list[str] = field(default_factory=list)
|
|
100
|
+
join_vars: str = "None"
|
|
99
101
|
|
|
100
102
|
def __post_init__(self):
|
|
101
103
|
self.mapped_dataset = mapped.mapped(
|
|
102
104
|
self.lamin_dataset,
|
|
103
105
|
label_keys=self.obs,
|
|
104
|
-
encode_labels=self.
|
|
106
|
+
encode_labels=self.clss_to_pred,
|
|
105
107
|
stream=True,
|
|
106
108
|
parallel=True,
|
|
107
|
-
join_vars=
|
|
109
|
+
join_vars=self.join_vars,
|
|
108
110
|
)
|
|
109
111
|
print(
|
|
110
112
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
111
113
|
)
|
|
112
114
|
# generate tree from ontologies
|
|
113
|
-
if len(self.
|
|
114
|
-
self.define_hierarchies(self.
|
|
115
|
+
if len(self.hierarchical_clss) > 0:
|
|
116
|
+
self.define_hierarchies(self.hierarchical_clss)
|
|
117
|
+
if len(self.clss_to_pred) > 0:
|
|
118
|
+
for clss in self.clss_to_pred:
|
|
119
|
+
if clss not in self.hierarchical_clss:
|
|
120
|
+
# otherwise it's already been done
|
|
121
|
+
self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
|
|
122
|
+
clss
|
|
123
|
+
)
|
|
124
|
+
update = {}
|
|
125
|
+
c = 0
|
|
126
|
+
for k, v in self.mapped_dataset.encoders[clss].items():
|
|
127
|
+
if k == self.mapped_dataset.unknown_class:
|
|
128
|
+
update.update({k: v})
|
|
129
|
+
c += 1
|
|
130
|
+
self.class_topred[clss] -= set([k])
|
|
131
|
+
else:
|
|
132
|
+
update.update({k: v - c})
|
|
133
|
+
self.mapped_dataset.encoders[clss] = update
|
|
115
134
|
|
|
116
135
|
if self.genedf is None:
|
|
117
|
-
self.genedf =
|
|
136
|
+
self.genedf = load_genes(self.organisms)
|
|
118
137
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# [self.genedf.set_index("ensembl_gene_id"), self.gene_embedding],
|
|
124
|
-
# axis=1,
|
|
125
|
-
# join="inner",
|
|
126
|
-
# )
|
|
127
|
-
self.genedf.columns = self.genedf.columns.astype(str)
|
|
138
|
+
self.genedf.columns = self.genedf.columns.astype(str)
|
|
139
|
+
for organism in self.organisms:
|
|
140
|
+
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
141
|
+
self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
|
|
128
142
|
|
|
129
143
|
def __len__(self, **kwargs):
|
|
130
144
|
return self.mapped_dataset.__len__(**kwargs)
|
|
131
145
|
|
|
146
|
+
@property
|
|
147
|
+
def encoder(self):
|
|
148
|
+
return self.mapped_dataset.encoders
|
|
149
|
+
|
|
132
150
|
def __getitem__(self, *args, **kwargs):
|
|
133
151
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
152
|
+
#item.update({"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)})
|
|
134
153
|
# ret = {}
|
|
135
154
|
# ret["count"] = item[0]
|
|
136
155
|
# for i, val in enumerate(self.obs):
|
|
@@ -143,8 +162,7 @@ class Dataset(torchDataset):
|
|
|
143
162
|
def __repr__(self):
|
|
144
163
|
print(
|
|
145
164
|
"total dataset size is {} Gb".format(
|
|
146
|
-
sum([file.size for file in self.lamin_dataset.artifacts.all()])
|
|
147
|
-
/ 1e9
|
|
165
|
+
sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
|
|
148
166
|
)
|
|
149
167
|
)
|
|
150
168
|
print("---")
|
|
@@ -158,111 +176,14 @@ class Dataset(torchDataset):
|
|
|
158
176
|
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
159
177
|
)
|
|
160
178
|
)
|
|
161
|
-
print("embedding size is {}".format(self.gene_embedding.shape[1]))
|
|
179
|
+
# print("embedding size is {}".format(self.gene_embedding.shape[1]))
|
|
162
180
|
return ""
|
|
163
181
|
|
|
164
182
|
def get_label_weights(self, *args, **kwargs):
|
|
165
183
|
return self.mapped_dataset.get_label_weights(*args, **kwargs)
|
|
166
184
|
|
|
167
185
|
def get_unseen_mapped_dataset_elements(self, idx):
|
|
168
|
-
return [
|
|
169
|
-
str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")
|
|
170
|
-
]
|
|
171
|
-
|
|
172
|
-
def use_prior_network(
|
|
173
|
-
self, name="collectri", organism="human", split_complexes=True
|
|
174
|
-
):
|
|
175
|
-
"""
|
|
176
|
-
use_prior_network loads a prior GRN from a list of available networks.
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
name (str, optional): name of the network to load. Defaults to "collectri".
|
|
180
|
-
organism (str, optional): organism to load the network for. Defaults to "human".
|
|
181
|
-
split_complexes (bool, optional): whether to split complexes into individual genes. Defaults to True.
|
|
182
|
-
|
|
183
|
-
Raises:
|
|
184
|
-
ValueError: if the provided name is not amongst the available names.
|
|
185
|
-
"""
|
|
186
|
-
# TODO: use omnipath instead
|
|
187
|
-
if name == "tflink":
|
|
188
|
-
TFLINK = "https://cdn.netbiol.org/tflink/download_files/TFLink_Homo_sapiens_interactions_All_simpleFormat_v1.0.tsv.gz"
|
|
189
|
-
net = pd_load_cached(TFLINK)
|
|
190
|
-
net = net.rename(
|
|
191
|
-
columns={"Name.TF": "regulator", "Name.Target": "target"}
|
|
192
|
-
)
|
|
193
|
-
elif name == "htftarget":
|
|
194
|
-
HTFTARGET = "http://bioinfo.life.hust.edu.cn/static/hTFtarget/file_download/tf-target-infomation.txt"
|
|
195
|
-
net = pd_load_cached(HTFTARGET)
|
|
196
|
-
net = net.rename(columns={"TF": "regulator"})
|
|
197
|
-
elif name == "collectri":
|
|
198
|
-
import decoupler as dc
|
|
199
|
-
|
|
200
|
-
net = dc.get_collectri(
|
|
201
|
-
organism=organism, split_complexes=split_complexes
|
|
202
|
-
)
|
|
203
|
-
net = net.rename(columns={"source": "regulator"})
|
|
204
|
-
else:
|
|
205
|
-
raise ValueError(
|
|
206
|
-
f"provided name: '{name}' is not amongst the available names."
|
|
207
|
-
)
|
|
208
|
-
self.add_prior_network(net)
|
|
209
|
-
|
|
210
|
-
def add_prior_network(self, prior_network: pd.DataFrame, init_len):
|
|
211
|
-
# validate the network dataframe
|
|
212
|
-
required_columns: list[str] = ["target", "regulators"]
|
|
213
|
-
optional_columns: list[str] = ["type", "weight"]
|
|
214
|
-
|
|
215
|
-
for column in required_columns:
|
|
216
|
-
assert (
|
|
217
|
-
column in prior_network.columns
|
|
218
|
-
), f"Column '{column}' is missing in the provided network dataframe."
|
|
219
|
-
|
|
220
|
-
for column in optional_columns:
|
|
221
|
-
if column not in prior_network.columns:
|
|
222
|
-
print(
|
|
223
|
-
f"Optional column '{column}' is not present in the provided network dataframe."
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
assert (
|
|
227
|
-
prior_network["target"].dtype == "str"
|
|
228
|
-
), "Column 'target' should be of dtype 'str'."
|
|
229
|
-
assert (
|
|
230
|
-
prior_network["regulators"].dtype == "str"
|
|
231
|
-
), "Column 'regulators' should be of dtype 'str'."
|
|
232
|
-
|
|
233
|
-
if "type" in prior_network.columns:
|
|
234
|
-
assert (
|
|
235
|
-
prior_network["type"].dtype == "str"
|
|
236
|
-
), "Column 'type' should be of dtype 'str'."
|
|
237
|
-
|
|
238
|
-
if "weight" in prior_network.columns:
|
|
239
|
-
assert (
|
|
240
|
-
prior_network["weight"].dtype == "float"
|
|
241
|
-
), "Column 'weight' should be of dtype 'float'."
|
|
242
|
-
|
|
243
|
-
# TODO: check that we match the genes in the network to the genes in the dataset
|
|
244
|
-
|
|
245
|
-
print(
|
|
246
|
-
"loaded {:.2f}% of the edges".format(
|
|
247
|
-
(len(prior_network) / init_len) * 100
|
|
248
|
-
)
|
|
249
|
-
)
|
|
250
|
-
# TODO: transform it into a sparse matrix
|
|
251
|
-
self.prior_network = prior_network
|
|
252
|
-
self.network_size = len(prior_network)
|
|
253
|
-
# self.overlap =
|
|
254
|
-
# self.edge_freq
|
|
255
|
-
|
|
256
|
-
def load_genes(self, organisms):
|
|
257
|
-
organismdf = []
|
|
258
|
-
for o in organisms:
|
|
259
|
-
organism = lb.Gene(
|
|
260
|
-
organism=lb.Organism.filter(ontology_id=o).one()
|
|
261
|
-
).df()
|
|
262
|
-
organism["organism"] = o
|
|
263
|
-
organismdf.append(organism)
|
|
264
|
-
|
|
265
|
-
return pd.concat(organismdf)
|
|
186
|
+
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
266
187
|
|
|
267
188
|
# def load_embeddings(self, genedfs, embedding_size=128, cache=True):
|
|
268
189
|
# embeddings = []
|
|
@@ -344,10 +265,72 @@ class Dataset(torchDataset):
|
|
|
344
265
|
)
|
|
345
266
|
)
|
|
346
267
|
cats = self.mapped_dataset.get_merged_categories(label)
|
|
347
|
-
|
|
268
|
+
addition = set(LABELS_TOADD.get(label, {}).values())
|
|
269
|
+
cats |= addition
|
|
348
270
|
groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
|
|
349
271
|
for i, j in groupings.items():
|
|
350
272
|
if len(j) == 0:
|
|
351
273
|
groupings.pop(i)
|
|
352
274
|
self.class_groupings[label] = groupings
|
|
353
|
-
|
|
275
|
+
if label in self.clss_to_pred:
|
|
276
|
+
# if we have added new labels, we need to update the encoder with them too.
|
|
277
|
+
mlength = len(self.mapped_dataset.encoders[label])
|
|
278
|
+
mlength -= (
|
|
279
|
+
1
|
|
280
|
+
if self.mapped_dataset.unknown_class
|
|
281
|
+
in self.mapped_dataset.encoders[label].keys()
|
|
282
|
+
else 0
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
for i, v in enumerate(
|
|
286
|
+
addition - set(self.mapped_dataset.encoders[label].keys())
|
|
287
|
+
):
|
|
288
|
+
self.mapped_dataset.encoders[label].update({v: mlength + i})
|
|
289
|
+
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
290
|
+
|
|
291
|
+
self.class_topred[label] = lclass
|
|
292
|
+
c = 0
|
|
293
|
+
d = 0
|
|
294
|
+
update = {}
|
|
295
|
+
mlength = len(lclass)
|
|
296
|
+
# import pdb
|
|
297
|
+
|
|
298
|
+
# pdb.set_trace()
|
|
299
|
+
mlength -= (
|
|
300
|
+
1
|
|
301
|
+
if self.mapped_dataset.unknown_class
|
|
302
|
+
in self.mapped_dataset.encoders[label].keys()
|
|
303
|
+
else 0
|
|
304
|
+
)
|
|
305
|
+
for k, v in self.mapped_dataset.encoders[label].items():
|
|
306
|
+
if k in self.class_groupings[label].keys():
|
|
307
|
+
update.update({k: mlength + c})
|
|
308
|
+
c += 1
|
|
309
|
+
elif k == self.mapped_dataset.unknown_class:
|
|
310
|
+
update.update({k: v})
|
|
311
|
+
d += 1
|
|
312
|
+
self.class_topred[label] -= set([k])
|
|
313
|
+
else:
|
|
314
|
+
update.update({k: (v - c) - d})
|
|
315
|
+
self.mapped_dataset.encoders[label] = update
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class SimpleAnnDataset:
|
|
319
|
+
def __init__(self, adata, obs_to_output=[], layer=None):
|
|
320
|
+
self.adata = adata
|
|
321
|
+
self.obs_to_output = obs_to_output
|
|
322
|
+
self.layer = layer
|
|
323
|
+
|
|
324
|
+
def __len__(self):
|
|
325
|
+
return self.adata.shape[0]
|
|
326
|
+
|
|
327
|
+
def __getitem__(self, idx):
|
|
328
|
+
with warnings.catch_warnings():
|
|
329
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
330
|
+
if self.layer is not None:
|
|
331
|
+
out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
|
|
332
|
+
else:
|
|
333
|
+
out = {"x": self.adata.X[idx].toarray().reshape(-1)}
|
|
334
|
+
for i in self.obs_to_output:
|
|
335
|
+
out.update({i: self.adata.obs.iloc[idx][i]})
|
|
336
|
+
return out
|