scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +2 -0
- scdataloader/__main__.py +98 -36
- scdataloader/collator.py +13 -7
- scdataloader/config.py +99 -0
- scdataloader/data.py +48 -35
- scdataloader/datamodule.py +138 -44
- scdataloader/mapped.py +656 -0
- scdataloader/preprocess.py +239 -91
- scdataloader/utils.py +71 -27
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/METADATA +10 -8
- scdataloader-1.8.0.dist-info/RECORD +16 -0
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/WHEEL +1 -1
- scdataloader-1.8.0.dist-info/entry_points.txt +2 -0
- scdataloader-1.6.4.dist-info/RECORD +0 -14
- {scdataloader-1.6.4.dist-info → scdataloader-1.8.0.dist-info}/licenses/LICENSE +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
1.
|
|
1
|
+
1.8.0
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -10,157 +10,218 @@ from scdataloader.preprocess import (
|
|
|
10
10
|
)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15"
|
|
13
|
+
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" \
|
|
14
|
+
# --description="scPRINT-V2 datasets" --new_name="scprint v2" --n_hvg_for_postp=4000 --cache=False \
|
|
15
|
+
# --filter_gene_by_counts=0 --filter_cell_by_counts=300 --min_valid_genes_id=500 \
|
|
16
|
+
# --min_nnz_genes=120 --min_dataset_size=100 --maxdropamount=90 \
|
|
17
|
+
# --organisms=["NCBITaxon:9606","NCBITaxon:9544","NCBITaxon:9483","NCBITaxon:10090"] \
|
|
18
|
+
# --start_at=0
|
|
14
19
|
def main():
|
|
15
20
|
"""
|
|
16
|
-
|
|
21
|
+
Main function to either preprocess datasets in a lamindb collection or populate ontologies.
|
|
17
22
|
"""
|
|
18
23
|
parser = argparse.ArgumentParser(
|
|
19
|
-
description="Preprocess datasets
|
|
24
|
+
description="Preprocess datasets or populate ontologies."
|
|
20
25
|
)
|
|
21
|
-
parser.
|
|
26
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
27
|
+
|
|
28
|
+
# Preprocess command
|
|
29
|
+
preprocess_parser = subparsers.add_parser("preprocess", help="Preprocess datasets")
|
|
30
|
+
preprocess_parser.add_argument(
|
|
22
31
|
"--name", type=str, required=True, help="Name of the input dataset"
|
|
23
32
|
)
|
|
24
|
-
|
|
33
|
+
preprocess_parser.add_argument(
|
|
25
34
|
"--new_name",
|
|
26
35
|
type=str,
|
|
27
36
|
default="preprocessed dataset",
|
|
28
37
|
help="Name of the preprocessed dataset.",
|
|
29
38
|
)
|
|
30
|
-
|
|
39
|
+
preprocess_parser.add_argument(
|
|
31
40
|
"--description",
|
|
32
41
|
type=str,
|
|
33
42
|
default="preprocessed by scDataLoader",
|
|
34
43
|
help="Description of the preprocessed dataset.",
|
|
35
44
|
)
|
|
36
|
-
|
|
45
|
+
preprocess_parser.add_argument(
|
|
37
46
|
"--start_at", type=int, default=0, help="Position to start preprocessing at."
|
|
38
47
|
)
|
|
39
|
-
|
|
48
|
+
preprocess_parser.add_argument(
|
|
40
49
|
"--new_version",
|
|
41
50
|
type=str,
|
|
42
51
|
default="2",
|
|
43
52
|
help="Version of the output dataset and files.",
|
|
44
53
|
)
|
|
45
|
-
|
|
54
|
+
preprocess_parser.add_argument(
|
|
46
55
|
"--instance",
|
|
47
56
|
type=str,
|
|
48
57
|
default=None,
|
|
49
58
|
help="Instance storing the input dataset, if not local",
|
|
50
59
|
)
|
|
51
|
-
|
|
60
|
+
preprocess_parser.add_argument(
|
|
52
61
|
"--version", type=str, default=None, help="Version of the input dataset."
|
|
53
62
|
)
|
|
54
|
-
|
|
63
|
+
preprocess_parser.add_argument(
|
|
55
64
|
"--filter_gene_by_counts",
|
|
56
65
|
type=int,
|
|
57
66
|
default=0,
|
|
58
67
|
help="Determines whether to filter genes by counts.",
|
|
59
68
|
)
|
|
60
|
-
|
|
69
|
+
preprocess_parser.add_argument(
|
|
61
70
|
"--filter_cell_by_counts",
|
|
62
71
|
type=int,
|
|
63
72
|
default=0,
|
|
64
73
|
help="Determines whether to filter cells by counts.",
|
|
65
74
|
)
|
|
66
|
-
|
|
75
|
+
preprocess_parser.add_argument(
|
|
67
76
|
"--normalize_sum",
|
|
68
77
|
type=float,
|
|
69
78
|
default=1e4,
|
|
70
79
|
help="Determines whether to normalize the total counts of each cell to a specific value.",
|
|
71
80
|
)
|
|
72
|
-
|
|
73
|
-
"--
|
|
81
|
+
preprocess_parser.add_argument(
|
|
82
|
+
"--n_hvg_for_postp",
|
|
74
83
|
type=int,
|
|
75
84
|
default=0,
|
|
76
85
|
help="Determines whether to subset highly variable genes.",
|
|
77
86
|
)
|
|
78
|
-
|
|
87
|
+
preprocess_parser.add_argument(
|
|
79
88
|
"--hvg_flavor",
|
|
80
89
|
type=str,
|
|
81
90
|
default="seurat_v3",
|
|
82
91
|
help="Specifies the flavor of highly variable genes selection.",
|
|
83
92
|
)
|
|
84
|
-
|
|
93
|
+
preprocess_parser.add_argument(
|
|
85
94
|
"--binning",
|
|
86
95
|
type=Optional[int],
|
|
87
96
|
default=None,
|
|
88
97
|
help="Determines whether to bin the data into discrete values of number of bins provided.",
|
|
89
98
|
)
|
|
90
|
-
|
|
99
|
+
preprocess_parser.add_argument(
|
|
91
100
|
"--result_binned_key",
|
|
92
101
|
type=str,
|
|
93
102
|
default="X_binned",
|
|
94
103
|
help="Specifies the key of AnnData to store the binned data.",
|
|
95
104
|
)
|
|
96
|
-
|
|
105
|
+
preprocess_parser.add_argument(
|
|
97
106
|
"--length_normalize",
|
|
98
107
|
type=bool,
|
|
99
108
|
default=False,
|
|
100
109
|
help="Determines whether to normalize the length.",
|
|
101
110
|
)
|
|
102
|
-
|
|
111
|
+
preprocess_parser.add_argument(
|
|
103
112
|
"--force_preprocess",
|
|
104
113
|
type=bool,
|
|
105
114
|
default=False,
|
|
106
115
|
help="Determines whether to force preprocessing.",
|
|
107
116
|
)
|
|
108
|
-
|
|
117
|
+
preprocess_parser.add_argument(
|
|
109
118
|
"--min_dataset_size",
|
|
110
119
|
type=int,
|
|
111
120
|
default=100,
|
|
112
121
|
help="Specifies the minimum dataset size.",
|
|
113
122
|
)
|
|
114
|
-
|
|
123
|
+
preprocess_parser.add_argument(
|
|
115
124
|
"--min_valid_genes_id",
|
|
116
125
|
type=int,
|
|
117
126
|
default=10_000,
|
|
118
127
|
help="Specifies the minimum valid genes id.",
|
|
119
128
|
)
|
|
120
|
-
|
|
129
|
+
preprocess_parser.add_argument(
|
|
121
130
|
"--min_nnz_genes",
|
|
122
131
|
type=int,
|
|
123
|
-
default=
|
|
132
|
+
default=200,
|
|
124
133
|
help="Specifies the minimum non-zero genes.",
|
|
125
134
|
)
|
|
126
|
-
|
|
135
|
+
preprocess_parser.add_argument(
|
|
127
136
|
"--maxdropamount",
|
|
128
137
|
type=int,
|
|
129
138
|
default=50,
|
|
130
139
|
help="Specifies the maximum drop amount.",
|
|
131
140
|
)
|
|
132
|
-
|
|
141
|
+
preprocess_parser.add_argument(
|
|
133
142
|
"--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
|
|
134
143
|
)
|
|
135
|
-
|
|
144
|
+
preprocess_parser.add_argument(
|
|
136
145
|
"--pct_mt_outlier",
|
|
137
146
|
type=int,
|
|
138
147
|
default=8,
|
|
139
148
|
help="Specifies the percentage of MT outlier.",
|
|
140
149
|
)
|
|
141
|
-
|
|
142
|
-
"--
|
|
150
|
+
preprocess_parser.add_argument(
|
|
151
|
+
"--batch_keys",
|
|
152
|
+
type=list[str],
|
|
153
|
+
default=[
|
|
154
|
+
"assay_ontology_term_id",
|
|
155
|
+
"self_reported_ethnicity_ontology_term_id",
|
|
156
|
+
"sex_ontology_term_id",
|
|
157
|
+
"donor_id",
|
|
158
|
+
"suspension_type",
|
|
159
|
+
],
|
|
160
|
+
help="Specifies the batch keys.",
|
|
143
161
|
)
|
|
144
|
-
|
|
162
|
+
preprocess_parser.add_argument(
|
|
145
163
|
"--skip_validate",
|
|
146
164
|
type=bool,
|
|
147
165
|
default=False,
|
|
148
166
|
help="Determines whether to skip validation.",
|
|
149
167
|
)
|
|
150
|
-
|
|
168
|
+
preprocess_parser.add_argument(
|
|
151
169
|
"--do_postp",
|
|
152
170
|
type=bool,
|
|
153
|
-
default=
|
|
171
|
+
default=True,
|
|
154
172
|
help="Determines whether to do postprocessing.",
|
|
155
173
|
)
|
|
156
|
-
|
|
174
|
+
preprocess_parser.add_argument(
|
|
157
175
|
"--cache",
|
|
158
176
|
type=bool,
|
|
159
|
-
default=
|
|
177
|
+
default=False,
|
|
160
178
|
help="Determines whether to cache the dataset.",
|
|
161
179
|
)
|
|
180
|
+
preprocess_parser.add_argument(
|
|
181
|
+
"--organisms",
|
|
182
|
+
type=list,
|
|
183
|
+
default=[
|
|
184
|
+
"NCBITaxon:9606",
|
|
185
|
+
"NCBITaxon:10090",
|
|
186
|
+
],
|
|
187
|
+
help="Determines the organisms to keep.",
|
|
188
|
+
)
|
|
189
|
+
preprocess_parser.add_argument(
|
|
190
|
+
"--force_preloaded",
|
|
191
|
+
type=bool,
|
|
192
|
+
default=False,
|
|
193
|
+
help="Determines whether the dataset is preloaded.",
|
|
194
|
+
)
|
|
195
|
+
# Populate command
|
|
196
|
+
populate_parser = subparsers.add_parser("populate", help="Populate ontologies")
|
|
197
|
+
populate_parser.add_argument(
|
|
198
|
+
"what",
|
|
199
|
+
nargs="?",
|
|
200
|
+
default="all",
|
|
201
|
+
choices=[
|
|
202
|
+
"all",
|
|
203
|
+
"organisms",
|
|
204
|
+
"celltypes",
|
|
205
|
+
"diseases",
|
|
206
|
+
"tissues",
|
|
207
|
+
"assays",
|
|
208
|
+
"ethnicities",
|
|
209
|
+
"sex",
|
|
210
|
+
"dev_stages",
|
|
211
|
+
],
|
|
212
|
+
help="What ontologies to populate",
|
|
213
|
+
)
|
|
162
214
|
args = parser.parse_args()
|
|
163
215
|
|
|
216
|
+
if args.command == "populate":
|
|
217
|
+
from scdataloader.utils import populate_my_ontology
|
|
218
|
+
|
|
219
|
+
if args.what != "all":
|
|
220
|
+
raise ValueError("Only 'all' is supported for now")
|
|
221
|
+
else:
|
|
222
|
+
populate_my_ontology()
|
|
223
|
+
return
|
|
224
|
+
|
|
164
225
|
# Load the collection
|
|
165
226
|
# if not args.preprocess:
|
|
166
227
|
# print("Only preprocess is available for now")
|
|
@@ -182,7 +243,7 @@ def main():
|
|
|
182
243
|
filter_gene_by_counts=args.filter_gene_by_counts,
|
|
183
244
|
filter_cell_by_counts=args.filter_cell_by_counts,
|
|
184
245
|
normalize_sum=args.normalize_sum,
|
|
185
|
-
|
|
246
|
+
n_hvg_for_postp=args.n_hvg_for_postp,
|
|
186
247
|
hvg_flavor=args.hvg_flavor,
|
|
187
248
|
cache=args.cache,
|
|
188
249
|
binning=args.binning,
|
|
@@ -195,12 +256,13 @@ def main():
|
|
|
195
256
|
maxdropamount=args.maxdropamount,
|
|
196
257
|
madoutlier=args.madoutlier,
|
|
197
258
|
pct_mt_outlier=args.pct_mt_outlier,
|
|
198
|
-
|
|
259
|
+
batch_keys=args.batch_keys,
|
|
199
260
|
skip_validate=args.skip_validate,
|
|
200
261
|
do_postp=args.do_postp,
|
|
201
262
|
additional_preprocess=additional_preprocess,
|
|
202
263
|
additional_postprocess=additional_postprocess,
|
|
203
264
|
keep_files=False,
|
|
265
|
+
force_preloaded=args.force_preloaded,
|
|
204
266
|
)
|
|
205
267
|
|
|
206
268
|
# Preprocess the dataset
|
scdataloader/collator.py
CHANGED
|
@@ -23,7 +23,8 @@ class Collator:
|
|
|
23
23
|
class_names: list[str] = [],
|
|
24
24
|
genelist: list[str] = [],
|
|
25
25
|
downsample: Optional[float] = None, # don't use it for training!
|
|
26
|
-
save_output:
|
|
26
|
+
save_output: Optional[str] = None,
|
|
27
|
+
metacell_mode: bool = False,
|
|
27
28
|
):
|
|
28
29
|
"""
|
|
29
30
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -59,8 +60,9 @@ class Collator:
|
|
|
59
60
|
If [] all genes will be considered
|
|
60
61
|
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
|
|
61
62
|
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
|
|
62
|
-
save_output (
|
|
63
|
+
save_output (str, optional): If not None, saves the output to a file. Defaults to None.
|
|
63
64
|
This is mainly for debugging purposes
|
|
65
|
+
metacell_mode (bool, optional): Whether to sample a metacell. Defaults to False.
|
|
64
66
|
"""
|
|
65
67
|
self.organisms = organisms
|
|
66
68
|
self.genedf = load_genes(organisms)
|
|
@@ -80,6 +82,7 @@ class Collator:
|
|
|
80
82
|
self.accepted_genes = {}
|
|
81
83
|
self.downsample = downsample
|
|
82
84
|
self.to_subset = {}
|
|
85
|
+
self.metacell_mode = metacell_mode
|
|
83
86
|
self._setup(org_to_id, valid_genes, genelist)
|
|
84
87
|
|
|
85
88
|
def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
|
|
@@ -131,6 +134,7 @@ class Collator:
|
|
|
131
134
|
tp = []
|
|
132
135
|
dataset = []
|
|
133
136
|
nnz_loc = []
|
|
137
|
+
is_meta = []
|
|
134
138
|
for elem in batch:
|
|
135
139
|
organism_id = elem[self.organism_name]
|
|
136
140
|
if organism_id not in self.organism_ids:
|
|
@@ -193,16 +197,16 @@ class Collator:
|
|
|
193
197
|
tp.append(elem[self.tp_name])
|
|
194
198
|
else:
|
|
195
199
|
tp.append(0)
|
|
196
|
-
|
|
200
|
+
if self.metacell_mode:
|
|
201
|
+
is_meta.append(elem["is_meta"])
|
|
197
202
|
other_classes.append([elem[i] for i in self.class_names])
|
|
198
|
-
|
|
199
203
|
expr = np.array(exprs)
|
|
200
204
|
tp = np.array(tp)
|
|
201
205
|
gene_locs = np.array(gene_locs)
|
|
202
206
|
total_count = np.array(total_count)
|
|
203
207
|
other_classes = np.array(other_classes)
|
|
204
208
|
dataset = np.array(dataset)
|
|
205
|
-
|
|
209
|
+
is_meta = np.array(is_meta)
|
|
206
210
|
# normalize counts
|
|
207
211
|
if self.norm_to is not None:
|
|
208
212
|
expr = (expr * self.norm_to) / total_count[:, None]
|
|
@@ -229,12 +233,14 @@ class Collator:
|
|
|
229
233
|
"tp": Tensor(tp),
|
|
230
234
|
"depth": Tensor(total_count),
|
|
231
235
|
}
|
|
236
|
+
if self.metacell_mode:
|
|
237
|
+
ret.update({"is_meta": Tensor(is_meta).int()})
|
|
232
238
|
if len(dataset) > 0:
|
|
233
239
|
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
234
240
|
if self.downsample is not None:
|
|
235
241
|
ret["x"] = downsample_profile(ret["x"], self.downsample)
|
|
236
|
-
if self.save_output:
|
|
237
|
-
with open(
|
|
242
|
+
if self.save_output is not None:
|
|
243
|
+
with open(self.save_output, "a") as f:
|
|
238
244
|
np.savetxt(f, ret["x"].numpy())
|
|
239
245
|
return ret
|
|
240
246
|
|
scdataloader/config.py
CHANGED
|
@@ -110,3 +110,102 @@ COARSE_ASSAY = {
|
|
|
110
110
|
"TruDrop": "",
|
|
111
111
|
"Visium Spatial Gene Expression": "",
|
|
112
112
|
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
|
|
116
|
+
"HsapDv:0010000": [
|
|
117
|
+
"MmusDv:0000092", # postnatal stage
|
|
118
|
+
],
|
|
119
|
+
"HsapDv:0000258": [ # mature stage
|
|
120
|
+
"MmusDv:0000110", # mature stage
|
|
121
|
+
"HsapDv:0000204",
|
|
122
|
+
],
|
|
123
|
+
"HsapDv:0000227": [ # late adult stage
|
|
124
|
+
"MmusDv:0000091", # 20 month-old stage
|
|
125
|
+
"MmusDv:0000089", # 18 month-old stage
|
|
126
|
+
],
|
|
127
|
+
"HsapDv:0000272": [], # 60-79 year-old stage
|
|
128
|
+
"HsapDv:0000095": [], # 80 year-old and over stage
|
|
129
|
+
"HsapDv:0000267": [ # middle aged stage
|
|
130
|
+
"MmusDv:0000087", # 16 month-old stage
|
|
131
|
+
"UBERON:0018241", # prime adult stage
|
|
132
|
+
"MmusDv:0000083", # 12 month-old stage
|
|
133
|
+
"HsapDv:0000092", # same
|
|
134
|
+
],
|
|
135
|
+
"HsapDv:0000266": [ # young adult stage
|
|
136
|
+
"MmusDv:0000050", # 6 weeks
|
|
137
|
+
"HsapDv:0000089", # same
|
|
138
|
+
"MmusDv:0000051", # 7 weeks
|
|
139
|
+
"MmusDv:0000052", # 8 weeks
|
|
140
|
+
"MmusDv:0000053", # 9 weeks
|
|
141
|
+
"MmusDv:0000054", # 10 weeks
|
|
142
|
+
"MmusDv:0000055", # 11 weeks
|
|
143
|
+
"MmusDv:0000056", # 12 weeks
|
|
144
|
+
"MmusDv:0000057", # 13 weeks
|
|
145
|
+
"MmusDv:0000058", # 14 weeks
|
|
146
|
+
"MmusDv:0000059", # 15 weeks
|
|
147
|
+
"MmusDv:0000061", # early adult stage
|
|
148
|
+
"MmusDv:0000062", # 2 month-old stage
|
|
149
|
+
"MmusDv:0000063", # 3 month-old stage
|
|
150
|
+
"MmusDv:0000064", # 4 month-old stage
|
|
151
|
+
"MmusDv:0000065", # 16 weeks
|
|
152
|
+
"MmusDv:0000066", # 17 weeks
|
|
153
|
+
"MmusDv:0000067", # 18 weeks
|
|
154
|
+
"MmusDv:0000068", # 19 weeks
|
|
155
|
+
"MmusDv:0000070", # 20 weeks
|
|
156
|
+
"MmusDv:0000071", # 21 weeks
|
|
157
|
+
"MmusDv:0000072", # 22 weeks
|
|
158
|
+
"MmusDv:0000073", # 23 weeks
|
|
159
|
+
"MmusDv:0000074", # 24 weeks
|
|
160
|
+
"MmusDv:0000077", # 6 month-old stage
|
|
161
|
+
"MmusDv:0000079", # 8 month-old stage
|
|
162
|
+
"MmusDv:0000098", # 25 weeks
|
|
163
|
+
"MmusDv:0000099", # 26 weeks
|
|
164
|
+
"MmusDv:0000102", # 29 weeks
|
|
165
|
+
],
|
|
166
|
+
"HsapDv:0000265": [], # child stage (1-4 yo)
|
|
167
|
+
"HsapDv:0000271": [ # juvenile stage (5-14 yo)
|
|
168
|
+
"MmusDv:0000048", # 4 weeks
|
|
169
|
+
"MmusDv:0000049", # 5 weeks
|
|
170
|
+
],
|
|
171
|
+
"HsapDv:0000260": [ # infant stage
|
|
172
|
+
"MmusDv:0000046", # 2 weeks
|
|
173
|
+
"MmusDv:0000045", # 1 week
|
|
174
|
+
"MmusDv:0000047", # 3 weeks
|
|
175
|
+
"HsapDv:0000083",
|
|
176
|
+
],
|
|
177
|
+
"HsapDv:0000262": [ # newborn stage (0-28 days)
|
|
178
|
+
"MmusDv:0000036", # Theiler stage 27
|
|
179
|
+
"MmusDv:0000037", # Theiler stage 28
|
|
180
|
+
"MmusDv:0000113", # 4-7 days
|
|
181
|
+
],
|
|
182
|
+
"HsapDv:0000007": [], # Carnegie stage 03
|
|
183
|
+
"HsapDv:0000008": [], # Carnegie stage 04
|
|
184
|
+
"HsapDv:0000009": [], # Carnegie stage 05
|
|
185
|
+
"HsapDv:0000003": [], # Carnegie stage 01
|
|
186
|
+
"HsapDv:0000005": [], # Carnegie stage 02
|
|
187
|
+
"HsapDv:0000010": [], # gastrula stage
|
|
188
|
+
"HsapDv:0000012": [], # neurula stage
|
|
189
|
+
"HsapDv:0000015": [ # organogenesis stage
|
|
190
|
+
"MmusDv:0000019", # Theiler stage 13
|
|
191
|
+
"MmusDv:0000020", # Theiler stage 12
|
|
192
|
+
"MmusDv:0000021", # Theiler stage 14
|
|
193
|
+
"MmusDv:0000022", # Theiler stage 15
|
|
194
|
+
"MmusDv:0000023", # Theiler stage 16
|
|
195
|
+
"MmusDv:0000024", # Theiler stage 17
|
|
196
|
+
"MmusDv:0000025", # Theiler stage 18
|
|
197
|
+
"MmusDv:0000026", # Theiler stage 19
|
|
198
|
+
"MmusDv:0000027", # Theiler stage 20
|
|
199
|
+
"MmusDv:0000028", # Theiler stage 21
|
|
200
|
+
"MmusDv:0000029", # Theiler stage 22
|
|
201
|
+
],
|
|
202
|
+
"HsapDv:0000037": [ # fetal stage
|
|
203
|
+
"MmusDv:0000033", # Theiler stage 24
|
|
204
|
+
"MmusDv:0000034", # Theiler stage 25
|
|
205
|
+
"MmusDv:0000035", # Theiler stage 26
|
|
206
|
+
"MmusDv:0000032", # Theiler stage 23
|
|
207
|
+
],
|
|
208
|
+
"unknown": [
|
|
209
|
+
"MmusDv:0000041", # unknown
|
|
210
|
+
],
|
|
211
|
+
}
|
scdataloader/data.py
CHANGED
|
@@ -10,8 +10,6 @@ import lamindb as ln
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from anndata import AnnData
|
|
13
|
-
from lamindb.core import MappedCollection
|
|
14
|
-
from lamindb.core._mapped_collection import _Connect
|
|
15
13
|
from lamindb.core.storage._anndata_accessor import _safer_read_index
|
|
16
14
|
from scipy.sparse import issparse
|
|
17
15
|
from torch.utils.data import Dataset as torchDataset
|
|
@@ -19,6 +17,7 @@ from torch.utils.data import Dataset as torchDataset
|
|
|
19
17
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
20
18
|
|
|
21
19
|
from .config import LABELS_TOADD
|
|
20
|
+
from .mapped import MappedCollection, _Connect
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
@dataclass
|
|
@@ -43,7 +42,7 @@ class Dataset(torchDataset):
|
|
|
43
42
|
organisms (list[str]): list of organisms to load
|
|
44
43
|
(for now only validates the the genes map to this organism)
|
|
45
44
|
obs (list[str]): list of observations to load from the Collection
|
|
46
|
-
|
|
45
|
+
clss_to_predict (list[str]): list of observations to encode
|
|
47
46
|
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
48
47
|
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
49
48
|
"""
|
|
@@ -53,37 +52,23 @@ class Dataset(torchDataset):
|
|
|
53
52
|
organisms: Optional[Union[list[str], str]] = field(
|
|
54
53
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
55
54
|
)
|
|
56
|
-
obs: Optional[list[str]] = field(
|
|
57
|
-
default_factory=[
|
|
58
|
-
"self_reported_ethnicity_ontology_term_id",
|
|
59
|
-
"assay_ontology_term_id",
|
|
60
|
-
"development_stage_ontology_term_id",
|
|
61
|
-
"disease_ontology_term_id",
|
|
62
|
-
"cell_type_ontology_term_id",
|
|
63
|
-
"tissue_ontology_term_id",
|
|
64
|
-
"sex_ontology_term_id",
|
|
65
|
-
#'dataset_id',
|
|
66
|
-
#'cell_culture',
|
|
67
|
-
# "dpt_group",
|
|
68
|
-
# "heat_diff",
|
|
69
|
-
# "nnz",
|
|
70
|
-
]
|
|
71
|
-
)
|
|
72
55
|
# set of obs to prepare for prediction (encode)
|
|
73
|
-
|
|
56
|
+
clss_to_predict: Optional[list[str]] = field(default_factory=list)
|
|
74
57
|
# set of obs that need to be hierarchically prepared
|
|
75
58
|
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
76
59
|
join_vars: Literal["inner", "outer"] | None = None
|
|
60
|
+
metacell_mode: float = 0.0
|
|
77
61
|
|
|
78
62
|
def __post_init__(self):
|
|
79
63
|
self.mapped_dataset = mapped(
|
|
80
64
|
self.lamin_dataset,
|
|
81
|
-
obs_keys=self.
|
|
65
|
+
obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
|
|
82
66
|
join=self.join_vars,
|
|
83
|
-
encode_labels=self.
|
|
67
|
+
encode_labels=self.clss_to_predict,
|
|
84
68
|
unknown_label="unknown",
|
|
85
69
|
stream=True,
|
|
86
70
|
parallel=True,
|
|
71
|
+
metacell_mode=self.metacell_mode,
|
|
87
72
|
)
|
|
88
73
|
print(
|
|
89
74
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
@@ -93,8 +78,8 @@ class Dataset(torchDataset):
|
|
|
93
78
|
# generate tree from ontologies
|
|
94
79
|
if len(self.hierarchical_clss) > 0:
|
|
95
80
|
self.define_hierarchies(self.hierarchical_clss)
|
|
96
|
-
if len(self.
|
|
97
|
-
for clss in self.
|
|
81
|
+
if len(self.clss_to_predict) > 0:
|
|
82
|
+
for clss in self.clss_to_predict:
|
|
98
83
|
if clss not in self.hierarchical_clss:
|
|
99
84
|
# otherwise it's already been done
|
|
100
85
|
self.class_topred[clss] = set(
|
|
@@ -143,8 +128,7 @@ class Dataset(torchDataset):
|
|
|
143
128
|
+ "dataset contains:\n"
|
|
144
129
|
+ " {} cells\n".format(self.mapped_dataset.__len__())
|
|
145
130
|
+ " {} genes\n".format(self.genedf.shape[0])
|
|
146
|
-
+ " {}
|
|
147
|
-
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
131
|
+
+ " {} clss_to_predict\n".format(len(self.clss_to_predict))
|
|
148
132
|
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
149
133
|
+ " {} organisms\n".format(len(self.organisms))
|
|
150
134
|
+ (
|
|
@@ -154,9 +138,16 @@ class Dataset(torchDataset):
|
|
|
154
138
|
if len(self.class_topred) > 0
|
|
155
139
|
else ""
|
|
156
140
|
)
|
|
141
|
+
+ " {} metacell_mode\n".format(self.metacell_mode)
|
|
157
142
|
)
|
|
158
143
|
|
|
159
|
-
def get_label_weights(
|
|
144
|
+
def get_label_weights(
|
|
145
|
+
self,
|
|
146
|
+
obs_keys: str | list[str],
|
|
147
|
+
scaler: int = 10,
|
|
148
|
+
return_categories=False,
|
|
149
|
+
bypass_label=["neuron"],
|
|
150
|
+
):
|
|
160
151
|
"""Get all weights for the given label keys."""
|
|
161
152
|
if isinstance(obs_keys, str):
|
|
162
153
|
obs_keys = [obs_keys]
|
|
@@ -167,16 +158,24 @@ class Dataset(torchDataset):
|
|
|
167
158
|
)
|
|
168
159
|
labels_list.append(labels_to_str)
|
|
169
160
|
if len(labels_list) > 1:
|
|
170
|
-
labels =
|
|
161
|
+
labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
|
|
171
162
|
else:
|
|
172
163
|
labels = labels_list[0]
|
|
173
164
|
|
|
174
165
|
counter = Counter(labels) # type: ignore
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
166
|
+
if return_categories:
|
|
167
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
168
|
+
labels = np.array([rn[label] for label in labels])
|
|
169
|
+
counter = np.array(list(counter.values()))
|
|
170
|
+
weights = scaler / (counter + scaler)
|
|
171
|
+
return weights, labels
|
|
172
|
+
else:
|
|
173
|
+
counts = np.array([counter[label] for label in labels])
|
|
174
|
+
if scaler is None:
|
|
175
|
+
weights = 1.0 / counts
|
|
176
|
+
else:
|
|
177
|
+
weights = scaler / (counts + scaler)
|
|
178
|
+
return weights
|
|
180
179
|
|
|
181
180
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
182
181
|
"""
|
|
@@ -209,6 +208,8 @@ class Dataset(torchDataset):
|
|
|
209
208
|
"tissue_ontology_term_id",
|
|
210
209
|
"disease_ontology_term_id",
|
|
211
210
|
"development_stage_ontology_term_id",
|
|
211
|
+
"simplified_dev_stage",
|
|
212
|
+
"age_group",
|
|
212
213
|
"assay_ontology_term_id",
|
|
213
214
|
"self_reported_ethnicity_ontology_term_id",
|
|
214
215
|
]:
|
|
@@ -235,7 +236,11 @@ class Dataset(torchDataset):
|
|
|
235
236
|
.df(include=["parents__ontology_id"])
|
|
236
237
|
.set_index("ontology_id")
|
|
237
238
|
)
|
|
238
|
-
elif clss
|
|
239
|
+
elif clss in [
|
|
240
|
+
"development_stage_ontology_term_id",
|
|
241
|
+
"simplified_dev_stage",
|
|
242
|
+
"age_group",
|
|
243
|
+
]:
|
|
239
244
|
parentdf = (
|
|
240
245
|
bt.DevelopmentalStage.filter()
|
|
241
246
|
.df(include=["parents__ontology_id"])
|
|
@@ -268,7 +273,7 @@ class Dataset(torchDataset):
|
|
|
268
273
|
if len(j) == 0:
|
|
269
274
|
groupings.pop(i)
|
|
270
275
|
self.labels_groupings[clss] = groupings
|
|
271
|
-
if clss in self.
|
|
276
|
+
if clss in self.clss_to_predict:
|
|
272
277
|
# if we have added new clss, we need to update the encoder with them too.
|
|
273
278
|
mlength = len(self.mapped_dataset.encoders[clss])
|
|
274
279
|
|
|
@@ -354,6 +359,8 @@ class SimpleAnnDataset(torchDataset):
|
|
|
354
359
|
def mapped(
|
|
355
360
|
dataset,
|
|
356
361
|
obs_keys: list[str] | None = None,
|
|
362
|
+
obsm_keys: list[str] | None = None,
|
|
363
|
+
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
357
364
|
join: Literal["inner", "outer"] | None = "inner",
|
|
358
365
|
encode_labels: bool | list[str] = True,
|
|
359
366
|
unknown_label: str | dict[str, str] | None = None,
|
|
@@ -362,6 +369,8 @@ def mapped(
|
|
|
362
369
|
dtype: str | None = None,
|
|
363
370
|
stream: bool = False,
|
|
364
371
|
is_run_input: bool | None = None,
|
|
372
|
+
metacell_mode: bool = False,
|
|
373
|
+
meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
|
|
365
374
|
) -> MappedCollection:
|
|
366
375
|
path_list = []
|
|
367
376
|
for artifact in dataset.artifacts.all():
|
|
@@ -378,11 +387,15 @@ def mapped(
|
|
|
378
387
|
ds = MappedCollection(
|
|
379
388
|
path_list=path_list,
|
|
380
389
|
obs_keys=obs_keys,
|
|
390
|
+
obsm_keys=obsm_keys,
|
|
391
|
+
obs_filter=obs_filter,
|
|
381
392
|
join=join,
|
|
382
393
|
encode_labels=encode_labels,
|
|
383
394
|
unknown_label=unknown_label,
|
|
384
395
|
cache_categories=cache_categories,
|
|
385
396
|
parallel=parallel,
|
|
386
397
|
dtype=dtype,
|
|
398
|
+
meta_assays=meta_assays,
|
|
399
|
+
metacell_mode=metacell_mode,
|
|
387
400
|
)
|
|
388
401
|
return ds
|