scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +66 -42
- scdataloader/collator.py +136 -67
- scdataloader/config.py +112 -0
- scdataloader/data.py +160 -169
- scdataloader/datamodule.py +403 -0
- scdataloader/mapped.py +285 -109
- scdataloader/preprocess.py +240 -109
- scdataloader/utils.py +162 -70
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/METADATA +87 -18
- scdataloader-1.0.1.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-1.0.1.dist-info}/entry_points.txt +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
1.0.0
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from scdataloader.preprocess import
|
|
2
|
+
from scdataloader.preprocess import (
|
|
3
|
+
LaminPreprocessor,
|
|
4
|
+
additional_preprocess,
|
|
5
|
+
additional_postprocess,
|
|
6
|
+
)
|
|
3
7
|
import lamindb as ln
|
|
4
8
|
from typing import Optional, Union
|
|
5
9
|
|
|
10
|
+
|
|
11
|
+
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
|
|
6
12
|
def main():
|
|
13
|
+
"""
|
|
14
|
+
main function to preprocess datasets in a given lamindb collection.
|
|
15
|
+
"""
|
|
7
16
|
parser = argparse.ArgumentParser(
|
|
8
17
|
description="Preprocess datasets in a given lamindb collection."
|
|
9
18
|
)
|
|
@@ -11,22 +20,31 @@ def main():
|
|
|
11
20
|
"--name", type=str, required=True, help="Name of the input dataset"
|
|
12
21
|
)
|
|
13
22
|
parser.add_argument(
|
|
14
|
-
"--new_name",
|
|
23
|
+
"--new_name",
|
|
24
|
+
type=str,
|
|
25
|
+
default="preprocessed dataset",
|
|
26
|
+
help="Name of the preprocessed dataset.",
|
|
15
27
|
)
|
|
16
28
|
parser.add_argument(
|
|
17
29
|
"--description",
|
|
18
30
|
type=str,
|
|
19
|
-
default="preprocessed by scDataLoader"
|
|
31
|
+
default="preprocessed by scDataLoader",
|
|
20
32
|
help="Description of the preprocessed dataset.",
|
|
21
33
|
)
|
|
22
34
|
parser.add_argument(
|
|
23
35
|
"--start_at", type=int, default=0, help="Position to start preprocessing at."
|
|
24
36
|
)
|
|
25
37
|
parser.add_argument(
|
|
26
|
-
"--new_version",
|
|
38
|
+
"--new_version",
|
|
39
|
+
type=str,
|
|
40
|
+
default="2",
|
|
41
|
+
help="Version of the output dataset and files.",
|
|
27
42
|
)
|
|
28
43
|
parser.add_argument(
|
|
29
|
-
"--instance",
|
|
44
|
+
"--instance",
|
|
45
|
+
type=str,
|
|
46
|
+
default=None,
|
|
47
|
+
help="Instance storing the input dataset, if not local",
|
|
30
48
|
)
|
|
31
49
|
parser.add_argument(
|
|
32
50
|
"--version", type=str, default=None, help="Version of the input dataset."
|
|
@@ -35,125 +53,127 @@ def main():
|
|
|
35
53
|
"--filter_gene_by_counts",
|
|
36
54
|
type=Union[int, bool],
|
|
37
55
|
default=False,
|
|
38
|
-
help="Determines whether to filter genes by counts."
|
|
56
|
+
help="Determines whether to filter genes by counts.",
|
|
39
57
|
)
|
|
40
58
|
parser.add_argument(
|
|
41
59
|
"--filter_cell_by_counts",
|
|
42
60
|
type=Union[int, bool],
|
|
43
61
|
default=False,
|
|
44
|
-
help="Determines whether to filter cells by counts."
|
|
62
|
+
help="Determines whether to filter cells by counts.",
|
|
45
63
|
)
|
|
46
64
|
parser.add_argument(
|
|
47
65
|
"--normalize_sum",
|
|
48
66
|
type=float,
|
|
49
67
|
default=1e4,
|
|
50
|
-
help="Determines whether to normalize the total counts of each cell to a specific value."
|
|
51
|
-
)
|
|
52
|
-
parser.add_argument(
|
|
53
|
-
"--keep_norm_layer",
|
|
54
|
-
type=bool,
|
|
55
|
-
default=False,
|
|
56
|
-
help="Determines whether to keep the normalization layer."
|
|
68
|
+
help="Determines whether to normalize the total counts of each cell to a specific value.",
|
|
57
69
|
)
|
|
58
70
|
parser.add_argument(
|
|
59
71
|
"--subset_hvg",
|
|
60
72
|
type=int,
|
|
61
73
|
default=0,
|
|
62
|
-
help="Determines whether to subset highly variable genes."
|
|
74
|
+
help="Determines whether to subset highly variable genes.",
|
|
63
75
|
)
|
|
64
76
|
parser.add_argument(
|
|
65
77
|
"--hvg_flavor",
|
|
66
78
|
type=str,
|
|
67
79
|
default="seurat_v3",
|
|
68
|
-
help="Specifies the flavor of highly variable genes selection."
|
|
80
|
+
help="Specifies the flavor of highly variable genes selection.",
|
|
69
81
|
)
|
|
70
82
|
parser.add_argument(
|
|
71
83
|
"--binning",
|
|
72
84
|
type=Optional[int],
|
|
73
85
|
default=None,
|
|
74
|
-
help="Determines whether to bin the data into discrete values of number of bins provided."
|
|
86
|
+
help="Determines whether to bin the data into discrete values of number of bins provided.",
|
|
75
87
|
)
|
|
76
88
|
parser.add_argument(
|
|
77
89
|
"--result_binned_key",
|
|
78
90
|
type=str,
|
|
79
91
|
default="X_binned",
|
|
80
|
-
help="Specifies the key of AnnData to store the binned data."
|
|
92
|
+
help="Specifies the key of AnnData to store the binned data.",
|
|
81
93
|
)
|
|
82
94
|
parser.add_argument(
|
|
83
95
|
"--length_normalize",
|
|
84
96
|
type=bool,
|
|
85
97
|
default=False,
|
|
86
|
-
help="Determines whether to normalize the length."
|
|
98
|
+
help="Determines whether to normalize the length.",
|
|
87
99
|
)
|
|
88
100
|
parser.add_argument(
|
|
89
101
|
"--force_preprocess",
|
|
90
102
|
type=bool,
|
|
91
103
|
default=False,
|
|
92
|
-
help="Determines whether to force preprocessing."
|
|
104
|
+
help="Determines whether to force preprocessing.",
|
|
93
105
|
)
|
|
94
106
|
parser.add_argument(
|
|
95
107
|
"--min_dataset_size",
|
|
96
108
|
type=int,
|
|
97
109
|
default=100,
|
|
98
|
-
help="Specifies the minimum dataset size."
|
|
110
|
+
help="Specifies the minimum dataset size.",
|
|
99
111
|
)
|
|
100
112
|
parser.add_argument(
|
|
101
113
|
"--min_valid_genes_id",
|
|
102
114
|
type=int,
|
|
103
115
|
default=10_000,
|
|
104
|
-
help="Specifies the minimum valid genes id."
|
|
116
|
+
help="Specifies the minimum valid genes id.",
|
|
105
117
|
)
|
|
106
118
|
parser.add_argument(
|
|
107
119
|
"--min_nnz_genes",
|
|
108
120
|
type=int,
|
|
109
|
-
default=
|
|
110
|
-
help="Specifies the minimum non-zero genes."
|
|
121
|
+
default=400,
|
|
122
|
+
help="Specifies the minimum non-zero genes.",
|
|
111
123
|
)
|
|
112
124
|
parser.add_argument(
|
|
113
125
|
"--maxdropamount",
|
|
114
126
|
type=int,
|
|
115
|
-
default=
|
|
116
|
-
help="Specifies the maximum drop amount."
|
|
127
|
+
default=50,
|
|
128
|
+
help="Specifies the maximum drop amount.",
|
|
117
129
|
)
|
|
118
130
|
parser.add_argument(
|
|
119
|
-
"--madoutlier",
|
|
120
|
-
type=int,
|
|
121
|
-
default=5,
|
|
122
|
-
help="Specifies the MAD outlier."
|
|
131
|
+
"--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
|
|
123
132
|
)
|
|
124
133
|
parser.add_argument(
|
|
125
134
|
"--pct_mt_outlier",
|
|
126
135
|
type=int,
|
|
127
136
|
default=8,
|
|
128
|
-
help="Specifies the percentage of MT outlier."
|
|
137
|
+
help="Specifies the percentage of MT outlier.",
|
|
129
138
|
)
|
|
130
139
|
parser.add_argument(
|
|
131
|
-
"--batch_key",
|
|
132
|
-
type=Optional[str],
|
|
133
|
-
default=None,
|
|
134
|
-
help="Specifies the batch key."
|
|
140
|
+
"--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
|
|
135
141
|
)
|
|
136
142
|
parser.add_argument(
|
|
137
143
|
"--skip_validate",
|
|
138
144
|
type=bool,
|
|
139
145
|
default=False,
|
|
140
|
-
help="Determines whether to skip validation."
|
|
146
|
+
help="Determines whether to skip validation.",
|
|
147
|
+
)
|
|
148
|
+
parser.add_argument(
|
|
149
|
+
"--do_postp",
|
|
150
|
+
type=bool,
|
|
151
|
+
default=False,
|
|
152
|
+
help="Determines whether to do postprocessing.",
|
|
141
153
|
)
|
|
142
154
|
args = parser.parse_args()
|
|
143
155
|
|
|
144
156
|
# Load the collection
|
|
157
|
+
# if not args.preprocess:
|
|
158
|
+
# print("Only preprocess is available for now")
|
|
159
|
+
# return
|
|
145
160
|
if args.instance is not None:
|
|
146
|
-
collection =
|
|
147
|
-
|
|
148
|
-
|
|
161
|
+
collection = (
|
|
162
|
+
ln.Collection.using(instance=args.instance)
|
|
163
|
+
.filter(name=args.name, version=args.version)
|
|
164
|
+
.first()
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
collection = ln.Collection.filter(name=args.name, version=args.version).first()
|
|
149
168
|
|
|
150
|
-
print(
|
|
169
|
+
print(
|
|
170
|
+
"using the dataset ", collection, " of size ", len(collection.artifacts.all())
|
|
171
|
+
)
|
|
151
172
|
# Initialize the preprocessor
|
|
152
173
|
preprocessor = LaminPreprocessor(
|
|
153
174
|
filter_gene_by_counts=args.filter_gene_by_counts,
|
|
154
175
|
filter_cell_by_counts=args.filter_cell_by_counts,
|
|
155
176
|
normalize_sum=args.normalize_sum,
|
|
156
|
-
keep_norm_layer=args.keep_norm_layer,
|
|
157
177
|
subset_hvg=args.subset_hvg,
|
|
158
178
|
hvg_flavor=args.hvg_flavor,
|
|
159
179
|
binning=args.binning,
|
|
@@ -168,10 +188,14 @@ def main():
|
|
|
168
188
|
pct_mt_outlier=args.pct_mt_outlier,
|
|
169
189
|
batch_key=args.batch_key,
|
|
170
190
|
skip_validate=args.skip_validate,
|
|
191
|
+
do_postp=args.do_postp,
|
|
192
|
+
additional_preprocess=additional_preprocess,
|
|
193
|
+
additional_postprocess=additional_postprocess,
|
|
194
|
+
keep_files=False,
|
|
171
195
|
)
|
|
172
196
|
|
|
173
197
|
# Preprocess the dataset
|
|
174
|
-
|
|
198
|
+
preprocessor(
|
|
175
199
|
collection,
|
|
176
200
|
name=args.new_name,
|
|
177
201
|
description=args.description,
|
scdataloader/collator.py
CHANGED
|
@@ -1,25 +1,27 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from .utils import load_genes
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
|
|
5
|
-
# class SimpleCollator:
|
|
2
|
+
from .utils import load_genes, downsample_profile
|
|
3
|
+
from torch import Tensor, long
|
|
4
|
+
from typing import Optional
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class Collator:
|
|
9
8
|
def __init__(
|
|
10
9
|
self,
|
|
11
|
-
organisms: list,
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
add_zero_genes=
|
|
17
|
-
logp1=False,
|
|
18
|
-
norm_to=None,
|
|
19
|
-
|
|
20
|
-
tp_name=None,
|
|
21
|
-
organism_name="organism_ontology_term_id",
|
|
22
|
-
class_names=[],
|
|
10
|
+
organisms: list[str],
|
|
11
|
+
how: str = "all",
|
|
12
|
+
org_to_id: dict[str, int] = None,
|
|
13
|
+
valid_genes: list[str] = [],
|
|
14
|
+
max_len: int = 2000,
|
|
15
|
+
add_zero_genes: int = 0,
|
|
16
|
+
logp1: bool = False,
|
|
17
|
+
norm_to: Optional[float] = None,
|
|
18
|
+
n_bins: int = 0,
|
|
19
|
+
tp_name: Optional[str] = None,
|
|
20
|
+
organism_name: str = "organism_ontology_term_id",
|
|
21
|
+
class_names: list[str] = [],
|
|
22
|
+
genelist: list[str] = [],
|
|
23
|
+
downsample: Optional[float] = None, # don't use it for training!
|
|
24
|
+
save_output: bool = False,
|
|
23
25
|
):
|
|
24
26
|
"""
|
|
25
27
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -27,51 +29,81 @@ class Collator:
|
|
|
27
29
|
allowing for various configurations such as maximum gene list length, normalization,
|
|
28
30
|
and selection method for gene expression.
|
|
29
31
|
|
|
32
|
+
This Collator should work with scVI's dataloader as well!
|
|
33
|
+
|
|
30
34
|
Args:
|
|
31
35
|
organisms (list): List of organisms to be considered for gene expression data.
|
|
36
|
+
it will drop any other organism it sees (might lead to batches of different sizes!)
|
|
37
|
+
how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
|
|
38
|
+
one of ["most expr", "random expr", "all", "some"]:
|
|
39
|
+
"most expr": selects the max_len most expressed genes,
|
|
40
|
+
if less genes are expressed, will sample random unexpressed genes,
|
|
41
|
+
"random expr": uses a random set of max_len expressed genes.
|
|
42
|
+
if less genes are expressed, will sample random unexpressed genes
|
|
43
|
+
"all": uses all genes
|
|
44
|
+
"some": uses only the genes provided through the genelist param
|
|
32
45
|
org_to_id (dict): Dictionary mapping organisms to their respective IDs.
|
|
33
|
-
labels (list, optional): List of labels for the data. Defaults to [].
|
|
34
46
|
valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
47
|
+
it will drop any other genes from the input expression data (usefull when your model only works on some genes)
|
|
48
|
+
max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
|
|
49
|
+
n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
|
|
50
|
+
add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
|
|
38
51
|
logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
|
|
39
|
-
norm_to (
|
|
40
|
-
|
|
52
|
+
norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
|
|
53
|
+
organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
|
|
54
|
+
tp_name (str, optional): Name of the heat diff. Defaults to None.
|
|
55
|
+
class_names (list, optional): List of other classes to be considered. Defaults to [].
|
|
56
|
+
genelist (list, optional): List of genes to be considered. Defaults to [].
|
|
57
|
+
If [] all genes will be considered
|
|
58
|
+
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
|
|
59
|
+
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
|
|
60
|
+
save_output (bool, optional): If True, saves the output to a file. Defaults to False.
|
|
61
|
+
This is mainly for debugging purposes
|
|
41
62
|
"""
|
|
42
63
|
self.organisms = organisms
|
|
43
|
-
self.
|
|
64
|
+
self.genedf = load_genes(organisms)
|
|
44
65
|
self.max_len = max_len
|
|
45
66
|
self.n_bins = n_bins
|
|
46
67
|
self.add_zero_genes = add_zero_genes
|
|
47
68
|
self.logp1 = logp1
|
|
48
69
|
self.norm_to = norm_to
|
|
49
|
-
self.org_to_id = org_to_id
|
|
50
70
|
self.how = how
|
|
51
|
-
self.
|
|
52
|
-
|
|
53
|
-
if org_to_id is not None
|
|
54
|
-
else set(organisms)
|
|
55
|
-
)
|
|
71
|
+
if self.how == "some":
|
|
72
|
+
assert len(genelist) > 0, "if how is some, genelist must be provided"
|
|
56
73
|
self.organism_name = organism_name
|
|
57
74
|
self.tp_name = tp_name
|
|
58
75
|
self.class_names = class_names
|
|
59
|
-
|
|
76
|
+
self.save_output = save_output
|
|
60
77
|
self.start_idx = {}
|
|
61
78
|
self.accepted_genes = {}
|
|
62
|
-
self.
|
|
63
|
-
|
|
79
|
+
self.downsample = downsample
|
|
80
|
+
self.to_subset = {}
|
|
81
|
+
self._setup(org_to_id, valid_genes, genelist)
|
|
82
|
+
|
|
83
|
+
def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
|
|
84
|
+
self.org_to_id = org_to_id
|
|
85
|
+
self.to_subset = {}
|
|
86
|
+
self.accepted_genes = {}
|
|
87
|
+
self.start_idx = {}
|
|
88
|
+
self.organism_ids = (
|
|
89
|
+
set([org_to_id[k] for k in self.organisms])
|
|
90
|
+
if org_to_id is not None
|
|
91
|
+
else set(self.organisms)
|
|
92
|
+
)
|
|
93
|
+
for organism in self.organisms:
|
|
64
94
|
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
95
|
+
tot = self.genedf[self.genedf.index.isin(valid_genes)]
|
|
65
96
|
org = org_to_id[organism] if org_to_id is not None else organism
|
|
66
|
-
self.start_idx.update(
|
|
67
|
-
{org: np.where(self.genedf.organism == organism)[0][0]}
|
|
68
|
-
)
|
|
97
|
+
self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
|
|
69
98
|
if len(valid_genes) > 0:
|
|
70
99
|
self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
|
|
100
|
+
if len(genelist) > 0:
|
|
101
|
+
df = ogenedf[ogenedf.index.isin(valid_genes)]
|
|
102
|
+
self.to_subset.update({org: df.index.isin(genelist)})
|
|
71
103
|
|
|
72
|
-
def __call__(self, batch):
|
|
104
|
+
def __call__(self, batch) -> dict[str, Tensor]:
|
|
73
105
|
"""
|
|
74
|
-
__call__
|
|
106
|
+
__call__ applies the collator to a minibatch of data
|
|
75
107
|
|
|
76
108
|
Args:
|
|
77
109
|
batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
@@ -92,33 +124,62 @@ class Collator:
|
|
|
92
124
|
other_classes = []
|
|
93
125
|
gene_locs = []
|
|
94
126
|
tp = []
|
|
127
|
+
dataset = []
|
|
128
|
+
nnz_loc = []
|
|
95
129
|
for elem in batch:
|
|
96
130
|
organism_id = elem[self.organism_name]
|
|
97
131
|
if organism_id not in self.organism_ids:
|
|
98
132
|
continue
|
|
133
|
+
if "_storage_idx" in elem:
|
|
134
|
+
dataset.append(elem["_storage_idx"])
|
|
99
135
|
expr = np.array(elem["x"])
|
|
100
136
|
total_count.append(expr.sum())
|
|
101
137
|
if len(self.accepted_genes) > 0:
|
|
102
138
|
expr = expr[self.accepted_genes[organism_id]]
|
|
103
139
|
if self.how == "most expr":
|
|
104
|
-
|
|
140
|
+
nnz_loc = np.where(expr > 0)[0]
|
|
141
|
+
ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
|
|
142
|
+
loc = np.argsort(expr)[-(ma):][::-1]
|
|
143
|
+
# nnz_loc = [1] * 30_000
|
|
144
|
+
# loc = np.argsort(expr)[-(self.max_len) :][::-1]
|
|
105
145
|
elif self.how == "random expr":
|
|
106
146
|
nnz_loc = np.where(expr > 0)[0]
|
|
107
147
|
loc = nnz_loc[
|
|
108
|
-
np.random.choice(
|
|
148
|
+
np.random.choice(
|
|
149
|
+
len(nnz_loc),
|
|
150
|
+
self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
|
|
151
|
+
replace=False,
|
|
152
|
+
# p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
|
|
153
|
+
)
|
|
109
154
|
]
|
|
110
|
-
elif self.how
|
|
155
|
+
elif self.how in ["all", "some"]:
|
|
111
156
|
loc = np.arange(len(expr))
|
|
112
157
|
else:
|
|
113
158
|
raise ValueError("how must be either most expr or random expr")
|
|
114
|
-
if
|
|
159
|
+
if (
|
|
160
|
+
(self.add_zero_genes > 0) or (self.max_len > len(nnz_loc))
|
|
161
|
+
) and self.how not in ["all", "some"]:
|
|
115
162
|
zero_loc = np.where(expr == 0)[0]
|
|
116
|
-
zero_loc = [
|
|
117
|
-
np.random.choice(
|
|
163
|
+
zero_loc = zero_loc[
|
|
164
|
+
np.random.choice(
|
|
165
|
+
len(zero_loc),
|
|
166
|
+
self.add_zero_genes
|
|
167
|
+
+ (
|
|
168
|
+
0
|
|
169
|
+
if self.max_len < len(nnz_loc)
|
|
170
|
+
else self.max_len - len(nnz_loc)
|
|
171
|
+
),
|
|
172
|
+
replace=False,
|
|
173
|
+
)
|
|
118
174
|
]
|
|
119
175
|
loc = np.concatenate((loc, zero_loc), axis=None)
|
|
120
|
-
|
|
121
|
-
|
|
176
|
+
expr = expr[loc]
|
|
177
|
+
loc = loc + self.start_idx[organism_id]
|
|
178
|
+
if self.how == "some":
|
|
179
|
+
expr = expr[self.to_subset[organism_id]]
|
|
180
|
+
loc = loc[self.to_subset[organism_id]]
|
|
181
|
+
exprs.append(expr)
|
|
182
|
+
gene_locs.append(loc)
|
|
122
183
|
|
|
123
184
|
if self.tp_name is not None:
|
|
124
185
|
tp.append(elem[self.tp_name])
|
|
@@ -132,6 +193,7 @@ class Collator:
|
|
|
132
193
|
gene_locs = np.array(gene_locs)
|
|
133
194
|
total_count = np.array(total_count)
|
|
134
195
|
other_classes = np.array(other_classes)
|
|
196
|
+
dataset = np.array(dataset)
|
|
135
197
|
|
|
136
198
|
# normalize counts
|
|
137
199
|
if self.norm_to is not None:
|
|
@@ -152,20 +214,34 @@ class Collator:
|
|
|
152
214
|
# do encoding of graph location
|
|
153
215
|
# encode all the edges in some sparse way
|
|
154
216
|
# normalizing total counts between 0,1
|
|
155
|
-
|
|
217
|
+
ret = {
|
|
156
218
|
"x": Tensor(expr),
|
|
157
219
|
"genes": Tensor(gene_locs).int(),
|
|
158
220
|
"class": Tensor(other_classes).int(),
|
|
159
221
|
"tp": Tensor(tp),
|
|
160
222
|
"depth": Tensor(total_count),
|
|
161
223
|
}
|
|
224
|
+
if len(dataset) > 0:
|
|
225
|
+
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
226
|
+
if self.downsample is not None:
|
|
227
|
+
ret["x"] = downsample_profile(ret["x"], self.downsample)
|
|
228
|
+
if self.save_output:
|
|
229
|
+
with open("collator_output.txt", "a") as f:
|
|
230
|
+
np.savetxt(f, ret["x"].numpy())
|
|
231
|
+
return ret
|
|
162
232
|
|
|
163
233
|
|
|
164
234
|
class AnnDataCollator(Collator):
|
|
165
235
|
def __init__(self, *args, **kwargs):
|
|
236
|
+
"""
|
|
237
|
+
AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
@see Collator
|
|
241
|
+
"""
|
|
166
242
|
super().__init__(*args, **kwargs)
|
|
167
243
|
|
|
168
|
-
def __call__(self, batch):
|
|
244
|
+
def __call__(self, batch) -> dict[str, Tensor]:
|
|
169
245
|
exprs = []
|
|
170
246
|
total_count = []
|
|
171
247
|
other_classes = []
|
|
@@ -218,28 +294,17 @@ class AnnDataCollator(Collator):
|
|
|
218
294
|
}
|
|
219
295
|
|
|
220
296
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def __call__(self, batch):
|
|
226
|
-
expr = batch["x"]
|
|
227
|
-
total_count = expr.sum(axis=1)
|
|
228
|
-
if self.how == "most expr":
|
|
229
|
-
loc = np.argsort(expr)[:, -(self.max_len) :][:, ::-1]
|
|
230
|
-
else:
|
|
231
|
-
raise ValueError("how must be either most expr or random expr")
|
|
232
|
-
if self.logp1:
|
|
233
|
-
expr = np.log2(1 + expr)
|
|
234
|
-
return {
|
|
235
|
-
"x": Tensor(expr[np.arange(expr.shape[0])[:, None], loc]),
|
|
236
|
-
"genes": Tensor(loc.copy()).int(),
|
|
237
|
-
"depth": Tensor(total_count),
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
|
|
297
|
+
#############
|
|
298
|
+
#### WIP ####
|
|
299
|
+
#############
|
|
241
300
|
class GeneformerCollator(Collator):
|
|
242
301
|
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
302
|
+
"""
|
|
303
|
+
GeneformerCollator to finish
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
gene_norm_list (list): the normalization of expression through all datasets, per gene.
|
|
307
|
+
"""
|
|
243
308
|
super().__init__(*args, **kwargs)
|
|
244
309
|
self.gene_norm_list = gene_norm_list
|
|
245
310
|
|
|
@@ -251,6 +316,10 @@ class GeneformerCollator(Collator):
|
|
|
251
316
|
|
|
252
317
|
|
|
253
318
|
class scGPTCollator(Collator):
|
|
319
|
+
"""
|
|
320
|
+
scGPTCollator to finish
|
|
321
|
+
"""
|
|
322
|
+
|
|
254
323
|
def __call__(self, batch):
|
|
255
324
|
super().__call__(batch)
|
|
256
325
|
# binning
|
scdataloader/config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration file for scDataLoader
|
|
3
|
+
|
|
4
|
+
Missing labels are added to the dataset to complete a better hierarchical tree
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
LABELS_TOADD = {
|
|
8
|
+
"assay_ontology_term_id": {
|
|
9
|
+
"10x transcription profiling": "EFO:0030003",
|
|
10
|
+
"spatial transcriptomics": "EFO:0008994",
|
|
11
|
+
"10x 3' transcription profiling": "EFO:0030003",
|
|
12
|
+
"10x 5' transcription profiling": "EFO:0030004",
|
|
13
|
+
},
|
|
14
|
+
"disease_ontology_term_id": {
|
|
15
|
+
"metabolic disease": "MONDO:0005066",
|
|
16
|
+
"chronic kidney disease": "MONDO:0005300",
|
|
17
|
+
"chromosomal disorder": "MONDO:0019040",
|
|
18
|
+
"infectious disease": "MONDO:0005550",
|
|
19
|
+
"inflammatory disease": "MONDO:0021166",
|
|
20
|
+
# "immune system disease",
|
|
21
|
+
"disorder of development or morphogenesis": "MONDO:0021147",
|
|
22
|
+
"mitochondrial disease": "MONDO:0044970",
|
|
23
|
+
"psychiatric disorder": "MONDO:0002025",
|
|
24
|
+
"cancer or benign tumor": "MONDO:0002025",
|
|
25
|
+
"neoplasm": "MONDO:0005070",
|
|
26
|
+
},
|
|
27
|
+
"cell_type_ontology_term_id": {
|
|
28
|
+
"progenitor cell": "CL:0011026",
|
|
29
|
+
"hematopoietic cell": "CL:0000988",
|
|
30
|
+
"myoblast": "CL:0000056",
|
|
31
|
+
"myeloid cell": "CL:0000763",
|
|
32
|
+
"neuron": "CL:0000540",
|
|
33
|
+
"electrically active cell": "CL:0000211",
|
|
34
|
+
"epithelial cell": "CL:0000066",
|
|
35
|
+
"secretory cell": "CL:0000151",
|
|
36
|
+
"stem cell": "CL:0000034",
|
|
37
|
+
"non-terminally differentiated cell": "CL:0000055",
|
|
38
|
+
"supporting cell": "CL:0000630",
|
|
39
|
+
},
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
COARSE_TISSUE = {
|
|
43
|
+
"adipose tissue": "",
|
|
44
|
+
"bladder organ": "",
|
|
45
|
+
"blood": "",
|
|
46
|
+
"bone marrow": "",
|
|
47
|
+
"brain": "",
|
|
48
|
+
"breast": "",
|
|
49
|
+
"esophagus": "",
|
|
50
|
+
"eye": "",
|
|
51
|
+
"embryo": "",
|
|
52
|
+
"fallopian tube": "",
|
|
53
|
+
"gall bladder": "",
|
|
54
|
+
"heart": "",
|
|
55
|
+
"intestine": "",
|
|
56
|
+
"kidney": "",
|
|
57
|
+
"liver": "",
|
|
58
|
+
"lung": "",
|
|
59
|
+
"lymph node": "",
|
|
60
|
+
"musculature of body": "",
|
|
61
|
+
"nose": "",
|
|
62
|
+
"ovary": "",
|
|
63
|
+
"pancreas": "",
|
|
64
|
+
"placenta": "",
|
|
65
|
+
"skin of body": "",
|
|
66
|
+
"spinal cord": "",
|
|
67
|
+
"spleen": "",
|
|
68
|
+
"stomach": "",
|
|
69
|
+
"thymus": "",
|
|
70
|
+
"thyroid gland": "",
|
|
71
|
+
"tongue": "",
|
|
72
|
+
"uterus": "",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
COARSE_ANCESTRY = {
|
|
76
|
+
"African": "",
|
|
77
|
+
"Chinese": "",
|
|
78
|
+
"East Asian": "",
|
|
79
|
+
"Eskimo": "",
|
|
80
|
+
"European": "",
|
|
81
|
+
"Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
|
|
82
|
+
"Hispanic or Latin American": "",
|
|
83
|
+
"Native American": "",
|
|
84
|
+
"Oceanian": "",
|
|
85
|
+
"South Asian": "",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
COARSE_DEVELOPMENT_STAGE = {
|
|
89
|
+
"Embryonic human": "",
|
|
90
|
+
"Fetal": "",
|
|
91
|
+
"Immature": "",
|
|
92
|
+
"Mature": "",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
COARSE_ASSAY = {
|
|
96
|
+
"10x 3'": "",
|
|
97
|
+
"10x 5'": "",
|
|
98
|
+
"10x multiome": "",
|
|
99
|
+
"CEL-seq2": "",
|
|
100
|
+
"Drop-seq": "",
|
|
101
|
+
"GEXSCOPE technology": "",
|
|
102
|
+
"inDrop": "",
|
|
103
|
+
"microwell-seq": "",
|
|
104
|
+
"sci-Plex": "",
|
|
105
|
+
"sci-RNA-seq": "",
|
|
106
|
+
"Seq-Well": "",
|
|
107
|
+
"Slide-seq": "",
|
|
108
|
+
"Smart-seq": "",
|
|
109
|
+
"SPLiT-seq": "",
|
|
110
|
+
"TruDrop": "",
|
|
111
|
+
"Visium Spatial Gene Expression": "",
|
|
112
|
+
}
|