scdataloader 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +63 -42
- scdataloader/collator.py +87 -43
- scdataloader/config.py +106 -0
- scdataloader/data.py +78 -98
- scdataloader/datamodule.py +375 -0
- scdataloader/mapped.py +22 -7
- scdataloader/preprocess.py +444 -109
- scdataloader/utils.py +106 -63
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/METADATA +46 -2
- scdataloader-0.0.4.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/entry_points.txt +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
0.7.0
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from scdataloader.preprocess import
|
|
2
|
+
from scdataloader.preprocess import (
|
|
3
|
+
LaminPreprocessor,
|
|
4
|
+
additional_preprocess,
|
|
5
|
+
additional_postprocess,
|
|
6
|
+
)
|
|
3
7
|
import lamindb as ln
|
|
4
8
|
from typing import Optional, Union
|
|
5
9
|
|
|
10
|
+
|
|
11
|
+
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
|
|
6
12
|
def main():
|
|
7
13
|
parser = argparse.ArgumentParser(
|
|
8
14
|
description="Preprocess datasets in a given lamindb collection."
|
|
@@ -11,22 +17,31 @@ def main():
|
|
|
11
17
|
"--name", type=str, required=True, help="Name of the input dataset"
|
|
12
18
|
)
|
|
13
19
|
parser.add_argument(
|
|
14
|
-
"--new_name",
|
|
20
|
+
"--new_name",
|
|
21
|
+
type=str,
|
|
22
|
+
default="preprocessed dataset",
|
|
23
|
+
help="Name of the preprocessed dataset.",
|
|
15
24
|
)
|
|
16
25
|
parser.add_argument(
|
|
17
26
|
"--description",
|
|
18
27
|
type=str,
|
|
19
|
-
default="preprocessed by scDataLoader"
|
|
28
|
+
default="preprocessed by scDataLoader",
|
|
20
29
|
help="Description of the preprocessed dataset.",
|
|
21
30
|
)
|
|
22
31
|
parser.add_argument(
|
|
23
32
|
"--start_at", type=int, default=0, help="Position to start preprocessing at."
|
|
24
33
|
)
|
|
25
34
|
parser.add_argument(
|
|
26
|
-
"--new_version",
|
|
35
|
+
"--new_version",
|
|
36
|
+
type=str,
|
|
37
|
+
default="2",
|
|
38
|
+
help="Version of the output dataset and files.",
|
|
27
39
|
)
|
|
28
40
|
parser.add_argument(
|
|
29
|
-
"--instance",
|
|
41
|
+
"--instance",
|
|
42
|
+
type=str,
|
|
43
|
+
default=None,
|
|
44
|
+
help="Instance storing the input dataset, if not local",
|
|
30
45
|
)
|
|
31
46
|
parser.add_argument(
|
|
32
47
|
"--version", type=str, default=None, help="Version of the input dataset."
|
|
@@ -35,125 +50,127 @@ def main():
|
|
|
35
50
|
"--filter_gene_by_counts",
|
|
36
51
|
type=Union[int, bool],
|
|
37
52
|
default=False,
|
|
38
|
-
help="Determines whether to filter genes by counts."
|
|
53
|
+
help="Determines whether to filter genes by counts.",
|
|
39
54
|
)
|
|
40
55
|
parser.add_argument(
|
|
41
56
|
"--filter_cell_by_counts",
|
|
42
57
|
type=Union[int, bool],
|
|
43
58
|
default=False,
|
|
44
|
-
help="Determines whether to filter cells by counts."
|
|
59
|
+
help="Determines whether to filter cells by counts.",
|
|
45
60
|
)
|
|
46
61
|
parser.add_argument(
|
|
47
62
|
"--normalize_sum",
|
|
48
63
|
type=float,
|
|
49
64
|
default=1e4,
|
|
50
|
-
help="Determines whether to normalize the total counts of each cell to a specific value."
|
|
51
|
-
)
|
|
52
|
-
parser.add_argument(
|
|
53
|
-
"--keep_norm_layer",
|
|
54
|
-
type=bool,
|
|
55
|
-
default=False,
|
|
56
|
-
help="Determines whether to keep the normalization layer."
|
|
65
|
+
help="Determines whether to normalize the total counts of each cell to a specific value.",
|
|
57
66
|
)
|
|
58
67
|
parser.add_argument(
|
|
59
68
|
"--subset_hvg",
|
|
60
69
|
type=int,
|
|
61
70
|
default=0,
|
|
62
|
-
help="Determines whether to subset highly variable genes."
|
|
71
|
+
help="Determines whether to subset highly variable genes.",
|
|
63
72
|
)
|
|
64
73
|
parser.add_argument(
|
|
65
74
|
"--hvg_flavor",
|
|
66
75
|
type=str,
|
|
67
76
|
default="seurat_v3",
|
|
68
|
-
help="Specifies the flavor of highly variable genes selection."
|
|
77
|
+
help="Specifies the flavor of highly variable genes selection.",
|
|
69
78
|
)
|
|
70
79
|
parser.add_argument(
|
|
71
80
|
"--binning",
|
|
72
81
|
type=Optional[int],
|
|
73
82
|
default=None,
|
|
74
|
-
help="Determines whether to bin the data into discrete values of number of bins provided."
|
|
83
|
+
help="Determines whether to bin the data into discrete values of number of bins provided.",
|
|
75
84
|
)
|
|
76
85
|
parser.add_argument(
|
|
77
86
|
"--result_binned_key",
|
|
78
87
|
type=str,
|
|
79
88
|
default="X_binned",
|
|
80
|
-
help="Specifies the key of AnnData to store the binned data."
|
|
89
|
+
help="Specifies the key of AnnData to store the binned data.",
|
|
81
90
|
)
|
|
82
91
|
parser.add_argument(
|
|
83
92
|
"--length_normalize",
|
|
84
93
|
type=bool,
|
|
85
94
|
default=False,
|
|
86
|
-
help="Determines whether to normalize the length."
|
|
95
|
+
help="Determines whether to normalize the length.",
|
|
87
96
|
)
|
|
88
97
|
parser.add_argument(
|
|
89
98
|
"--force_preprocess",
|
|
90
99
|
type=bool,
|
|
91
100
|
default=False,
|
|
92
|
-
help="Determines whether to force preprocessing."
|
|
101
|
+
help="Determines whether to force preprocessing.",
|
|
93
102
|
)
|
|
94
103
|
parser.add_argument(
|
|
95
104
|
"--min_dataset_size",
|
|
96
105
|
type=int,
|
|
97
106
|
default=100,
|
|
98
|
-
help="Specifies the minimum dataset size."
|
|
107
|
+
help="Specifies the minimum dataset size.",
|
|
99
108
|
)
|
|
100
109
|
parser.add_argument(
|
|
101
110
|
"--min_valid_genes_id",
|
|
102
111
|
type=int,
|
|
103
112
|
default=10_000,
|
|
104
|
-
help="Specifies the minimum valid genes id."
|
|
113
|
+
help="Specifies the minimum valid genes id.",
|
|
105
114
|
)
|
|
106
115
|
parser.add_argument(
|
|
107
116
|
"--min_nnz_genes",
|
|
108
117
|
type=int,
|
|
109
|
-
default=
|
|
110
|
-
help="Specifies the minimum non-zero genes."
|
|
118
|
+
default=400,
|
|
119
|
+
help="Specifies the minimum non-zero genes.",
|
|
111
120
|
)
|
|
112
121
|
parser.add_argument(
|
|
113
122
|
"--maxdropamount",
|
|
114
123
|
type=int,
|
|
115
|
-
default=
|
|
116
|
-
help="Specifies the maximum drop amount."
|
|
124
|
+
default=50,
|
|
125
|
+
help="Specifies the maximum drop amount.",
|
|
117
126
|
)
|
|
118
127
|
parser.add_argument(
|
|
119
|
-
"--madoutlier",
|
|
120
|
-
type=int,
|
|
121
|
-
default=5,
|
|
122
|
-
help="Specifies the MAD outlier."
|
|
128
|
+
"--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
|
|
123
129
|
)
|
|
124
130
|
parser.add_argument(
|
|
125
131
|
"--pct_mt_outlier",
|
|
126
132
|
type=int,
|
|
127
133
|
default=8,
|
|
128
|
-
help="Specifies the percentage of MT outlier."
|
|
134
|
+
help="Specifies the percentage of MT outlier.",
|
|
129
135
|
)
|
|
130
136
|
parser.add_argument(
|
|
131
|
-
"--batch_key",
|
|
132
|
-
type=Optional[str],
|
|
133
|
-
default=None,
|
|
134
|
-
help="Specifies the batch key."
|
|
137
|
+
"--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
|
|
135
138
|
)
|
|
136
139
|
parser.add_argument(
|
|
137
140
|
"--skip_validate",
|
|
138
141
|
type=bool,
|
|
139
142
|
default=False,
|
|
140
|
-
help="Determines whether to skip validation."
|
|
143
|
+
help="Determines whether to skip validation.",
|
|
144
|
+
)
|
|
145
|
+
parser.add_argument(
|
|
146
|
+
"--do_postp",
|
|
147
|
+
type=bool,
|
|
148
|
+
default=False,
|
|
149
|
+
help="Determines whether to do postprocessing.",
|
|
141
150
|
)
|
|
142
151
|
args = parser.parse_args()
|
|
143
152
|
|
|
144
153
|
# Load the collection
|
|
154
|
+
# if not args.preprocess:
|
|
155
|
+
# print("Only preprocess is available for now")
|
|
156
|
+
# return
|
|
145
157
|
if args.instance is not None:
|
|
146
|
-
collection =
|
|
147
|
-
|
|
148
|
-
|
|
158
|
+
collection = (
|
|
159
|
+
ln.Collection.using(instance=args.instance)
|
|
160
|
+
.filter(name=args.name, version=args.version)
|
|
161
|
+
.first()
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
collection = ln.Collection.filter(name=args.name, version=args.version).first()
|
|
149
165
|
|
|
150
|
-
print(
|
|
166
|
+
print(
|
|
167
|
+
"using the dataset ", collection, " of size ", len(collection.artifacts.all())
|
|
168
|
+
)
|
|
151
169
|
# Initialize the preprocessor
|
|
152
170
|
preprocessor = LaminPreprocessor(
|
|
153
171
|
filter_gene_by_counts=args.filter_gene_by_counts,
|
|
154
172
|
filter_cell_by_counts=args.filter_cell_by_counts,
|
|
155
173
|
normalize_sum=args.normalize_sum,
|
|
156
|
-
keep_norm_layer=args.keep_norm_layer,
|
|
157
174
|
subset_hvg=args.subset_hvg,
|
|
158
175
|
hvg_flavor=args.hvg_flavor,
|
|
159
176
|
binning=args.binning,
|
|
@@ -168,10 +185,14 @@ def main():
|
|
|
168
185
|
pct_mt_outlier=args.pct_mt_outlier,
|
|
169
186
|
batch_key=args.batch_key,
|
|
170
187
|
skip_validate=args.skip_validate,
|
|
188
|
+
do_postp=args.do_postp,
|
|
189
|
+
additional_preprocess=additional_preprocess,
|
|
190
|
+
additional_postprocess=additional_postprocess,
|
|
191
|
+
keep_files=False,
|
|
171
192
|
)
|
|
172
193
|
|
|
173
194
|
# Preprocess the dataset
|
|
174
|
-
|
|
195
|
+
preprocessor(
|
|
175
196
|
collection,
|
|
176
197
|
name=args.new_name,
|
|
177
198
|
description=args.description,
|
scdataloader/collator.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from .utils import load_genes
|
|
3
|
-
from torch import Tensor
|
|
3
|
+
from torch import Tensor, long
|
|
4
4
|
|
|
5
5
|
# class SimpleCollator:
|
|
6
6
|
|
|
@@ -9,17 +9,18 @@ class Collator:
|
|
|
9
9
|
def __init__(
|
|
10
10
|
self,
|
|
11
11
|
organisms: list,
|
|
12
|
+
how="all",
|
|
12
13
|
org_to_id: dict = None,
|
|
13
14
|
valid_genes: list = [],
|
|
14
15
|
max_len=2000,
|
|
15
|
-
|
|
16
|
-
add_zero_genes=200,
|
|
16
|
+
add_zero_genes=0,
|
|
17
17
|
logp1=False,
|
|
18
18
|
norm_to=None,
|
|
19
|
-
|
|
19
|
+
n_bins=0,
|
|
20
20
|
tp_name=None,
|
|
21
21
|
organism_name="organism_ontology_term_id",
|
|
22
22
|
class_names=[],
|
|
23
|
+
genelist=[],
|
|
23
24
|
):
|
|
24
25
|
"""
|
|
25
26
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -27,20 +28,29 @@ class Collator:
|
|
|
27
28
|
allowing for various configurations such as maximum gene list length, normalization,
|
|
28
29
|
and selection method for gene expression.
|
|
29
30
|
|
|
31
|
+
This Collator should work with scVI's dataloader as well!
|
|
32
|
+
|
|
30
33
|
Args:
|
|
31
34
|
organisms (list): List of organisms to be considered for gene expression data.
|
|
35
|
+
it will drop any other organism it sees (might lead to batches of different sizes!)
|
|
36
|
+
how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
|
|
37
|
+
one of ["most expr", "random expr", "all", "some"]:
|
|
38
|
+
"most expr": selects the max_len most expressed genes,
|
|
39
|
+
if less genes are expressed, will sample random unexpressed genes,
|
|
40
|
+
"random expr": uses a random set of max_len expressed genes.
|
|
41
|
+
if less genes are expressed, will sample random unexpressed genes
|
|
42
|
+
"all": uses all genes
|
|
43
|
+
"some": uses only the genes provided through the genelist param
|
|
32
44
|
org_to_id (dict): Dictionary mapping organisms to their respective IDs.
|
|
33
|
-
labels (list, optional): List of labels for the data. Defaults to [].
|
|
34
45
|
valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
46
|
+
it will drop any other genes from the input expression data (usefull when your model only works on some genes)
|
|
47
|
+
max_len (int, optional): Maximum number of genes to use (for random expr and most expr). Defaults to 2000.
|
|
48
|
+
n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
|
|
49
|
+
add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
|
|
38
50
|
logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
|
|
39
51
|
norm_to (str, optional): Normalization method to be applied. Defaults to None.
|
|
40
|
-
how (str, optional): Method for selecting gene expression. Defaults to "most expr".
|
|
41
52
|
"""
|
|
42
53
|
self.organisms = organisms
|
|
43
|
-
self.valid_genes = valid_genes
|
|
44
54
|
self.max_len = max_len
|
|
45
55
|
self.n_bins = n_bins
|
|
46
56
|
self.add_zero_genes = add_zero_genes
|
|
@@ -53,6 +63,8 @@ class Collator:
|
|
|
53
63
|
if org_to_id is not None
|
|
54
64
|
else set(organisms)
|
|
55
65
|
)
|
|
66
|
+
if self.how == "some":
|
|
67
|
+
assert len(genelist) > 0, "if how is some, genelist must be provided"
|
|
56
68
|
self.organism_name = organism_name
|
|
57
69
|
self.tp_name = tp_name
|
|
58
70
|
self.class_names = class_names
|
|
@@ -60,18 +72,21 @@ class Collator:
|
|
|
60
72
|
self.start_idx = {}
|
|
61
73
|
self.accepted_genes = {}
|
|
62
74
|
self.genedf = load_genes(organisms)
|
|
75
|
+
self.to_subset = {}
|
|
63
76
|
for organism in set(self.genedf.organism):
|
|
64
77
|
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
78
|
+
tot = self.genedf[self.genedf.index.isin(valid_genes)]
|
|
65
79
|
org = org_to_id[organism] if org_to_id is not None else organism
|
|
66
|
-
self.start_idx.update(
|
|
67
|
-
{org: np.where(self.genedf.organism == organism)[0][0]}
|
|
68
|
-
)
|
|
80
|
+
self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
|
|
69
81
|
if len(valid_genes) > 0:
|
|
70
82
|
self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
|
|
83
|
+
if len(genelist) > 0:
|
|
84
|
+
df = ogenedf[ogenedf.index.isin(valid_genes)]
|
|
85
|
+
self.to_subset.update({org: df.index.isin(genelist)})
|
|
71
86
|
|
|
72
87
|
def __call__(self, batch):
|
|
73
88
|
"""
|
|
74
|
-
__call__
|
|
89
|
+
__call__ applies the collator to a minibatch of data
|
|
75
90
|
|
|
76
91
|
Args:
|
|
77
92
|
batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
@@ -92,33 +107,62 @@ class Collator:
|
|
|
92
107
|
other_classes = []
|
|
93
108
|
gene_locs = []
|
|
94
109
|
tp = []
|
|
110
|
+
dataset = []
|
|
111
|
+
nnz_loc = []
|
|
95
112
|
for elem in batch:
|
|
96
113
|
organism_id = elem[self.organism_name]
|
|
97
114
|
if organism_id not in self.organism_ids:
|
|
98
115
|
continue
|
|
116
|
+
if "dataset" in elem:
|
|
117
|
+
dataset.append(elem["dataset"])
|
|
99
118
|
expr = np.array(elem["x"])
|
|
100
119
|
total_count.append(expr.sum())
|
|
101
120
|
if len(self.accepted_genes) > 0:
|
|
102
121
|
expr = expr[self.accepted_genes[organism_id]]
|
|
103
122
|
if self.how == "most expr":
|
|
104
|
-
|
|
123
|
+
nnz_loc = np.where(expr > 0)[0]
|
|
124
|
+
ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
|
|
125
|
+
loc = np.argsort(expr)[-(ma):][::-1]
|
|
126
|
+
# nnz_loc = [1] * 30_000
|
|
127
|
+
# loc = np.argsort(expr)[-(self.max_len) :][::-1]
|
|
105
128
|
elif self.how == "random expr":
|
|
106
129
|
nnz_loc = np.where(expr > 0)[0]
|
|
107
130
|
loc = nnz_loc[
|
|
108
|
-
np.random.choice(
|
|
131
|
+
np.random.choice(
|
|
132
|
+
len(nnz_loc),
|
|
133
|
+
self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
|
|
134
|
+
replace=False,
|
|
135
|
+
# p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
|
|
136
|
+
)
|
|
109
137
|
]
|
|
110
|
-
elif self.how
|
|
138
|
+
elif self.how in ["all", "some"]:
|
|
111
139
|
loc = np.arange(len(expr))
|
|
112
140
|
else:
|
|
113
141
|
raise ValueError("how must be either most expr or random expr")
|
|
114
|
-
if
|
|
142
|
+
if (
|
|
143
|
+
(self.add_zero_genes > 0) or (self.max_len > len(nnz_loc))
|
|
144
|
+
) and self.how not in ["all", "some"]:
|
|
115
145
|
zero_loc = np.where(expr == 0)[0]
|
|
116
|
-
zero_loc = [
|
|
117
|
-
np.random.choice(
|
|
146
|
+
zero_loc = zero_loc[
|
|
147
|
+
np.random.choice(
|
|
148
|
+
len(zero_loc),
|
|
149
|
+
self.add_zero_genes
|
|
150
|
+
+ (
|
|
151
|
+
0
|
|
152
|
+
if self.max_len < len(nnz_loc)
|
|
153
|
+
else self.max_len - len(nnz_loc)
|
|
154
|
+
),
|
|
155
|
+
replace=False,
|
|
156
|
+
)
|
|
118
157
|
]
|
|
119
158
|
loc = np.concatenate((loc, zero_loc), axis=None)
|
|
120
|
-
|
|
121
|
-
|
|
159
|
+
expr = expr[loc]
|
|
160
|
+
loc = loc + self.start_idx[organism_id]
|
|
161
|
+
if self.how == "some":
|
|
162
|
+
expr = expr[self.to_subset[organism_id]]
|
|
163
|
+
loc = loc[self.to_subset[organism_id]]
|
|
164
|
+
exprs.append(expr)
|
|
165
|
+
gene_locs.append(loc)
|
|
122
166
|
|
|
123
167
|
if self.tp_name is not None:
|
|
124
168
|
tp.append(elem[self.tp_name])
|
|
@@ -132,6 +176,7 @@ class Collator:
|
|
|
132
176
|
gene_locs = np.array(gene_locs)
|
|
133
177
|
total_count = np.array(total_count)
|
|
134
178
|
other_classes = np.array(other_classes)
|
|
179
|
+
dataset = np.array(dataset)
|
|
135
180
|
|
|
136
181
|
# normalize counts
|
|
137
182
|
if self.norm_to is not None:
|
|
@@ -152,17 +197,26 @@ class Collator:
|
|
|
152
197
|
# do encoding of graph location
|
|
153
198
|
# encode all the edges in some sparse way
|
|
154
199
|
# normalizing total counts between 0,1
|
|
155
|
-
|
|
200
|
+
ret = {
|
|
156
201
|
"x": Tensor(expr),
|
|
157
202
|
"genes": Tensor(gene_locs).int(),
|
|
158
203
|
"class": Tensor(other_classes).int(),
|
|
159
204
|
"tp": Tensor(tp),
|
|
160
205
|
"depth": Tensor(total_count),
|
|
161
206
|
}
|
|
207
|
+
if len(dataset) > 0:
|
|
208
|
+
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
209
|
+
return ret
|
|
162
210
|
|
|
163
211
|
|
|
164
212
|
class AnnDataCollator(Collator):
|
|
165
213
|
def __init__(self, *args, **kwargs):
|
|
214
|
+
"""
|
|
215
|
+
AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
@see Collator
|
|
219
|
+
"""
|
|
166
220
|
super().__init__(*args, **kwargs)
|
|
167
221
|
|
|
168
222
|
def __call__(self, batch):
|
|
@@ -218,28 +272,14 @@ class AnnDataCollator(Collator):
|
|
|
218
272
|
}
|
|
219
273
|
|
|
220
274
|
|
|
221
|
-
class SCVICollator(Collator):
|
|
222
|
-
def __init__(self, *args, **kwargs):
|
|
223
|
-
super().__init__(*args, **kwargs)
|
|
224
|
-
|
|
225
|
-
def __call__(self, batch):
|
|
226
|
-
expr = batch["x"]
|
|
227
|
-
total_count = expr.sum(axis=1)
|
|
228
|
-
if self.how == "most expr":
|
|
229
|
-
loc = np.argsort(expr)[:, -(self.max_len) :][:, ::-1]
|
|
230
|
-
else:
|
|
231
|
-
raise ValueError("how must be either most expr or random expr")
|
|
232
|
-
if self.logp1:
|
|
233
|
-
expr = np.log2(1 + expr)
|
|
234
|
-
return {
|
|
235
|
-
"x": Tensor(expr[np.arange(expr.shape[0])[:, None], loc]),
|
|
236
|
-
"genes": Tensor(loc.copy()).int(),
|
|
237
|
-
"depth": Tensor(total_count),
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
|
|
241
275
|
class GeneformerCollator(Collator):
|
|
242
276
|
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
277
|
+
"""
|
|
278
|
+
GeneformerCollator to finish
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
gene_norm_list (list): the normalization of expression through all datasets, per gene.
|
|
282
|
+
"""
|
|
243
283
|
super().__init__(*args, **kwargs)
|
|
244
284
|
self.gene_norm_list = gene_norm_list
|
|
245
285
|
|
|
@@ -251,6 +291,10 @@ class GeneformerCollator(Collator):
|
|
|
251
291
|
|
|
252
292
|
|
|
253
293
|
class scGPTCollator(Collator):
|
|
294
|
+
"""
|
|
295
|
+
scGPTCollator to finish
|
|
296
|
+
"""
|
|
297
|
+
|
|
254
298
|
def __call__(self, batch):
|
|
255
299
|
super().__call__(batch)
|
|
256
300
|
# binning
|
scdataloader/config.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
LABELS_TOADD = {
|
|
2
|
+
"assay_ontology_term_id": {
|
|
3
|
+
"10x transcription profiling": "EFO:0030003",
|
|
4
|
+
"spatial transcriptomics": "EFO:0008994",
|
|
5
|
+
"10x 3' transcription profiling": "EFO:0030003",
|
|
6
|
+
"10x 5' transcription profiling": "EFO:0030004",
|
|
7
|
+
},
|
|
8
|
+
"disease_ontology_term_id": {
|
|
9
|
+
"metabolic disease": "MONDO:0005066",
|
|
10
|
+
"chronic kidney disease": "MONDO:0005300",
|
|
11
|
+
"chromosomal disorder": "MONDO:0019040",
|
|
12
|
+
"infectious disease": "MONDO:0005550",
|
|
13
|
+
"inflammatory disease": "MONDO:0021166",
|
|
14
|
+
# "immune system disease",
|
|
15
|
+
"disorder of development or morphogenesis": "MONDO:0021147",
|
|
16
|
+
"mitochondrial disease": "MONDO:0044970",
|
|
17
|
+
"psychiatric disorder": "MONDO:0002025",
|
|
18
|
+
"cancer or benign tumor": "MONDO:0002025",
|
|
19
|
+
"neoplasm": "MONDO:0005070",
|
|
20
|
+
},
|
|
21
|
+
"cell_type_ontology_term_id": {
|
|
22
|
+
"progenitor cell": "CL:0011026",
|
|
23
|
+
"hematopoietic cell": "CL:0000988",
|
|
24
|
+
"myoblast": "CL:0000056",
|
|
25
|
+
"myeloid cell": "CL:0000763",
|
|
26
|
+
"neuron": "CL:0000540",
|
|
27
|
+
"electrically active cell": "CL:0000211",
|
|
28
|
+
"epithelial cell": "CL:0000066",
|
|
29
|
+
"secretory cell": "CL:0000151",
|
|
30
|
+
"stem cell": "CL:0000034",
|
|
31
|
+
"non-terminally differentiated cell": "CL:0000055",
|
|
32
|
+
"supporting cell": "CL:0000630",
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
COARSE_TISSUE = {
|
|
37
|
+
"adipose tissue": "",
|
|
38
|
+
"bladder organ": "",
|
|
39
|
+
"blood": "",
|
|
40
|
+
"bone marrow": "",
|
|
41
|
+
"brain": "",
|
|
42
|
+
"breast": "",
|
|
43
|
+
"esophagus": "",
|
|
44
|
+
"eye": "",
|
|
45
|
+
"embryo": "",
|
|
46
|
+
"fallopian tube": "",
|
|
47
|
+
"gall bladder": "",
|
|
48
|
+
"heart": "",
|
|
49
|
+
"intestine": "",
|
|
50
|
+
"kidney": "",
|
|
51
|
+
"liver": "",
|
|
52
|
+
"lung": "",
|
|
53
|
+
"lymph node": "",
|
|
54
|
+
"musculature of body": "",
|
|
55
|
+
"nose": "",
|
|
56
|
+
"ovary": "",
|
|
57
|
+
"pancreas": "",
|
|
58
|
+
"placenta": "",
|
|
59
|
+
"skin of body": "",
|
|
60
|
+
"spinal cord": "",
|
|
61
|
+
"spleen": "",
|
|
62
|
+
"stomach": "",
|
|
63
|
+
"thymus": "",
|
|
64
|
+
"thyroid gland": "",
|
|
65
|
+
"tongue": "",
|
|
66
|
+
"uterus": "",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
COARSE_ANCESTRY = {
|
|
70
|
+
"African": "",
|
|
71
|
+
"Chinese": "",
|
|
72
|
+
"East Asian": "",
|
|
73
|
+
"Eskimo": "",
|
|
74
|
+
"European": "",
|
|
75
|
+
"Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
|
|
76
|
+
"Hispanic or Latin American": "",
|
|
77
|
+
"Native American": "",
|
|
78
|
+
"Oceanian": "",
|
|
79
|
+
"South Asian": "",
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
COARSE_DEVELOPMENT_STAGE = {
|
|
83
|
+
"Embryonic human": "",
|
|
84
|
+
"Fetal": "",
|
|
85
|
+
"Immature": "",
|
|
86
|
+
"Mature": "",
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
COARSE_ASSAY = {
|
|
90
|
+
"10x 3'": "",
|
|
91
|
+
"10x 5'": "",
|
|
92
|
+
"10x multiome": "",
|
|
93
|
+
"CEL-seq2": "",
|
|
94
|
+
"Drop-seq": "",
|
|
95
|
+
"GEXSCOPE technology": "",
|
|
96
|
+
"inDrop": "",
|
|
97
|
+
"microwell-seq": "",
|
|
98
|
+
"sci-Plex": "",
|
|
99
|
+
"sci-RNA-seq": "",
|
|
100
|
+
"Seq-Well": "",
|
|
101
|
+
"Slide-seq": "",
|
|
102
|
+
"Smart-seq": "",
|
|
103
|
+
"SPLiT-seq": "",
|
|
104
|
+
"TruDrop": "",
|
|
105
|
+
"Visium Spatial Gene Expression": "",
|
|
106
|
+
}
|