scdataloader 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +4 -0
- scdataloader/__main__.py +209 -0
- scdataloader/collator.py +307 -0
- scdataloader/config.py +106 -0
- scdataloader/data.py +181 -218
- scdataloader/datamodule.py +375 -0
- scdataloader/mapped.py +46 -32
- scdataloader/preprocess.py +524 -208
- scdataloader/utils.py +189 -123
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/METADATA +77 -7
- scdataloader-0.0.4.dist-info/RECORD +16 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/WHEEL +1 -1
- scdataloader-0.0.2.dist-info/RECORD +0 -12
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.4.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -1,84 +1,53 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
3
|
import lamindb as ln
|
|
4
|
-
import
|
|
4
|
+
import bionty as bt
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from torch.utils.data import Dataset as torchDataset
|
|
7
|
-
|
|
7
|
+
from typing import Union, Optional, Literal
|
|
8
8
|
from scdataloader import mapped
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from anndata import AnnData
|
|
12
|
+
|
|
13
|
+
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
9
14
|
|
|
10
|
-
|
|
11
|
-
# from scprint.dataloader.embedder import embed
|
|
12
|
-
from scdataloader.utils import get_ancestry_mapping, pd_load_cached
|
|
13
|
-
|
|
14
|
-
LABELS_TOADD = {
|
|
15
|
-
"assay_ontology_term_id": [
|
|
16
|
-
"10x transcription profiling",
|
|
17
|
-
"spatial transcriptomics",
|
|
18
|
-
"10x 3' transcription profiling",
|
|
19
|
-
"10x 5' transcription profiling",
|
|
20
|
-
],
|
|
21
|
-
"disease_ontology_term_id": [
|
|
22
|
-
"metabolic disease",
|
|
23
|
-
"chronic kidney disease",
|
|
24
|
-
"chromosomal disorder",
|
|
25
|
-
"infectious disease",
|
|
26
|
-
"inflammatory disease",
|
|
27
|
-
# "immune system disease",
|
|
28
|
-
"disorder of development or morphogenesis",
|
|
29
|
-
"mitochondrial disease",
|
|
30
|
-
"psychiatric disorder",
|
|
31
|
-
"cancer or benign tumor",
|
|
32
|
-
"neoplasm",
|
|
33
|
-
],
|
|
34
|
-
"cell_type_ontology_term_id": [
|
|
35
|
-
"progenitor cell",
|
|
36
|
-
"hematopoietic cell",
|
|
37
|
-
"myoblast",
|
|
38
|
-
"myeloid cell",
|
|
39
|
-
"neuron",
|
|
40
|
-
"electrically active cell",
|
|
41
|
-
"epithelial cell",
|
|
42
|
-
"secretory cell",
|
|
43
|
-
"stem cell",
|
|
44
|
-
"non-terminally differentiated cell",
|
|
45
|
-
"supporting cell",
|
|
46
|
-
],
|
|
47
|
-
}
|
|
15
|
+
from .config import LABELS_TOADD
|
|
48
16
|
|
|
49
17
|
|
|
50
18
|
@dataclass
|
|
51
19
|
class Dataset(torchDataset):
|
|
52
20
|
"""
|
|
53
|
-
Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
|
|
21
|
+
Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
|
|
54
22
|
|
|
55
|
-
|
|
23
|
+
This serves as a wrapper around lamin's mappedCollection to provide more features,
|
|
24
|
+
mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
|
|
25
|
+
|
|
26
|
+
For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
|
|
56
27
|
|
|
57
28
|
.. note::
|
|
58
29
|
|
|
59
|
-
A
|
|
30
|
+
A related data loader exists `here
|
|
60
31
|
<https://github.com/Genentech/scimilarity>`__.
|
|
61
32
|
|
|
62
|
-
|
|
33
|
+
Args:
|
|
63
34
|
----
|
|
64
35
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
65
36
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
66
|
-
gene_embedding: dataframe containing the gene embeddings
|
|
67
37
|
organisms (list[str]): list of organisms to load
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
38
|
+
(for now only validates the the genes map to this organism)
|
|
39
|
+
obs (list[str]): list of observations to load from the Collection
|
|
40
|
+
clss_to_pred (list[str]): list of observations to encode
|
|
41
|
+
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
42
|
+
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
71
43
|
"""
|
|
72
44
|
|
|
73
|
-
lamin_dataset: ln.
|
|
74
|
-
genedf: pd.DataFrame = None
|
|
75
|
-
|
|
76
|
-
None # TODO: make it part of specialized dataset
|
|
77
|
-
)
|
|
78
|
-
organisms: list[str] = field(
|
|
45
|
+
lamin_dataset: ln.Collection
|
|
46
|
+
genedf: Optional[pd.DataFrame] = None
|
|
47
|
+
organisms: Optional[Union[list[str], str]] = field(
|
|
79
48
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
80
49
|
)
|
|
81
|
-
obs: list[str] = field(
|
|
50
|
+
obs: Optional[list[str]] = field(
|
|
82
51
|
default_factory=[
|
|
83
52
|
"self_reported_ethnicity_ontology_term_id",
|
|
84
53
|
"assay_ontology_term_id",
|
|
@@ -89,48 +58,75 @@ class Dataset(torchDataset):
|
|
|
89
58
|
"sex_ontology_term_id",
|
|
90
59
|
#'dataset_id',
|
|
91
60
|
#'cell_culture',
|
|
92
|
-
"dpt_group",
|
|
93
|
-
"heat_diff",
|
|
94
|
-
"nnz",
|
|
61
|
+
#"dpt_group",
|
|
62
|
+
#"heat_diff",
|
|
63
|
+
#"nnz",
|
|
95
64
|
]
|
|
96
65
|
)
|
|
97
|
-
|
|
98
|
-
|
|
66
|
+
# set of obs to prepare for prediction (encode)
|
|
67
|
+
clss_to_pred: Optional[list[str]] = field(default_factory=list)
|
|
68
|
+
# set of obs that need to be hierarchically prepared
|
|
69
|
+
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
70
|
+
join_vars: Optional[Literal["auto", "inner", "None"]] = "None"
|
|
99
71
|
|
|
100
72
|
def __post_init__(self):
|
|
101
73
|
self.mapped_dataset = mapped.mapped(
|
|
102
74
|
self.lamin_dataset,
|
|
103
75
|
label_keys=self.obs,
|
|
104
|
-
encode_labels=self.
|
|
76
|
+
encode_labels=self.clss_to_pred,
|
|
105
77
|
stream=True,
|
|
106
78
|
parallel=True,
|
|
107
|
-
join_vars=
|
|
79
|
+
join_vars=self.join_vars,
|
|
108
80
|
)
|
|
109
81
|
print(
|
|
110
82
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
111
83
|
)
|
|
84
|
+
self.class_groupings = {}
|
|
85
|
+
self.class_topred = {}
|
|
112
86
|
# generate tree from ontologies
|
|
113
|
-
if len(self.
|
|
114
|
-
self.define_hierarchies(self.
|
|
87
|
+
if len(self.hierarchical_clss) > 0:
|
|
88
|
+
self.define_hierarchies(self.hierarchical_clss)
|
|
89
|
+
if len(self.clss_to_pred) > 0:
|
|
90
|
+
for clss in self.clss_to_pred:
|
|
91
|
+
if clss not in self.hierarchical_clss:
|
|
92
|
+
# otherwise it's already been done
|
|
93
|
+
self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
|
|
94
|
+
clss
|
|
95
|
+
)
|
|
96
|
+
update = {}
|
|
97
|
+
c = 0
|
|
98
|
+
for k, v in self.mapped_dataset.encoders[clss].items():
|
|
99
|
+
if k == self.mapped_dataset.unknown_class:
|
|
100
|
+
update.update({k: v})
|
|
101
|
+
c += 1
|
|
102
|
+
self.class_topred[clss] -= set([k])
|
|
103
|
+
else:
|
|
104
|
+
update.update({k: v - c})
|
|
105
|
+
self.mapped_dataset.encoders[clss] = update
|
|
115
106
|
|
|
116
107
|
if self.genedf is None:
|
|
117
|
-
self.genedf =
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# [self.genedf.set_index("ensembl_gene_id"), self.gene_embedding],
|
|
124
|
-
# axis=1,
|
|
125
|
-
# join="inner",
|
|
126
|
-
# )
|
|
127
|
-
self.genedf.columns = self.genedf.columns.astype(str)
|
|
108
|
+
self.genedf = load_genes(self.organisms)
|
|
109
|
+
|
|
110
|
+
self.genedf.columns = self.genedf.columns.astype(str)
|
|
111
|
+
for organism in self.organisms:
|
|
112
|
+
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
113
|
+
self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
|
|
128
114
|
|
|
129
115
|
def __len__(self, **kwargs):
|
|
130
116
|
return self.mapped_dataset.__len__(**kwargs)
|
|
131
117
|
|
|
118
|
+
@property
|
|
119
|
+
def encoder(self):
|
|
120
|
+
return self.mapped_dataset.encoders
|
|
121
|
+
|
|
132
122
|
def __getitem__(self, *args, **kwargs):
|
|
133
123
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
124
|
+
# import pdb
|
|
125
|
+
|
|
126
|
+
# pdb.set_trace()
|
|
127
|
+
# item.update(
|
|
128
|
+
# {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
|
|
129
|
+
# )
|
|
134
130
|
# ret = {}
|
|
135
131
|
# ret["count"] = item[0]
|
|
136
132
|
# for i, val in enumerate(self.obs):
|
|
@@ -141,149 +137,36 @@ class Dataset(torchDataset):
|
|
|
141
137
|
return item
|
|
142
138
|
|
|
143
139
|
def __repr__(self):
|
|
144
|
-
|
|
145
|
-
"total dataset size is {} Gb".format(
|
|
146
|
-
sum([file.size for file in self.lamin_dataset.artifacts.all()])
|
|
147
|
-
/ 1e9
|
|
140
|
+
return (
|
|
141
|
+
"total dataset size is {} Gb\n".format(
|
|
142
|
+
sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
|
|
148
143
|
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
"
|
|
158
|
-
|
|
144
|
+
+ "---\n"
|
|
145
|
+
+ "dataset contains:\n"
|
|
146
|
+
+ " {} cells\n".format(self.mapped_dataset.__len__())
|
|
147
|
+
+ " {} genes\n".format(self.genedf.shape[0])
|
|
148
|
+
+ " {} labels\n".format(len(self.obs))
|
|
149
|
+
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
150
|
+
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
151
|
+
+ " {} join_vars\n".format(len(self.join_vars))
|
|
152
|
+
+ " {} organisms\n".format(len(self.organisms))
|
|
153
|
+
+ (
|
|
154
|
+
"dataset contains {} classes to predict\n".format(
|
|
155
|
+
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
156
|
+
)
|
|
157
|
+
if len(self.class_topred) > 0
|
|
158
|
+
else ""
|
|
159
159
|
)
|
|
160
160
|
)
|
|
161
|
-
print("embedding size is {}".format(self.gene_embedding.shape[1]))
|
|
162
|
-
return ""
|
|
163
161
|
|
|
164
162
|
def get_label_weights(self, *args, **kwargs):
|
|
165
163
|
return self.mapped_dataset.get_label_weights(*args, **kwargs)
|
|
166
164
|
|
|
167
|
-
def get_unseen_mapped_dataset_elements(self, idx):
|
|
168
|
-
return [
|
|
169
|
-
str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")
|
|
170
|
-
]
|
|
171
|
-
|
|
172
|
-
def use_prior_network(
|
|
173
|
-
self, name="collectri", organism="human", split_complexes=True
|
|
174
|
-
):
|
|
175
|
-
"""
|
|
176
|
-
use_prior_network loads a prior GRN from a list of available networks.
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
name (str, optional): name of the network to load. Defaults to "collectri".
|
|
180
|
-
organism (str, optional): organism to load the network for. Defaults to "human".
|
|
181
|
-
split_complexes (bool, optional): whether to split complexes into individual genes. Defaults to True.
|
|
182
|
-
|
|
183
|
-
Raises:
|
|
184
|
-
ValueError: if the provided name is not amongst the available names.
|
|
185
|
-
"""
|
|
186
|
-
# TODO: use omnipath instead
|
|
187
|
-
if name == "tflink":
|
|
188
|
-
TFLINK = "https://cdn.netbiol.org/tflink/download_files/TFLink_Homo_sapiens_interactions_All_simpleFormat_v1.0.tsv.gz"
|
|
189
|
-
net = pd_load_cached(TFLINK)
|
|
190
|
-
net = net.rename(
|
|
191
|
-
columns={"Name.TF": "regulator", "Name.Target": "target"}
|
|
192
|
-
)
|
|
193
|
-
elif name == "htftarget":
|
|
194
|
-
HTFTARGET = "http://bioinfo.life.hust.edu.cn/static/hTFtarget/file_download/tf-target-infomation.txt"
|
|
195
|
-
net = pd_load_cached(HTFTARGET)
|
|
196
|
-
net = net.rename(columns={"TF": "regulator"})
|
|
197
|
-
elif name == "collectri":
|
|
198
|
-
import decoupler as dc
|
|
199
|
-
|
|
200
|
-
net = dc.get_collectri(
|
|
201
|
-
organism=organism, split_complexes=split_complexes
|
|
202
|
-
)
|
|
203
|
-
net = net.rename(columns={"source": "regulator"})
|
|
204
|
-
else:
|
|
205
|
-
raise ValueError(
|
|
206
|
-
f"provided name: '{name}' is not amongst the available names."
|
|
207
|
-
)
|
|
208
|
-
self.add_prior_network(net)
|
|
209
|
-
|
|
210
|
-
def add_prior_network(self, prior_network: pd.DataFrame, init_len):
|
|
211
|
-
# validate the network dataframe
|
|
212
|
-
required_columns: list[str] = ["target", "regulators"]
|
|
213
|
-
optional_columns: list[str] = ["type", "weight"]
|
|
214
|
-
|
|
215
|
-
for column in required_columns:
|
|
216
|
-
assert (
|
|
217
|
-
column in prior_network.columns
|
|
218
|
-
), f"Column '{column}' is missing in the provided network dataframe."
|
|
219
|
-
|
|
220
|
-
for column in optional_columns:
|
|
221
|
-
if column not in prior_network.columns:
|
|
222
|
-
print(
|
|
223
|
-
f"Optional column '{column}' is not present in the provided network dataframe."
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
assert (
|
|
227
|
-
prior_network["target"].dtype == "str"
|
|
228
|
-
), "Column 'target' should be of dtype 'str'."
|
|
229
|
-
assert (
|
|
230
|
-
prior_network["regulators"].dtype == "str"
|
|
231
|
-
), "Column 'regulators' should be of dtype 'str'."
|
|
165
|
+
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
166
|
+
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
232
167
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
prior_network["type"].dtype == "str"
|
|
236
|
-
), "Column 'type' should be of dtype 'str'."
|
|
237
|
-
|
|
238
|
-
if "weight" in prior_network.columns:
|
|
239
|
-
assert (
|
|
240
|
-
prior_network["weight"].dtype == "float"
|
|
241
|
-
), "Column 'weight' should be of dtype 'float'."
|
|
242
|
-
|
|
243
|
-
# TODO: check that we match the genes in the network to the genes in the dataset
|
|
244
|
-
|
|
245
|
-
print(
|
|
246
|
-
"loaded {:.2f}% of the edges".format(
|
|
247
|
-
(len(prior_network) / init_len) * 100
|
|
248
|
-
)
|
|
249
|
-
)
|
|
250
|
-
# TODO: transform it into a sparse matrix
|
|
251
|
-
self.prior_network = prior_network
|
|
252
|
-
self.network_size = len(prior_network)
|
|
253
|
-
# self.overlap =
|
|
254
|
-
# self.edge_freq
|
|
255
|
-
|
|
256
|
-
def load_genes(self, organisms):
|
|
257
|
-
organismdf = []
|
|
258
|
-
for o in organisms:
|
|
259
|
-
organism = lb.Gene(
|
|
260
|
-
organism=lb.Organism.filter(ontology_id=o).one()
|
|
261
|
-
).df()
|
|
262
|
-
organism["organism"] = o
|
|
263
|
-
organismdf.append(organism)
|
|
264
|
-
|
|
265
|
-
return pd.concat(organismdf)
|
|
266
|
-
|
|
267
|
-
# def load_embeddings(self, genedfs, embedding_size=128, cache=True):
|
|
268
|
-
# embeddings = []
|
|
269
|
-
# for o in self.organisms:
|
|
270
|
-
# genedf = genedfs[genedfs.organism == o]
|
|
271
|
-
# org_name = lb.Organism.filter(ontology_id=o).one().scientific_name
|
|
272
|
-
# embedding = embed(
|
|
273
|
-
# genedf=genedf,
|
|
274
|
-
# organism=org_name,
|
|
275
|
-
# cache=cache,
|
|
276
|
-
# fasta_path="/tmp/data/fasta/",
|
|
277
|
-
# embedding_size=embedding_size,
|
|
278
|
-
# )
|
|
279
|
-
# genedf = pd.concat(
|
|
280
|
-
# [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
|
|
281
|
-
# )
|
|
282
|
-
# genedf.columns = genedf.columns.astype(str)
|
|
283
|
-
# embeddings.append(genedf)
|
|
284
|
-
# return pd.concat(embeddings)
|
|
285
|
-
|
|
286
|
-
def define_hierarchies(self, labels):
|
|
168
|
+
def define_hierarchies(self, labels: list[str]):
|
|
169
|
+
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
287
170
|
self.class_groupings = {}
|
|
288
171
|
self.class_topred = {}
|
|
289
172
|
for label in labels:
|
|
@@ -302,37 +185,37 @@ class Dataset(torchDataset):
|
|
|
302
185
|
)
|
|
303
186
|
elif label == "cell_type_ontology_term_id":
|
|
304
187
|
parentdf = (
|
|
305
|
-
|
|
188
|
+
bt.CellType.filter()
|
|
306
189
|
.df(include=["parents__ontology_id"])
|
|
307
190
|
.set_index("ontology_id")
|
|
308
191
|
)
|
|
309
192
|
elif label == "tissue_ontology_term_id":
|
|
310
193
|
parentdf = (
|
|
311
|
-
|
|
194
|
+
bt.Tissue.filter()
|
|
312
195
|
.df(include=["parents__ontology_id"])
|
|
313
196
|
.set_index("ontology_id")
|
|
314
197
|
)
|
|
315
198
|
elif label == "disease_ontology_term_id":
|
|
316
199
|
parentdf = (
|
|
317
|
-
|
|
200
|
+
bt.Disease.filter()
|
|
318
201
|
.df(include=["parents__ontology_id"])
|
|
319
202
|
.set_index("ontology_id")
|
|
320
203
|
)
|
|
321
204
|
elif label == "development_stage_ontology_term_id":
|
|
322
205
|
parentdf = (
|
|
323
|
-
|
|
206
|
+
bt.DevelopmentalStage.filter()
|
|
324
207
|
.df(include=["parents__ontology_id"])
|
|
325
208
|
.set_index("ontology_id")
|
|
326
209
|
)
|
|
327
210
|
elif label == "assay_ontology_term_id":
|
|
328
211
|
parentdf = (
|
|
329
|
-
|
|
212
|
+
bt.ExperimentalFactor.filter()
|
|
330
213
|
.df(include=["parents__ontology_id"])
|
|
331
214
|
.set_index("ontology_id")
|
|
332
215
|
)
|
|
333
216
|
elif label == "self_reported_ethnicity_ontology_term_id":
|
|
334
217
|
parentdf = (
|
|
335
|
-
|
|
218
|
+
bt.Ethnicity.filter()
|
|
336
219
|
.df(include=["parents__ontology_id"])
|
|
337
220
|
.set_index("ontology_id")
|
|
338
221
|
)
|
|
@@ -344,10 +227,90 @@ class Dataset(torchDataset):
|
|
|
344
227
|
)
|
|
345
228
|
)
|
|
346
229
|
cats = self.mapped_dataset.get_merged_categories(label)
|
|
347
|
-
|
|
230
|
+
addition = set(LABELS_TOADD.get(label, {}).values())
|
|
231
|
+
cats |= addition
|
|
232
|
+
# import pdb
|
|
233
|
+
|
|
234
|
+
# pdb.set_trace()
|
|
348
235
|
groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
|
|
349
236
|
for i, j in groupings.items():
|
|
350
237
|
if len(j) == 0:
|
|
351
238
|
groupings.pop(i)
|
|
352
239
|
self.class_groupings[label] = groupings
|
|
353
|
-
|
|
240
|
+
if label in self.clss_to_pred:
|
|
241
|
+
# if we have added new labels, we need to update the encoder with them too.
|
|
242
|
+
mlength = len(self.mapped_dataset.encoders[label])
|
|
243
|
+
mlength -= (
|
|
244
|
+
1
|
|
245
|
+
if self.mapped_dataset.unknown_class
|
|
246
|
+
in self.mapped_dataset.encoders[label].keys()
|
|
247
|
+
else 0
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
for i, v in enumerate(
|
|
251
|
+
addition - set(self.mapped_dataset.encoders[label].keys())
|
|
252
|
+
):
|
|
253
|
+
self.mapped_dataset.encoders[label].update({v: mlength + i})
|
|
254
|
+
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
255
|
+
|
|
256
|
+
self.class_topred[label] = lclass
|
|
257
|
+
c = 0
|
|
258
|
+
d = 0
|
|
259
|
+
update = {}
|
|
260
|
+
mlength = len(lclass)
|
|
261
|
+
# import pdb
|
|
262
|
+
|
|
263
|
+
# pdb.set_trace()
|
|
264
|
+
mlength -= (
|
|
265
|
+
1
|
|
266
|
+
if self.mapped_dataset.unknown_class
|
|
267
|
+
in self.mapped_dataset.encoders[label].keys()
|
|
268
|
+
else 0
|
|
269
|
+
)
|
|
270
|
+
for k, v in self.mapped_dataset.encoders[label].items():
|
|
271
|
+
if k in self.class_groupings[label].keys():
|
|
272
|
+
update.update({k: mlength + c})
|
|
273
|
+
c += 1
|
|
274
|
+
elif k == self.mapped_dataset.unknown_class:
|
|
275
|
+
update.update({k: v})
|
|
276
|
+
d += 1
|
|
277
|
+
self.class_topred[label] -= set([k])
|
|
278
|
+
else:
|
|
279
|
+
update.update({k: (v - c) - d})
|
|
280
|
+
self.mapped_dataset.encoders[label] = update
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class SimpleAnnDataset:
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
adata: AnnData,
|
|
287
|
+
obs_to_output: Optional[list[str]] = [],
|
|
288
|
+
layer: Optional[str] = None,
|
|
289
|
+
):
|
|
290
|
+
"""
|
|
291
|
+
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
292
|
+
scDataloader and with your model during inference.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
----
|
|
296
|
+
adata (anndata.AnnData): anndata object to use
|
|
297
|
+
obs_to_output (list[str]): list of observations to output from anndata.obs
|
|
298
|
+
layer (str): layer of the anndata to use
|
|
299
|
+
"""
|
|
300
|
+
self.adata = adata
|
|
301
|
+
self.obs_to_output = obs_to_output
|
|
302
|
+
self.layer = layer
|
|
303
|
+
|
|
304
|
+
def __len__(self):
|
|
305
|
+
return self.adata.shape[0]
|
|
306
|
+
|
|
307
|
+
def __getitem__(self, idx):
|
|
308
|
+
with warnings.catch_warnings():
|
|
309
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
310
|
+
if self.layer is not None:
|
|
311
|
+
out = {"x": self.adata.layers[self.layer][idx].toarray().reshape(-1)}
|
|
312
|
+
else:
|
|
313
|
+
out = {"x": self.adata.X[idx].toarray().reshape(-1)}
|
|
314
|
+
for i in self.obs_to_output:
|
|
315
|
+
out.update({i: self.adata.obs.iloc[idx][i]})
|
|
316
|
+
return out
|