scdataloader 0.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -2
- scdataloader/__main__.py +3 -0
- scdataloader/collator.py +61 -96
- scdataloader/config.py +6 -0
- scdataloader/data.py +138 -90
- scdataloader/datamodule.py +67 -39
- scdataloader/mapped.py +302 -120
- scdataloader/preprocess.py +4 -213
- scdataloader/utils.py +128 -92
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/METADATA +82 -26
- scdataloader-1.0.5.dist-info/RECORD +16 -0
- scdataloader-0.0.4.dist-info/RECORD +0 -16
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.4.dist-info → scdataloader-1.0.5.dist-info}/entry_points.txt +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.
|
|
1
|
+
1.0.5
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -10,6 +10,9 @@ from typing import Optional, Union
|
|
|
10
10
|
|
|
11
11
|
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
|
|
12
12
|
def main():
|
|
13
|
+
"""
|
|
14
|
+
main function to preprocess datasets in a given lamindb collection.
|
|
15
|
+
"""
|
|
13
16
|
parser = argparse.ArgumentParser(
|
|
14
17
|
description="Preprocess datasets in a given lamindb collection."
|
|
15
18
|
)
|
scdataloader/collator.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from .utils import load_genes
|
|
2
|
+
from .utils import load_genes, downsample_profile
|
|
3
3
|
from torch import Tensor, long
|
|
4
|
-
|
|
5
|
-
# class SimpleCollator:
|
|
4
|
+
from typing import Optional
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class Collator:
|
|
9
8
|
def __init__(
|
|
10
9
|
self,
|
|
11
|
-
organisms: list,
|
|
12
|
-
how="all",
|
|
13
|
-
org_to_id: dict = None,
|
|
14
|
-
valid_genes: list = [],
|
|
15
|
-
max_len=2000,
|
|
16
|
-
add_zero_genes=0,
|
|
17
|
-
logp1=False,
|
|
18
|
-
norm_to=None,
|
|
19
|
-
n_bins=0,
|
|
20
|
-
tp_name=None,
|
|
21
|
-
organism_name="organism_ontology_term_id",
|
|
22
|
-
class_names=[],
|
|
23
|
-
genelist=[],
|
|
10
|
+
organisms: list[str],
|
|
11
|
+
how: str = "all",
|
|
12
|
+
org_to_id: dict[str, int] = None,
|
|
13
|
+
valid_genes: list[str] = [],
|
|
14
|
+
max_len: int = 2000,
|
|
15
|
+
add_zero_genes: int = 0,
|
|
16
|
+
logp1: bool = False,
|
|
17
|
+
norm_to: Optional[float] = None,
|
|
18
|
+
n_bins: int = 0,
|
|
19
|
+
tp_name: Optional[str] = None,
|
|
20
|
+
organism_name: str = "organism_ontology_term_id",
|
|
21
|
+
class_names: list[str] = [],
|
|
22
|
+
genelist: list[str] = [],
|
|
23
|
+
downsample: Optional[float] = None, # don't use it for training!
|
|
24
|
+
save_output: bool = False,
|
|
24
25
|
):
|
|
25
26
|
"""
|
|
26
27
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -44,38 +45,57 @@ class Collator:
|
|
|
44
45
|
org_to_id (dict): Dictionary mapping organisms to their respective IDs.
|
|
45
46
|
valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
|
|
46
47
|
it will drop any other genes from the input expression data (usefull when your model only works on some genes)
|
|
47
|
-
max_len (int, optional):
|
|
48
|
+
max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
|
|
48
49
|
n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
|
|
49
50
|
add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
|
|
50
51
|
logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
|
|
51
|
-
norm_to (
|
|
52
|
+
norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
|
|
53
|
+
organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
|
|
54
|
+
tp_name (str, optional): Name of the heat diff. Defaults to None.
|
|
55
|
+
class_names (list, optional): List of other classes to be considered. Defaults to [].
|
|
56
|
+
genelist (list, optional): List of genes to be considered. Defaults to [].
|
|
57
|
+
If [] all genes will be considered
|
|
58
|
+
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
|
|
59
|
+
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
|
|
60
|
+
save_output (bool, optional): If True, saves the output to a file. Defaults to False.
|
|
61
|
+
This is mainly for debugging purposes
|
|
52
62
|
"""
|
|
53
63
|
self.organisms = organisms
|
|
64
|
+
self.genedf = load_genes(organisms)
|
|
54
65
|
self.max_len = max_len
|
|
55
66
|
self.n_bins = n_bins
|
|
56
67
|
self.add_zero_genes = add_zero_genes
|
|
57
68
|
self.logp1 = logp1
|
|
58
69
|
self.norm_to = norm_to
|
|
59
|
-
self.org_to_id = org_to_id
|
|
60
70
|
self.how = how
|
|
61
|
-
self.organism_ids = (
|
|
62
|
-
set([org_to_id[k] for k in organisms])
|
|
63
|
-
if org_to_id is not None
|
|
64
|
-
else set(organisms)
|
|
65
|
-
)
|
|
66
71
|
if self.how == "some":
|
|
67
72
|
assert len(genelist) > 0, "if how is some, genelist must be provided"
|
|
68
73
|
self.organism_name = organism_name
|
|
69
74
|
self.tp_name = tp_name
|
|
70
75
|
self.class_names = class_names
|
|
71
|
-
|
|
76
|
+
self.save_output = save_output
|
|
72
77
|
self.start_idx = {}
|
|
73
78
|
self.accepted_genes = {}
|
|
74
|
-
self.
|
|
79
|
+
self.downsample = downsample
|
|
75
80
|
self.to_subset = {}
|
|
76
|
-
|
|
81
|
+
self._setup(org_to_id, valid_genes, genelist)
|
|
82
|
+
|
|
83
|
+
def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
|
|
84
|
+
self.org_to_id = org_to_id
|
|
85
|
+
self.to_subset = {}
|
|
86
|
+
self.accepted_genes = {}
|
|
87
|
+
self.start_idx = {}
|
|
88
|
+
self.organism_ids = (
|
|
89
|
+
set([org_to_id[k] for k in self.organisms])
|
|
90
|
+
if org_to_id is not None
|
|
91
|
+
else set(self.organisms)
|
|
92
|
+
)
|
|
93
|
+
for organism in self.organisms:
|
|
77
94
|
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
78
|
-
|
|
95
|
+
if len(valid_genes) > 0:
|
|
96
|
+
tot = self.genedf[self.genedf.index.isin(valid_genes)]
|
|
97
|
+
else:
|
|
98
|
+
tot = self.genedf
|
|
79
99
|
org = org_to_id[organism] if org_to_id is not None else organism
|
|
80
100
|
self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
|
|
81
101
|
if len(valid_genes) > 0:
|
|
@@ -84,14 +104,14 @@ class Collator:
|
|
|
84
104
|
df = ogenedf[ogenedf.index.isin(valid_genes)]
|
|
85
105
|
self.to_subset.update({org: df.index.isin(genelist)})
|
|
86
106
|
|
|
87
|
-
def __call__(self, batch):
|
|
107
|
+
def __call__(self, batch) -> dict[str, Tensor]:
|
|
88
108
|
"""
|
|
89
109
|
__call__ applies the collator to a minibatch of data
|
|
90
110
|
|
|
91
111
|
Args:
|
|
92
112
|
batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
93
113
|
the first list is for the different samples, the second list is for the different elements with
|
|
94
|
-
elem["
|
|
114
|
+
elem["X"]: gene expression
|
|
95
115
|
elem["organism_name"]: organism ontology term id
|
|
96
116
|
elem["tp_name"]: heat diff
|
|
97
117
|
elem["class_names.."]: other classes
|
|
@@ -113,9 +133,9 @@ class Collator:
|
|
|
113
133
|
organism_id = elem[self.organism_name]
|
|
114
134
|
if organism_id not in self.organism_ids:
|
|
115
135
|
continue
|
|
116
|
-
if "
|
|
117
|
-
dataset.append(elem["
|
|
118
|
-
expr = np.array(elem["
|
|
136
|
+
if "_storage_idx" in elem:
|
|
137
|
+
dataset.append(elem["_storage_idx"])
|
|
138
|
+
expr = np.array(elem["X"])
|
|
119
139
|
total_count.append(expr.sum())
|
|
120
140
|
if len(self.accepted_genes) > 0:
|
|
121
141
|
expr = expr[self.accepted_genes[organism_id]]
|
|
@@ -206,72 +226,17 @@ class Collator:
|
|
|
206
226
|
}
|
|
207
227
|
if len(dataset) > 0:
|
|
208
228
|
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
229
|
+
if self.downsample is not None:
|
|
230
|
+
ret["x"] = downsample_profile(ret["x"], self.downsample)
|
|
231
|
+
if self.save_output:
|
|
232
|
+
with open("collator_output.txt", "a") as f:
|
|
233
|
+
np.savetxt(f, ret["x"].numpy())
|
|
209
234
|
return ret
|
|
210
235
|
|
|
211
236
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
|
|
216
|
-
|
|
217
|
-
Args:
|
|
218
|
-
@see Collator
|
|
219
|
-
"""
|
|
220
|
-
super().__init__(*args, **kwargs)
|
|
221
|
-
|
|
222
|
-
def __call__(self, batch):
|
|
223
|
-
exprs = []
|
|
224
|
-
total_count = []
|
|
225
|
-
other_classes = []
|
|
226
|
-
gene_locs = []
|
|
227
|
-
tp = []
|
|
228
|
-
for elem in batch:
|
|
229
|
-
organism_id = elem.obs[self.organism_name]
|
|
230
|
-
if organism_id.item() not in self.organism_ids:
|
|
231
|
-
print(organism_id)
|
|
232
|
-
expr = np.array(elem.X[0])
|
|
233
|
-
|
|
234
|
-
total_count.append(expr.sum())
|
|
235
|
-
if len(self.accepted_genes) > 0:
|
|
236
|
-
expr = expr[self.accepted_genes[organism_id]]
|
|
237
|
-
if self.how == "most expr":
|
|
238
|
-
loc = np.argsort(expr)[-(self.max_len) :][::-1]
|
|
239
|
-
elif self.how == "random expr":
|
|
240
|
-
nnz_loc = np.where(expr > 0)[0]
|
|
241
|
-
loc = nnz_loc[
|
|
242
|
-
np.random.choice(len(nnz_loc), self.max_len, replace=False)
|
|
243
|
-
]
|
|
244
|
-
else:
|
|
245
|
-
raise ValueError("how must be either most expr or random expr")
|
|
246
|
-
if self.add_zero_genes > 0:
|
|
247
|
-
zero_loc = np.where(expr == 0)[0]
|
|
248
|
-
zero_loc = [
|
|
249
|
-
np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
|
|
250
|
-
]
|
|
251
|
-
loc = np.concatenate((loc, zero_loc), axis=None)
|
|
252
|
-
exprs.append(expr[loc])
|
|
253
|
-
gene_locs.append(loc + self.start_idx[organism_id.item()])
|
|
254
|
-
|
|
255
|
-
if self.tp_name is not None:
|
|
256
|
-
tp.append(elem.obs[self.tp_name])
|
|
257
|
-
else:
|
|
258
|
-
tp.append(0)
|
|
259
|
-
|
|
260
|
-
other_classes.append([elem.obs[i].values[0] for i in self.class_names])
|
|
261
|
-
|
|
262
|
-
expr = np.array(exprs)
|
|
263
|
-
tp = np.array(tp)
|
|
264
|
-
gene_locs = np.array(gene_locs)
|
|
265
|
-
total_count = np.array(total_count)
|
|
266
|
-
other_classes = np.array(other_classes)
|
|
267
|
-
return {
|
|
268
|
-
"x": Tensor(expr),
|
|
269
|
-
"genes": Tensor(gene_locs).int(),
|
|
270
|
-
"depth": Tensor(total_count),
|
|
271
|
-
"class": Tensor(other_classes),
|
|
272
|
-
}
|
|
273
|
-
|
|
274
|
-
|
|
237
|
+
#############
|
|
238
|
+
#### WIP ####
|
|
239
|
+
#############
|
|
275
240
|
class GeneformerCollator(Collator):
|
|
276
241
|
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
277
242
|
"""
|
scdataloader/config.py
CHANGED
scdataloader/data.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
3
|
import lamindb as ln
|
|
4
|
+
|
|
5
|
+
# ln.connect("scprint")
|
|
6
|
+
|
|
4
7
|
import bionty as bt
|
|
5
8
|
import pandas as pd
|
|
6
9
|
from torch.utils.data import Dataset as torchDataset
|
|
7
10
|
from typing import Union, Optional, Literal
|
|
8
|
-
from scdataloader import
|
|
11
|
+
from scdataloader.mapped import MappedCollection
|
|
9
12
|
import warnings
|
|
10
13
|
|
|
11
14
|
from anndata import AnnData
|
|
15
|
+
from scipy.sparse import issparse
|
|
12
16
|
|
|
13
17
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
14
18
|
|
|
@@ -58,30 +62,31 @@ class Dataset(torchDataset):
|
|
|
58
62
|
"sex_ontology_term_id",
|
|
59
63
|
#'dataset_id',
|
|
60
64
|
#'cell_culture',
|
|
61
|
-
#"dpt_group",
|
|
62
|
-
#"heat_diff",
|
|
63
|
-
#"nnz",
|
|
65
|
+
# "dpt_group",
|
|
66
|
+
# "heat_diff",
|
|
67
|
+
# "nnz",
|
|
64
68
|
]
|
|
65
69
|
)
|
|
66
70
|
# set of obs to prepare for prediction (encode)
|
|
67
71
|
clss_to_pred: Optional[list[str]] = field(default_factory=list)
|
|
68
72
|
# set of obs that need to be hierarchically prepared
|
|
69
73
|
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
70
|
-
join_vars:
|
|
74
|
+
join_vars: Literal["inner", "outer"] | None = None
|
|
71
75
|
|
|
72
76
|
def __post_init__(self):
|
|
73
|
-
self.mapped_dataset = mapped
|
|
77
|
+
self.mapped_dataset = mapped(
|
|
74
78
|
self.lamin_dataset,
|
|
75
|
-
|
|
79
|
+
obs_keys=self.obs,
|
|
80
|
+
join=self.join_vars,
|
|
76
81
|
encode_labels=self.clss_to_pred,
|
|
82
|
+
unknown_label="unknown",
|
|
77
83
|
stream=True,
|
|
78
84
|
parallel=True,
|
|
79
|
-
join_vars=self.join_vars,
|
|
80
85
|
)
|
|
81
86
|
print(
|
|
82
87
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
83
88
|
)
|
|
84
|
-
self.
|
|
89
|
+
self.labels_groupings = {}
|
|
85
90
|
self.class_topred = {}
|
|
86
91
|
# generate tree from ontologies
|
|
87
92
|
if len(self.hierarchical_clss) > 0:
|
|
@@ -93,24 +98,19 @@ class Dataset(torchDataset):
|
|
|
93
98
|
self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
|
|
94
99
|
clss
|
|
95
100
|
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
else:
|
|
104
|
-
update.update({k: v - c})
|
|
105
|
-
self.mapped_dataset.encoders[clss] = update
|
|
101
|
+
if (
|
|
102
|
+
self.mapped_dataset.unknown_label
|
|
103
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
104
|
+
):
|
|
105
|
+
self.class_topred[clss] -= set(
|
|
106
|
+
[self.mapped_dataset.unknown_label]
|
|
107
|
+
)
|
|
106
108
|
|
|
107
109
|
if self.genedf is None:
|
|
108
110
|
self.genedf = load_genes(self.organisms)
|
|
109
111
|
|
|
110
112
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
111
|
-
|
|
112
|
-
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
113
|
-
self.mapped_dataset._check_aligned_vars(ogenedf.index.tolist())
|
|
113
|
+
self.mapped_dataset._check_aligned_vars(self.genedf.index.tolist())
|
|
114
114
|
|
|
115
115
|
def __len__(self, **kwargs):
|
|
116
116
|
return self.mapped_dataset.__len__(**kwargs)
|
|
@@ -121,19 +121,6 @@ class Dataset(torchDataset):
|
|
|
121
121
|
|
|
122
122
|
def __getitem__(self, *args, **kwargs):
|
|
123
123
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
124
|
-
# import pdb
|
|
125
|
-
|
|
126
|
-
# pdb.set_trace()
|
|
127
|
-
# item.update(
|
|
128
|
-
# {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
|
|
129
|
-
# )
|
|
130
|
-
# ret = {}
|
|
131
|
-
# ret["count"] = item[0]
|
|
132
|
-
# for i, val in enumerate(self.obs):
|
|
133
|
-
# ret[val] = item[1][i]
|
|
134
|
-
## mark unseen genes with a flag
|
|
135
|
-
## send the associated
|
|
136
|
-
# print(item[0].shape)
|
|
137
124
|
return item
|
|
138
125
|
|
|
139
126
|
def __repr__(self):
|
|
@@ -148,7 +135,6 @@ class Dataset(torchDataset):
|
|
|
148
135
|
+ " {} labels\n".format(len(self.obs))
|
|
149
136
|
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
150
137
|
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
151
|
-
+ " {} join_vars\n".format(len(self.join_vars))
|
|
152
138
|
+ " {} organisms\n".format(len(self.organisms))
|
|
153
139
|
+ (
|
|
154
140
|
"dataset contains {} classes to predict\n".format(
|
|
@@ -160,17 +146,41 @@ class Dataset(torchDataset):
|
|
|
160
146
|
)
|
|
161
147
|
|
|
162
148
|
def get_label_weights(self, *args, **kwargs):
|
|
149
|
+
"""
|
|
150
|
+
get_label_weights is a wrapper around mappedDataset.get_label_weights
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict: dictionary of weights for each label
|
|
154
|
+
"""
|
|
163
155
|
return self.mapped_dataset.get_label_weights(*args, **kwargs)
|
|
164
156
|
|
|
165
157
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
158
|
+
"""
|
|
159
|
+
get_unseen_mapped_dataset_elements is a wrapper around mappedDataset.get_unseen_mapped_dataset_elements
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
idx (int): index of the element to get
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
list[str]: list of unseen genes
|
|
166
|
+
"""
|
|
166
167
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
167
168
|
|
|
168
|
-
def define_hierarchies(self,
|
|
169
|
+
def define_hierarchies(self, clsses: list[str]):
|
|
170
|
+
"""
|
|
171
|
+
define_hierarchies is a method to define the hierarchies for the classes to predict
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
clsses (list[str]): list of classes to predict
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: if the class is not in the accepted classes
|
|
178
|
+
"""
|
|
169
179
|
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
170
|
-
self.
|
|
180
|
+
self.labels_groupings = {}
|
|
171
181
|
self.class_topred = {}
|
|
172
|
-
for
|
|
173
|
-
if
|
|
182
|
+
for clss in clsses:
|
|
183
|
+
if clss not in [
|
|
174
184
|
"cell_type_ontology_term_id",
|
|
175
185
|
"tissue_ontology_term_id",
|
|
176
186
|
"disease_ontology_term_id",
|
|
@@ -179,41 +189,41 @@ class Dataset(torchDataset):
|
|
|
179
189
|
"self_reported_ethnicity_ontology_term_id",
|
|
180
190
|
]:
|
|
181
191
|
raise ValueError(
|
|
182
|
-
"
|
|
183
|
-
|
|
192
|
+
"class {} not in accepted classes, for now only supported from bionty sources".format(
|
|
193
|
+
clss
|
|
184
194
|
)
|
|
185
195
|
)
|
|
186
|
-
elif
|
|
196
|
+
elif clss == "cell_type_ontology_term_id":
|
|
187
197
|
parentdf = (
|
|
188
198
|
bt.CellType.filter()
|
|
189
199
|
.df(include=["parents__ontology_id"])
|
|
190
200
|
.set_index("ontology_id")
|
|
191
201
|
)
|
|
192
|
-
elif
|
|
202
|
+
elif clss == "tissue_ontology_term_id":
|
|
193
203
|
parentdf = (
|
|
194
204
|
bt.Tissue.filter()
|
|
195
205
|
.df(include=["parents__ontology_id"])
|
|
196
206
|
.set_index("ontology_id")
|
|
197
207
|
)
|
|
198
|
-
elif
|
|
208
|
+
elif clss == "disease_ontology_term_id":
|
|
199
209
|
parentdf = (
|
|
200
210
|
bt.Disease.filter()
|
|
201
211
|
.df(include=["parents__ontology_id"])
|
|
202
212
|
.set_index("ontology_id")
|
|
203
213
|
)
|
|
204
|
-
elif
|
|
214
|
+
elif clss == "development_stage_ontology_term_id":
|
|
205
215
|
parentdf = (
|
|
206
216
|
bt.DevelopmentalStage.filter()
|
|
207
217
|
.df(include=["parents__ontology_id"])
|
|
208
218
|
.set_index("ontology_id")
|
|
209
219
|
)
|
|
210
|
-
elif
|
|
220
|
+
elif clss == "assay_ontology_term_id":
|
|
211
221
|
parentdf = (
|
|
212
222
|
bt.ExperimentalFactor.filter()
|
|
213
223
|
.df(include=["parents__ontology_id"])
|
|
214
224
|
.set_index("ontology_id")
|
|
215
225
|
)
|
|
216
|
-
elif
|
|
226
|
+
elif clss == "self_reported_ethnicity_ontology_term_id":
|
|
217
227
|
parentdf = (
|
|
218
228
|
bt.Ethnicity.filter()
|
|
219
229
|
.df(include=["parents__ontology_id"])
|
|
@@ -222,65 +232,58 @@ class Dataset(torchDataset):
|
|
|
222
232
|
|
|
223
233
|
else:
|
|
224
234
|
raise ValueError(
|
|
225
|
-
"
|
|
226
|
-
|
|
235
|
+
"class {} not in accepted classes, for now only supported from bionty sources".format(
|
|
236
|
+
clss
|
|
227
237
|
)
|
|
228
238
|
)
|
|
229
|
-
cats = self.mapped_dataset.get_merged_categories(
|
|
230
|
-
addition = set(LABELS_TOADD.get(
|
|
239
|
+
cats = self.mapped_dataset.get_merged_categories(clss)
|
|
240
|
+
addition = set(LABELS_TOADD.get(clss, {}).values())
|
|
231
241
|
cats |= addition
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
# pdb.set_trace()
|
|
235
|
-
groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
|
|
242
|
+
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
|
236
243
|
for i, j in groupings.items():
|
|
237
244
|
if len(j) == 0:
|
|
238
245
|
groupings.pop(i)
|
|
239
|
-
self.
|
|
240
|
-
if
|
|
241
|
-
# if we have added new
|
|
242
|
-
mlength = len(self.mapped_dataset.encoders[
|
|
246
|
+
self.labels_groupings[clss] = groupings
|
|
247
|
+
if clss in self.clss_to_pred:
|
|
248
|
+
# if we have added new clss, we need to update the encoder with them too.
|
|
249
|
+
mlength = len(self.mapped_dataset.encoders[clss])
|
|
250
|
+
|
|
243
251
|
mlength -= (
|
|
244
252
|
1
|
|
245
|
-
if self.mapped_dataset.
|
|
246
|
-
in self.mapped_dataset.encoders[
|
|
253
|
+
if self.mapped_dataset.unknown_label
|
|
254
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
247
255
|
else 0
|
|
248
256
|
)
|
|
249
257
|
|
|
250
258
|
for i, v in enumerate(
|
|
251
|
-
addition - set(self.mapped_dataset.encoders[
|
|
259
|
+
addition - set(self.mapped_dataset.encoders[clss].keys())
|
|
252
260
|
):
|
|
253
|
-
self.mapped_dataset.encoders[
|
|
261
|
+
self.mapped_dataset.encoders[clss].update({v: mlength + i})
|
|
254
262
|
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
255
263
|
|
|
256
|
-
self.class_topred[
|
|
264
|
+
self.class_topred[clss] = leaf_labels
|
|
257
265
|
c = 0
|
|
258
|
-
d = 0
|
|
259
266
|
update = {}
|
|
260
|
-
mlength = len(
|
|
261
|
-
# import pdb
|
|
262
|
-
|
|
263
|
-
# pdb.set_trace()
|
|
267
|
+
mlength = len(leaf_labels)
|
|
264
268
|
mlength -= (
|
|
265
269
|
1
|
|
266
|
-
if self.mapped_dataset.
|
|
267
|
-
in self.mapped_dataset.encoders[
|
|
270
|
+
if self.mapped_dataset.unknown_label
|
|
271
|
+
in self.mapped_dataset.encoders[clss].keys()
|
|
268
272
|
else 0
|
|
269
273
|
)
|
|
270
|
-
for k, v in self.mapped_dataset.encoders[
|
|
271
|
-
if k in self.
|
|
274
|
+
for k, v in self.mapped_dataset.encoders[clss].items():
|
|
275
|
+
if k in self.labels_groupings[clss].keys():
|
|
272
276
|
update.update({k: mlength + c})
|
|
273
277
|
c += 1
|
|
274
|
-
elif k == self.mapped_dataset.
|
|
278
|
+
elif k == self.mapped_dataset.unknown_label:
|
|
275
279
|
update.update({k: v})
|
|
276
|
-
|
|
277
|
-
self.class_topred[label] -= set([k])
|
|
280
|
+
self.class_topred[clss] -= set([k])
|
|
278
281
|
else:
|
|
279
|
-
update.update({k:
|
|
280
|
-
self.mapped_dataset.encoders[
|
|
282
|
+
update.update({k: v - c})
|
|
283
|
+
self.mapped_dataset.encoders[clss] = update
|
|
281
284
|
|
|
282
285
|
|
|
283
|
-
class SimpleAnnDataset:
|
|
286
|
+
class SimpleAnnDataset(torchDataset):
|
|
284
287
|
def __init__(
|
|
285
288
|
self,
|
|
286
289
|
adata: AnnData,
|
|
@@ -297,20 +300,65 @@ class SimpleAnnDataset:
|
|
|
297
300
|
obs_to_output (list[str]): list of observations to output from anndata.obs
|
|
298
301
|
layer (str): layer of the anndata to use
|
|
299
302
|
"""
|
|
300
|
-
self.
|
|
301
|
-
self.
|
|
302
|
-
self.
|
|
303
|
+
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
304
|
+
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
305
|
+
self.obs_to_output = adata.obs[obs_to_output]
|
|
303
306
|
|
|
304
307
|
def __len__(self):
|
|
305
|
-
return self.
|
|
308
|
+
return self.adataX.shape[0]
|
|
309
|
+
|
|
310
|
+
def __iter__(self):
|
|
311
|
+
for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
|
|
312
|
+
with warnings.catch_warnings():
|
|
313
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
314
|
+
out = {"X": self.adataX[idx].reshape(-1)}
|
|
315
|
+
out.update(
|
|
316
|
+
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
317
|
+
)
|
|
318
|
+
yield out
|
|
306
319
|
|
|
307
320
|
def __getitem__(self, idx):
|
|
308
321
|
with warnings.catch_warnings():
|
|
309
322
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
for i in self.obs_to_output:
|
|
315
|
-
out.update({i: self.adata.obs.iloc[idx][i]})
|
|
323
|
+
out = {"X": self.adataX[idx].reshape(-1)}
|
|
324
|
+
out.update(
|
|
325
|
+
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
326
|
+
)
|
|
316
327
|
return out
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def mapped(
|
|
331
|
+
dataset,
|
|
332
|
+
obs_keys: list[str] | None = None,
|
|
333
|
+
join: Literal["inner", "outer"] | None = "inner",
|
|
334
|
+
encode_labels: bool | list[str] = True,
|
|
335
|
+
unknown_label: str | dict[str, str] | None = None,
|
|
336
|
+
cache_categories: bool = True,
|
|
337
|
+
parallel: bool = False,
|
|
338
|
+
dtype: str | None = None,
|
|
339
|
+
stream: bool = False,
|
|
340
|
+
is_run_input: bool | None = None,
|
|
341
|
+
) -> MappedCollection:
|
|
342
|
+
path_list = []
|
|
343
|
+
for artifact in dataset.artifacts.all():
|
|
344
|
+
if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
|
|
345
|
+
print(f"Ignoring artifact with suffix {artifact.suffix}")
|
|
346
|
+
continue
|
|
347
|
+
elif not artifact.path.exists():
|
|
348
|
+
print(f"Path does not exist for artifact with suffix {artifact.suffix}")
|
|
349
|
+
continue
|
|
350
|
+
elif not stream:
|
|
351
|
+
path_list.append(artifact.stage())
|
|
352
|
+
else:
|
|
353
|
+
path_list.append(artifact.path)
|
|
354
|
+
ds = MappedCollection(
|
|
355
|
+
path_list=path_list,
|
|
356
|
+
obs_keys=obs_keys,
|
|
357
|
+
join=join,
|
|
358
|
+
encode_labels=encode_labels,
|
|
359
|
+
unknown_label=unknown_label,
|
|
360
|
+
cache_categories=cache_categories,
|
|
361
|
+
parallel=parallel,
|
|
362
|
+
dtype=dtype,
|
|
363
|
+
)
|
|
364
|
+
return ds
|