scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +66 -42
- scdataloader/collator.py +136 -67
- scdataloader/config.py +112 -0
- scdataloader/data.py +160 -169
- scdataloader/datamodule.py +403 -0
- scdataloader/mapped.py +285 -109
- scdataloader/preprocess.py +240 -109
- scdataloader/utils.py +162 -70
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/METADATA +87 -18
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -1,83 +1,57 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
3
|
import lamindb as ln
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
# ln.connect("scprint")
|
|
6
|
+
|
|
7
|
+
import bionty as bt
|
|
5
8
|
import pandas as pd
|
|
6
9
|
from torch.utils.data import Dataset as torchDataset
|
|
7
|
-
from typing import Union
|
|
10
|
+
from typing import Union, Optional, Literal
|
|
8
11
|
from scdataloader import mapped
|
|
9
12
|
import warnings
|
|
10
13
|
|
|
11
|
-
|
|
12
|
-
|
|
14
|
+
from anndata import AnnData
|
|
15
|
+
from scipy.sparse import issparse
|
|
16
|
+
|
|
13
17
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
14
18
|
|
|
15
|
-
|
|
16
|
-
"assay_ontology_term_id": {
|
|
17
|
-
"10x transcription profiling": "EFO:0030003",
|
|
18
|
-
"spatial transcriptomics": "EFO:0008994",
|
|
19
|
-
"10x 3' transcription profiling": "EFO:0030003",
|
|
20
|
-
"10x 5' transcription profiling": "EFO:0030004",
|
|
21
|
-
},
|
|
22
|
-
"disease_ontology_term_id": {
|
|
23
|
-
"metabolic disease": "MONDO:0005066",
|
|
24
|
-
"chronic kidney disease": "MONDO:0005300",
|
|
25
|
-
"chromosomal disorder": "MONDO:0019040",
|
|
26
|
-
"infectious disease": "MONDO:0005550",
|
|
27
|
-
"inflammatory disease": "MONDO:0021166",
|
|
28
|
-
# "immune system disease",
|
|
29
|
-
"disorder of development or morphogenesis": "MONDO:0021147",
|
|
30
|
-
"mitochondrial disease": "MONDO:0044970",
|
|
31
|
-
"psychiatric disorder": "MONDO:0002025",
|
|
32
|
-
"cancer or benign tumor": "MONDO:0002025",
|
|
33
|
-
"neoplasm": "MONDO:0005070",
|
|
34
|
-
},
|
|
35
|
-
"cell_type_ontology_term_id": {
|
|
36
|
-
"progenitor cell": "CL:0011026",
|
|
37
|
-
"hematopoietic cell": "CL:0000988",
|
|
38
|
-
"myoblast": "CL:0000056",
|
|
39
|
-
"myeloid cell": "CL:0000763",
|
|
40
|
-
"neuron": "CL:0000540",
|
|
41
|
-
"electrically active cell": "CL:0000211",
|
|
42
|
-
"epithelial cell": "CL:0000066",
|
|
43
|
-
"secretory cell": "CL:0000151",
|
|
44
|
-
"stem cell": "CL:0000034",
|
|
45
|
-
"non-terminally differentiated cell": "CL:0000055",
|
|
46
|
-
"supporting cell": "CL:0000630",
|
|
47
|
-
},
|
|
48
|
-
}
|
|
19
|
+
from .config import LABELS_TOADD
|
|
49
20
|
|
|
50
21
|
|
|
51
22
|
@dataclass
|
|
52
23
|
class Dataset(torchDataset):
|
|
53
24
|
"""
|
|
54
|
-
Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
|
|
25
|
+
Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
|
|
55
26
|
|
|
56
|
-
|
|
27
|
+
This serves as a wrapper around lamin's mappedCollection to provide more features,
|
|
28
|
+
mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
|
|
29
|
+
|
|
30
|
+
For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
|
|
57
31
|
|
|
58
32
|
.. note::
|
|
59
33
|
|
|
60
|
-
A
|
|
34
|
+
A related data loader exists `here
|
|
61
35
|
<https://github.com/Genentech/scimilarity>`__.
|
|
62
36
|
|
|
63
|
-
|
|
37
|
+
Args:
|
|
64
38
|
----
|
|
65
39
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
66
40
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
67
|
-
gene_embedding: dataframe containing the gene embeddings
|
|
68
41
|
organisms (list[str]): list of organisms to load
|
|
69
|
-
|
|
42
|
+
(for now only validates the the genes map to this organism)
|
|
43
|
+
obs (list[str]): list of observations to load from the Collection
|
|
70
44
|
clss_to_pred (list[str]): list of observations to encode
|
|
71
|
-
|
|
45
|
+
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
46
|
+
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
72
47
|
"""
|
|
73
48
|
|
|
74
49
|
lamin_dataset: ln.Collection
|
|
75
|
-
genedf: pd.DataFrame = None
|
|
76
|
-
|
|
77
|
-
organisms: Union[list[str], str] = field(
|
|
50
|
+
genedf: Optional[pd.DataFrame] = None
|
|
51
|
+
organisms: Optional[Union[list[str], str]] = field(
|
|
78
52
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
79
53
|
)
|
|
80
|
-
obs: list[str] = field(
|
|
54
|
+
obs: Optional[list[str]] = field(
|
|
81
55
|
default_factory=[
|
|
82
56
|
"self_reported_ethnicity_ontology_term_id",
|
|
83
57
|
"assay_ontology_term_id",
|
|
@@ -88,29 +62,32 @@ class Dataset(torchDataset):
|
|
|
88
62
|
"sex_ontology_term_id",
|
|
89
63
|
#'dataset_id',
|
|
90
64
|
#'cell_culture',
|
|
91
|
-
"dpt_group",
|
|
92
|
-
"heat_diff",
|
|
93
|
-
"nnz",
|
|
65
|
+
# "dpt_group",
|
|
66
|
+
# "heat_diff",
|
|
67
|
+
# "nnz",
|
|
94
68
|
]
|
|
95
69
|
)
|
|
96
70
|
# set of obs to prepare for prediction (encode)
|
|
97
|
-
clss_to_pred: list[str] = field(default_factory=list)
|
|
71
|
+
clss_to_pred: Optional[list[str]] = field(default_factory=list)
|
|
98
72
|
# set of obs that need to be hierarchically prepared
|
|
99
|
-
hierarchical_clss: list[str] = field(default_factory=list)
|
|
100
|
-
join_vars:
|
|
73
|
+
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
74
|
+
join_vars: Literal["inner", "outer"] | None = None
|
|
101
75
|
|
|
102
76
|
def __post_init__(self):
|
|
103
77
|
self.mapped_dataset = mapped.mapped(
|
|
104
78
|
self.lamin_dataset,
|
|
105
79
|
label_keys=self.obs,
|
|
80
|
+
join=self.join_vars,
|
|
106
81
|
encode_labels=self.clss_to_pred,
|
|
82
|
+
unknown_label="unknown",
|
|
107
83
|
stream=True,
|
|
108
84
|
parallel=True,
|
|
109
|
-
join_vars=self.join_vars,
|
|
110
85
|
)
|
|
111
86
|
print(
|
|
112
87
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
113
88
|
)
|
|
89
|
+
self.labels_groupings = {}
|
|
90
|
+
self.class_topred = {}
|
|
114
91
|
# generate tree from ontologies
|
|
115
92
|
if len(self.hierarchical_clss) > 0:
|
|
116
93
|
self.define_hierarchies(self.hierarchical_clss)
|
|
@@ -121,24 +98,19 @@ class Dataset(torchDataset):
|
|
|
121
98
|
self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
|
|
122
99
|
clss
|
|
123
100
|
)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
else:
|
|
132
|
-
update.update({k: v - c})
|
|
133
|
-
self.mapped_dataset.encoders[clss] = update
|
|
101
|
+
if (
|
|
102
|
+
self.mapped_dataset.unknown_label
|
|
103
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
104
|
+
):
|
|
105
|
+
self.class_topred[clss] -= set(
|
|
106
|
+
[self.mapped_dataset.unknown_label]
|
|
107
|
+
)
|
|
134
108
|
|
|
135
109
|
if self.genedf is None:
|
|
136
110
|
self.genedf = load_genes(self.organisms)
|
|
137
111
|
|
|
138
112
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
139
|
-
|
|
140
|
-
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
141
|
-
self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
|
|
113
|
+
self.mapped_dataset._check_aligned_vars(self.genedf.index.tolist())
|
|
142
114
|
|
|
143
115
|
def __len__(self, **kwargs):
|
|
144
116
|
return self.mapped_dataset.__len__(**kwargs)
|
|
@@ -149,66 +121,66 @@ class Dataset(torchDataset):
|
|
|
149
121
|
|
|
150
122
|
def __getitem__(self, *args, **kwargs):
|
|
151
123
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
152
|
-
#item.update({"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)})
|
|
153
|
-
# ret = {}
|
|
154
|
-
# ret["count"] = item[0]
|
|
155
|
-
# for i, val in enumerate(self.obs):
|
|
156
|
-
# ret[val] = item[1][i]
|
|
157
|
-
## mark unseen genes with a flag
|
|
158
|
-
## send the associated
|
|
159
|
-
# print(item[0].shape)
|
|
160
124
|
return item
|
|
161
125
|
|
|
162
126
|
def __repr__(self):
|
|
163
|
-
|
|
164
|
-
"total dataset size is {} Gb".format(
|
|
127
|
+
return (
|
|
128
|
+
"total dataset size is {} Gb\n".format(
|
|
165
129
|
sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
|
|
166
130
|
)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
131
|
+
+ "---\n"
|
|
132
|
+
+ "dataset contains:\n"
|
|
133
|
+
+ " {} cells\n".format(self.mapped_dataset.__len__())
|
|
134
|
+
+ " {} genes\n".format(self.genedf.shape[0])
|
|
135
|
+
+ " {} labels\n".format(len(self.obs))
|
|
136
|
+
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
137
|
+
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
138
|
+
+ " {} organisms\n".format(len(self.organisms))
|
|
139
|
+
+ (
|
|
140
|
+
"dataset contains {} classes to predict\n".format(
|
|
141
|
+
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
142
|
+
)
|
|
143
|
+
if len(self.class_topred) > 0
|
|
144
|
+
else ""
|
|
177
145
|
)
|
|
178
146
|
)
|
|
179
|
-
# print("embedding size is {}".format(self.gene_embedding.shape[1]))
|
|
180
|
-
return ""
|
|
181
147
|
|
|
182
148
|
def get_label_weights(self, *args, **kwargs):
|
|
149
|
+
"""
|
|
150
|
+
get_label_weights is a wrapper around mappedDataset.get_label_weights
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict: dictionary of weights for each label
|
|
154
|
+
"""
|
|
183
155
|
return self.mapped_dataset.get_label_weights(*args, **kwargs)
|
|
184
156
|
|
|
185
|
-
def get_unseen_mapped_dataset_elements(self, idx):
|
|
157
|
+
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
158
|
+
"""
|
|
159
|
+
get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
idx (int): index of the element to get
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
list[str]: list of unseen genes
|
|
166
|
+
"""
|
|
186
167
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
187
168
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
# genedf = pd.concat(
|
|
201
|
-
# [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
|
|
202
|
-
# )
|
|
203
|
-
# genedf.columns = genedf.columns.astype(str)
|
|
204
|
-
# embeddings.append(genedf)
|
|
205
|
-
# return pd.concat(embeddings)
|
|
206
|
-
|
|
207
|
-
def define_hierarchies(self, labels):
|
|
208
|
-
self.class_groupings = {}
|
|
169
|
+
def define_hierarchies(self, clsses: list[str]):
|
|
170
|
+
"""
|
|
171
|
+
define_hierarchies is a method to define the hierarchies for the classes to predict
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
clsses (list[str]): list of classes to predict
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: if the class is not in the accepted classes
|
|
178
|
+
"""
|
|
179
|
+
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
180
|
+
self.labels_groupings = {}
|
|
209
181
|
self.class_topred = {}
|
|
210
|
-
for
|
|
211
|
-
if
|
|
182
|
+
for clss in clsses:
|
|
183
|
+
if clss not in [
|
|
212
184
|
"cell_type_ontology_term_id",
|
|
213
185
|
"tissue_ontology_term_id",
|
|
214
186
|
"disease_ontology_term_id",
|
|
@@ -217,120 +189,139 @@ class Dataset(torchDataset):
|
|
|
217
189
|
"self_reported_ethnicity_ontology_term_id",
|
|
218
190
|
]:
|
|
219
191
|
raise ValueError(
|
|
220
|
-
"
|
|
221
|
-
|
|
192
|
+
"class {} not in accepted classes, for now only supported from bionty sources".format(
|
|
193
|
+
clss
|
|
222
194
|
)
|
|
223
195
|
)
|
|
224
|
-
elif
|
|
196
|
+
elif clss == "cell_type_ontology_term_id":
|
|
225
197
|
parentdf = (
|
|
226
|
-
|
|
198
|
+
bt.CellType.filter()
|
|
227
199
|
.df(include=["parents__ontology_id"])
|
|
228
200
|
.set_index("ontology_id")
|
|
229
201
|
)
|
|
230
|
-
elif
|
|
202
|
+
elif clss == "tissue_ontology_term_id":
|
|
231
203
|
parentdf = (
|
|
232
|
-
|
|
204
|
+
bt.Tissue.filter()
|
|
233
205
|
.df(include=["parents__ontology_id"])
|
|
234
206
|
.set_index("ontology_id")
|
|
235
207
|
)
|
|
236
|
-
elif
|
|
208
|
+
elif clss == "disease_ontology_term_id":
|
|
237
209
|
parentdf = (
|
|
238
|
-
|
|
210
|
+
bt.Disease.filter()
|
|
239
211
|
.df(include=["parents__ontology_id"])
|
|
240
212
|
.set_index("ontology_id")
|
|
241
213
|
)
|
|
242
|
-
elif
|
|
214
|
+
elif clss == "development_stage_ontology_term_id":
|
|
243
215
|
parentdf = (
|
|
244
|
-
|
|
216
|
+
bt.DevelopmentalStage.filter()
|
|
245
217
|
.df(include=["parents__ontology_id"])
|
|
246
218
|
.set_index("ontology_id")
|
|
247
219
|
)
|
|
248
|
-
elif
|
|
220
|
+
elif clss == "assay_ontology_term_id":
|
|
249
221
|
parentdf = (
|
|
250
|
-
|
|
222
|
+
bt.ExperimentalFactor.filter()
|
|
251
223
|
.df(include=["parents__ontology_id"])
|
|
252
224
|
.set_index("ontology_id")
|
|
253
225
|
)
|
|
254
|
-
elif
|
|
226
|
+
elif clss == "self_reported_ethnicity_ontology_term_id":
|
|
255
227
|
parentdf = (
|
|
256
|
-
|
|
228
|
+
bt.Ethnicity.filter()
|
|
257
229
|
.df(include=["parents__ontology_id"])
|
|
258
230
|
.set_index("ontology_id")
|
|
259
231
|
)
|
|
260
232
|
|
|
261
233
|
else:
|
|
262
234
|
raise ValueError(
|
|
263
|
-
"
|
|
264
|
-
|
|
235
|
+
"class {} not in accepted classes, for now only supported from bionty sources".format(
|
|
236
|
+
clss
|
|
265
237
|
)
|
|
266
238
|
)
|
|
267
|
-
cats = self.mapped_dataset.get_merged_categories(
|
|
268
|
-
addition = set(LABELS_TOADD.get(
|
|
239
|
+
cats = self.mapped_dataset.get_merged_categories(clss)
|
|
240
|
+
addition = set(LABELS_TOADD.get(clss, {}).values())
|
|
269
241
|
cats |= addition
|
|
270
|
-
groupings, _,
|
|
242
|
+
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
|
271
243
|
for i, j in groupings.items():
|
|
272
244
|
if len(j) == 0:
|
|
273
245
|
groupings.pop(i)
|
|
274
|
-
self.
|
|
275
|
-
if
|
|
276
|
-
# if we have added new
|
|
277
|
-
mlength = len(self.mapped_dataset.encoders[
|
|
246
|
+
self.labels_groupings[clss] = groupings
|
|
247
|
+
if clss in self.clss_to_pred:
|
|
248
|
+
# if we have added new clss, we need to update the encoder with them too.
|
|
249
|
+
mlength = len(self.mapped_dataset.encoders[clss])
|
|
250
|
+
|
|
278
251
|
mlength -= (
|
|
279
252
|
1
|
|
280
|
-
if self.mapped_dataset.
|
|
281
|
-
in self.mapped_dataset.encoders[
|
|
253
|
+
if self.mapped_dataset.unknown_label
|
|
254
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
282
255
|
else 0
|
|
283
256
|
)
|
|
284
257
|
|
|
285
258
|
for i, v in enumerate(
|
|
286
|
-
addition - set(self.mapped_dataset.encoders[
|
|
259
|
+
addition - set(self.mapped_dataset.encoders[clss].keys())
|
|
287
260
|
):
|
|
288
|
-
self.mapped_dataset.encoders[
|
|
261
|
+
self.mapped_dataset.encoders[clss].update({v: mlength + i})
|
|
289
262
|
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
290
263
|
|
|
291
|
-
self.class_topred[
|
|
264
|
+
self.class_topred[clss] = leaf_labels
|
|
292
265
|
c = 0
|
|
293
|
-
d = 0
|
|
294
266
|
update = {}
|
|
295
|
-
mlength = len(
|
|
296
|
-
# import pdb
|
|
297
|
-
|
|
298
|
-
# pdb.set_trace()
|
|
267
|
+
mlength = len(leaf_labels)
|
|
299
268
|
mlength -= (
|
|
300
269
|
1
|
|
301
|
-
if self.mapped_dataset.
|
|
302
|
-
in self.mapped_dataset.encoders[
|
|
270
|
+
if self.mapped_dataset.unknown_label
|
|
271
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
303
272
|
else 0
|
|
304
273
|
)
|
|
305
|
-
for k, v in self.mapped_dataset.encoders[
|
|
306
|
-
if k in self.
|
|
274
|
+
for k, v in self.mapped_dataset.encoders[clss].items():
|
|
275
|
+
if k in self.labels_groupings[clss].keys():
|
|
307
276
|
update.update({k: mlength + c})
|
|
308
277
|
c += 1
|
|
309
|
-
elif k == self.mapped_dataset.
|
|
278
|
+
elif k == self.mapped_dataset.unknown_label:
|
|
310
279
|
update.update({k: v})
|
|
311
|
-
|
|
312
|
-
self.class_topred[label] -= set([k])
|
|
280
|
+
self.class_topred[clss] -= set([k])
|
|
313
281
|
else:
|
|
314
|
-
update.update({k:
|
|
315
|
-
self.mapped_dataset.encoders[
|
|
282
|
+
update.update({k: v - c})
|
|
283
|
+
self.mapped_dataset.encoders[clss] = update
|
|
316
284
|
|
|
317
285
|
|
|
318
|
-
class SimpleAnnDataset:
|
|
319
|
-
def __init__(
|
|
320
|
-
self
|
|
321
|
-
|
|
322
|
-
|
|
286
|
+
class SimpleAnnDataset(torchDataset):
|
|
287
|
+
def __init__(
|
|
288
|
+
self,
|
|
289
|
+
adata: AnnData,
|
|
290
|
+
obs_to_output: Optional[list[str]] = [],
|
|
291
|
+
layer: Optional[str] = None,
|
|
292
|
+
):
|
|
293
|
+
"""
|
|
294
|
+
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
295
|
+
scDataloader and with your model during inference.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
----
|
|
299
|
+
adata (anndata.AnnData): anndata object to use
|
|
300
|
+
obs_to_output (list[str]): list of observations to output from anndata.obs
|
|
301
|
+
layer (str): layer of the anndata to use
|
|
302
|
+
"""
|
|
303
|
+
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
304
|
+
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
305
|
+
self.obs_to_output = adata.obs[obs_to_output]
|
|
323
306
|
|
|
324
307
|
def __len__(self):
|
|
325
|
-
return self.
|
|
308
|
+
return self.adataX.shape[0]
|
|
309
|
+
|
|
310
|
+
def __iter__(self):
|
|
311
|
+
for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
|
|
312
|
+
with warnings.catch_warnings():
|
|
313
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
314
|
+
out = {"x": self.adataX[idx].reshape(-1)}
|
|
315
|
+
out.update(
|
|
316
|
+
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
317
|
+
)
|
|
318
|
+
yield out
|
|
326
319
|
|
|
327
320
|
def __getitem__(self, idx):
|
|
328
321
|
with warnings.catch_warnings():
|
|
329
322
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
for i in self.obs_to_output:
|
|
335
|
-
out.update({i: self.adata.obs.iloc[idx][i]})
|
|
323
|
+
out = {"x": self.adataX[idx].reshape(-1)}
|
|
324
|
+
out.update(
|
|
325
|
+
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
326
|
+
)
|
|
336
327
|
return out
|