scdataloader 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__init__.py +4 -0
- scdataloader/__main__.py +188 -0
- scdataloader/collator.py +263 -0
- scdataloader/data.py +142 -159
- scdataloader/dataloader.py +318 -0
- scdataloader/mapped.py +24 -25
- scdataloader/preprocess.py +126 -145
- scdataloader/utils.py +99 -76
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/METADATA +33 -7
- scdataloader-0.0.3.dist-info/RECORD +15 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/WHEEL +1 -1
- scdataloader-0.0.2.dist-info/RECORD +0 -12
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.2.dist-info → scdataloader-0.0.3.dist-info}/entry_points.txt +0 -0
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from scdataloader.preprocess import LaminPreprocessor
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
parser = argparse.ArgumentParser(
|
|
8
|
+
description="Preprocess datasets in a given lamindb collection."
|
|
9
|
+
)
|
|
10
|
+
parser.add_argument(
|
|
11
|
+
"--name", type=str, required=True, help="Name of the input dataset"
|
|
12
|
+
)
|
|
13
|
+
parser.add_argument(
|
|
14
|
+
"--new_name", type=str, required=True, help="Name of the preprocessed dataset."
|
|
15
|
+
)
|
|
16
|
+
parser.add_argument(
|
|
17
|
+
"--description",
|
|
18
|
+
type=str,
|
|
19
|
+
default="preprocessed by scDataLoader"
|
|
20
|
+
help="Description of the preprocessed dataset.",
|
|
21
|
+
)
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"--start_at", type=int, default=0, help="Position to start preprocessing at."
|
|
24
|
+
)
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
"--new_version", type=str, default="2", help="Version of the output dataset and files."
|
|
27
|
+
)
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--instance", type=str, default=None, help="Instance storing the input dataset, if not local"
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"--version", type=str, default=None, help="Version of the input dataset."
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--filter_gene_by_counts",
|
|
36
|
+
type=Union[int, bool],
|
|
37
|
+
default=False,
|
|
38
|
+
help="Determines whether to filter genes by counts."
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--filter_cell_by_counts",
|
|
42
|
+
type=Union[int, bool],
|
|
43
|
+
default=False,
|
|
44
|
+
help="Determines whether to filter cells by counts."
|
|
45
|
+
)
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"--normalize_sum",
|
|
48
|
+
type=float,
|
|
49
|
+
default=1e4,
|
|
50
|
+
help="Determines whether to normalize the total counts of each cell to a specific value."
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--keep_norm_layer",
|
|
54
|
+
type=bool,
|
|
55
|
+
default=False,
|
|
56
|
+
help="Determines whether to keep the normalization layer."
|
|
57
|
+
)
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--subset_hvg",
|
|
60
|
+
type=int,
|
|
61
|
+
default=0,
|
|
62
|
+
help="Determines whether to subset highly variable genes."
|
|
63
|
+
)
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--hvg_flavor",
|
|
66
|
+
type=str,
|
|
67
|
+
default="seurat_v3",
|
|
68
|
+
help="Specifies the flavor of highly variable genes selection."
|
|
69
|
+
)
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--binning",
|
|
72
|
+
type=Optional[int],
|
|
73
|
+
default=None,
|
|
74
|
+
help="Determines whether to bin the data into discrete values of number of bins provided."
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--result_binned_key",
|
|
78
|
+
type=str,
|
|
79
|
+
default="X_binned",
|
|
80
|
+
help="Specifies the key of AnnData to store the binned data."
|
|
81
|
+
)
|
|
82
|
+
parser.add_argument(
|
|
83
|
+
"--length_normalize",
|
|
84
|
+
type=bool,
|
|
85
|
+
default=False,
|
|
86
|
+
help="Determines whether to normalize the length."
|
|
87
|
+
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--force_preprocess",
|
|
90
|
+
type=bool,
|
|
91
|
+
default=False,
|
|
92
|
+
help="Determines whether to force preprocessing."
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--min_dataset_size",
|
|
96
|
+
type=int,
|
|
97
|
+
default=100,
|
|
98
|
+
help="Specifies the minimum dataset size."
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--min_valid_genes_id",
|
|
102
|
+
type=int,
|
|
103
|
+
default=10_000,
|
|
104
|
+
help="Specifies the minimum valid genes id."
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--min_nnz_genes",
|
|
108
|
+
type=int,
|
|
109
|
+
default=200,
|
|
110
|
+
help="Specifies the minimum non-zero genes."
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--maxdropamount",
|
|
114
|
+
type=int,
|
|
115
|
+
default=2,
|
|
116
|
+
help="Specifies the maximum drop amount."
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--madoutlier",
|
|
120
|
+
type=int,
|
|
121
|
+
default=5,
|
|
122
|
+
help="Specifies the MAD outlier."
|
|
123
|
+
)
|
|
124
|
+
parser.add_argument(
|
|
125
|
+
"--pct_mt_outlier",
|
|
126
|
+
type=int,
|
|
127
|
+
default=8,
|
|
128
|
+
help="Specifies the percentage of MT outlier."
|
|
129
|
+
)
|
|
130
|
+
parser.add_argument(
|
|
131
|
+
"--batch_key",
|
|
132
|
+
type=Optional[str],
|
|
133
|
+
default=None,
|
|
134
|
+
help="Specifies the batch key."
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--skip_validate",
|
|
138
|
+
type=bool,
|
|
139
|
+
default=False,
|
|
140
|
+
help="Determines whether to skip validation."
|
|
141
|
+
)
|
|
142
|
+
args = parser.parse_args()
|
|
143
|
+
|
|
144
|
+
# Load the collection
|
|
145
|
+
if args.instance is not None:
|
|
146
|
+
collection = ln.Collection.using(instance=args.instance).filter(name=args.name, version=args.version).first()
|
|
147
|
+
|
|
148
|
+
collection = ln.Collection.filter(name=args.name, version=args.version).first()
|
|
149
|
+
|
|
150
|
+
print("using the dataset ",collection, " of size ",len(collection.artifacts.all()))
|
|
151
|
+
# Initialize the preprocessor
|
|
152
|
+
preprocessor = LaminPreprocessor(
|
|
153
|
+
filter_gene_by_counts=args.filter_gene_by_counts,
|
|
154
|
+
filter_cell_by_counts=args.filter_cell_by_counts,
|
|
155
|
+
normalize_sum=args.normalize_sum,
|
|
156
|
+
keep_norm_layer=args.keep_norm_layer,
|
|
157
|
+
subset_hvg=args.subset_hvg,
|
|
158
|
+
hvg_flavor=args.hvg_flavor,
|
|
159
|
+
binning=args.binning,
|
|
160
|
+
result_binned_key=args.result_binned_key,
|
|
161
|
+
length_normalize=args.length_normalize,
|
|
162
|
+
force_preprocess=args.force_preprocess,
|
|
163
|
+
min_dataset_size=args.min_dataset_size,
|
|
164
|
+
min_valid_genes_id=args.min_valid_genes_id,
|
|
165
|
+
min_nnz_genes=args.min_nnz_genes,
|
|
166
|
+
maxdropamount=args.maxdropamount,
|
|
167
|
+
madoutlier=args.madoutlier,
|
|
168
|
+
pct_mt_outlier=args.pct_mt_outlier,
|
|
169
|
+
batch_key=args.batch_key,
|
|
170
|
+
skip_validate=args.skip_validate,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Preprocess the dataset
|
|
174
|
+
preprocessed_dataset = preprocessor(
|
|
175
|
+
collection,
|
|
176
|
+
name=args.new_name,
|
|
177
|
+
description=args.description,
|
|
178
|
+
start_at=args.start_at,
|
|
179
|
+
version=args.new_version,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
print(
|
|
183
|
+
f"Preprocessed dataset saved with version {args.version} and name {args.new_name}."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
main()
|
scdataloader/collator.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from .utils import load_genes
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
# class SimpleCollator:
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Collator:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
organisms: list,
|
|
12
|
+
org_to_id: dict = None,
|
|
13
|
+
valid_genes: list = [],
|
|
14
|
+
max_len=2000,
|
|
15
|
+
n_bins=0,
|
|
16
|
+
add_zero_genes=200,
|
|
17
|
+
logp1=False,
|
|
18
|
+
norm_to=None,
|
|
19
|
+
how="all",
|
|
20
|
+
tp_name=None,
|
|
21
|
+
organism_name="organism_ontology_term_id",
|
|
22
|
+
class_names=[],
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
This class is responsible for collating data for the scPRINT model. It handles the
|
|
26
|
+
organization and preparation of gene expression data from different organisms,
|
|
27
|
+
allowing for various configurations such as maximum gene list length, normalization,
|
|
28
|
+
and selection method for gene expression.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
organisms (list): List of organisms to be considered for gene expression data.
|
|
32
|
+
org_to_id (dict): Dictionary mapping organisms to their respective IDs.
|
|
33
|
+
labels (list, optional): List of labels for the data. Defaults to [].
|
|
34
|
+
valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
|
|
35
|
+
max_len (int, optional): Maximum length of the gene list. Defaults to 2000.
|
|
36
|
+
n_bins (int, optional): Number of bins for binning the data. Defaults to 0.
|
|
37
|
+
add_zero_genes (int, optional): Number of zero genes to add. Defaults to 200.
|
|
38
|
+
logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
|
|
39
|
+
norm_to (str, optional): Normalization method to be applied. Defaults to None.
|
|
40
|
+
how (str, optional): Method for selecting gene expression. Defaults to "most expr".
|
|
41
|
+
"""
|
|
42
|
+
self.organisms = organisms
|
|
43
|
+
self.valid_genes = valid_genes
|
|
44
|
+
self.max_len = max_len
|
|
45
|
+
self.n_bins = n_bins
|
|
46
|
+
self.add_zero_genes = add_zero_genes
|
|
47
|
+
self.logp1 = logp1
|
|
48
|
+
self.norm_to = norm_to
|
|
49
|
+
self.org_to_id = org_to_id
|
|
50
|
+
self.how = how
|
|
51
|
+
self.organism_ids = (
|
|
52
|
+
set([org_to_id[k] for k in organisms])
|
|
53
|
+
if org_to_id is not None
|
|
54
|
+
else set(organisms)
|
|
55
|
+
)
|
|
56
|
+
self.organism_name = organism_name
|
|
57
|
+
self.tp_name = tp_name
|
|
58
|
+
self.class_names = class_names
|
|
59
|
+
|
|
60
|
+
self.start_idx = {}
|
|
61
|
+
self.accepted_genes = {}
|
|
62
|
+
self.genedf = load_genes(organisms)
|
|
63
|
+
for organism in set(self.genedf.organism):
|
|
64
|
+
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
65
|
+
org = org_to_id[organism] if org_to_id is not None else organism
|
|
66
|
+
self.start_idx.update(
|
|
67
|
+
{org: np.where(self.genedf.organism == organism)[0][0]}
|
|
68
|
+
)
|
|
69
|
+
if len(valid_genes) > 0:
|
|
70
|
+
self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
|
|
71
|
+
|
|
72
|
+
def __call__(self, batch):
|
|
73
|
+
"""
|
|
74
|
+
__call__ is a special method in Python that is called when an instance of the class is called.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
78
|
+
the first list is for the different samples, the second list is for the different elements with
|
|
79
|
+
elem["x"]: gene expression
|
|
80
|
+
elem["organism_name"]: organism ontology term id
|
|
81
|
+
elem["tp_name"]: heat diff
|
|
82
|
+
elem["class_names.."]: other classes
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
list[Tensor]: List of tensors containing the collated data.
|
|
86
|
+
"""
|
|
87
|
+
# do count selection
|
|
88
|
+
# get the unseen info and don't add any unseen
|
|
89
|
+
# get the I most expressed genes, add randomly some unexpressed genes that are not unseen
|
|
90
|
+
exprs = []
|
|
91
|
+
total_count = []
|
|
92
|
+
other_classes = []
|
|
93
|
+
gene_locs = []
|
|
94
|
+
tp = []
|
|
95
|
+
for elem in batch:
|
|
96
|
+
organism_id = elem[self.organism_name]
|
|
97
|
+
if organism_id not in self.organism_ids:
|
|
98
|
+
continue
|
|
99
|
+
expr = np.array(elem["x"])
|
|
100
|
+
total_count.append(expr.sum())
|
|
101
|
+
if len(self.accepted_genes) > 0:
|
|
102
|
+
expr = expr[self.accepted_genes[organism_id]]
|
|
103
|
+
if self.how == "most expr":
|
|
104
|
+
loc = np.argsort(expr)[-(self.max_len) :][::-1]
|
|
105
|
+
elif self.how == "random expr":
|
|
106
|
+
nnz_loc = np.where(expr > 0)[0]
|
|
107
|
+
loc = nnz_loc[
|
|
108
|
+
np.random.choice(len(nnz_loc), self.max_len, replace=False)
|
|
109
|
+
]
|
|
110
|
+
elif self.how == "all":
|
|
111
|
+
loc = np.arange(len(expr))
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError("how must be either most expr or random expr")
|
|
114
|
+
if self.add_zero_genes > 0 and self.how != "all":
|
|
115
|
+
zero_loc = np.where(expr == 0)[0]
|
|
116
|
+
zero_loc = [
|
|
117
|
+
np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
|
|
118
|
+
]
|
|
119
|
+
loc = np.concatenate((loc, zero_loc), axis=None)
|
|
120
|
+
exprs.append(expr[loc])
|
|
121
|
+
gene_locs.append(loc + self.start_idx[organism_id])
|
|
122
|
+
|
|
123
|
+
if self.tp_name is not None:
|
|
124
|
+
tp.append(elem[self.tp_name])
|
|
125
|
+
else:
|
|
126
|
+
tp.append(0)
|
|
127
|
+
|
|
128
|
+
other_classes.append([elem[i] for i in self.class_names])
|
|
129
|
+
|
|
130
|
+
expr = np.array(exprs)
|
|
131
|
+
tp = np.array(tp)
|
|
132
|
+
gene_locs = np.array(gene_locs)
|
|
133
|
+
total_count = np.array(total_count)
|
|
134
|
+
other_classes = np.array(other_classes)
|
|
135
|
+
|
|
136
|
+
# normalize counts
|
|
137
|
+
if self.norm_to is not None:
|
|
138
|
+
expr = (expr * self.norm_to) / total_count[:, None]
|
|
139
|
+
if self.logp1:
|
|
140
|
+
expr = np.log2(1 + expr)
|
|
141
|
+
|
|
142
|
+
# do binning of counts
|
|
143
|
+
if self.n_bins:
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
# find the associated gene ids (given the species)
|
|
147
|
+
|
|
148
|
+
# get the NN cells
|
|
149
|
+
|
|
150
|
+
# do encoding / selection a la scGPT
|
|
151
|
+
|
|
152
|
+
# do encoding of graph location
|
|
153
|
+
# encode all the edges in some sparse way
|
|
154
|
+
# normalizing total counts between 0,1
|
|
155
|
+
return {
|
|
156
|
+
"x": Tensor(expr),
|
|
157
|
+
"genes": Tensor(gene_locs).int(),
|
|
158
|
+
"class": Tensor(other_classes).int(),
|
|
159
|
+
"tp": Tensor(tp),
|
|
160
|
+
"depth": Tensor(total_count),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class AnnDataCollator(Collator):
|
|
165
|
+
def __init__(self, *args, **kwargs):
|
|
166
|
+
super().__init__(*args, **kwargs)
|
|
167
|
+
|
|
168
|
+
def __call__(self, batch):
|
|
169
|
+
exprs = []
|
|
170
|
+
total_count = []
|
|
171
|
+
other_classes = []
|
|
172
|
+
gene_locs = []
|
|
173
|
+
tp = []
|
|
174
|
+
for elem in batch:
|
|
175
|
+
organism_id = elem.obs[self.organism_name]
|
|
176
|
+
if organism_id.item() not in self.organism_ids:
|
|
177
|
+
print(organism_id)
|
|
178
|
+
expr = np.array(elem.X[0])
|
|
179
|
+
|
|
180
|
+
total_count.append(expr.sum())
|
|
181
|
+
if len(self.accepted_genes) > 0:
|
|
182
|
+
expr = expr[self.accepted_genes[organism_id]]
|
|
183
|
+
if self.how == "most expr":
|
|
184
|
+
loc = np.argsort(expr)[-(self.max_len) :][::-1]
|
|
185
|
+
elif self.how == "random expr":
|
|
186
|
+
nnz_loc = np.where(expr > 0)[0]
|
|
187
|
+
loc = nnz_loc[
|
|
188
|
+
np.random.choice(len(nnz_loc), self.max_len, replace=False)
|
|
189
|
+
]
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError("how must be either most expr or random expr")
|
|
192
|
+
if self.add_zero_genes > 0:
|
|
193
|
+
zero_loc = np.where(expr == 0)[0]
|
|
194
|
+
zero_loc = [
|
|
195
|
+
np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
|
|
196
|
+
]
|
|
197
|
+
loc = np.concatenate((loc, zero_loc), axis=None)
|
|
198
|
+
exprs.append(expr[loc])
|
|
199
|
+
gene_locs.append(loc + self.start_idx[organism_id.item()])
|
|
200
|
+
|
|
201
|
+
if self.tp_name is not None:
|
|
202
|
+
tp.append(elem.obs[self.tp_name])
|
|
203
|
+
else:
|
|
204
|
+
tp.append(0)
|
|
205
|
+
|
|
206
|
+
other_classes.append([elem.obs[i].values[0] for i in self.class_names])
|
|
207
|
+
|
|
208
|
+
expr = np.array(exprs)
|
|
209
|
+
tp = np.array(tp)
|
|
210
|
+
gene_locs = np.array(gene_locs)
|
|
211
|
+
total_count = np.array(total_count)
|
|
212
|
+
other_classes = np.array(other_classes)
|
|
213
|
+
return {
|
|
214
|
+
"x": Tensor(expr),
|
|
215
|
+
"genes": Tensor(gene_locs).int(),
|
|
216
|
+
"depth": Tensor(total_count),
|
|
217
|
+
"class": Tensor(other_classes),
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class SCVICollator(Collator):
|
|
222
|
+
def __init__(self, *args, **kwargs):
|
|
223
|
+
super().__init__(*args, **kwargs)
|
|
224
|
+
|
|
225
|
+
def __call__(self, batch):
|
|
226
|
+
expr = batch["x"]
|
|
227
|
+
total_count = expr.sum(axis=1)
|
|
228
|
+
if self.how == "most expr":
|
|
229
|
+
loc = np.argsort(expr)[:, -(self.max_len) :][:, ::-1]
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError("how must be either most expr or random expr")
|
|
232
|
+
if self.logp1:
|
|
233
|
+
expr = np.log2(1 + expr)
|
|
234
|
+
return {
|
|
235
|
+
"x": Tensor(expr[np.arange(expr.shape[0])[:, None], loc]),
|
|
236
|
+
"genes": Tensor(loc.copy()).int(),
|
|
237
|
+
"depth": Tensor(total_count),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GeneformerCollator(Collator):
|
|
242
|
+
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
243
|
+
super().__init__(*args, **kwargs)
|
|
244
|
+
self.gene_norm_list = gene_norm_list
|
|
245
|
+
|
|
246
|
+
def __call__(self, batch):
|
|
247
|
+
super().__call__(batch)
|
|
248
|
+
# normlization per gene
|
|
249
|
+
|
|
250
|
+
# tokenize the empty locations
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class scGPTCollator(Collator):
|
|
254
|
+
def __call__(self, batch):
|
|
255
|
+
super().__call__(batch)
|
|
256
|
+
# binning
|
|
257
|
+
|
|
258
|
+
# tokenize the empty locations
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class scPRINTCollator(Collator):
|
|
262
|
+
def __call__(self, batch):
|
|
263
|
+
super().__call__(batch)
|