scdataloader 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
-
1
+ 0.7.0
scdataloader/__init__.py CHANGED
@@ -0,0 +1,4 @@
1
+ from .data import Dataset
2
+ from .datamodule import DataModule
3
+ from .preprocess import Preprocessor
4
+ from .collator import *
@@ -0,0 +1,209 @@
1
+ import argparse
2
+ from scdataloader.preprocess import (
3
+ LaminPreprocessor,
4
+ additional_preprocess,
5
+ additional_postprocess,
6
+ )
7
+ import lamindb as ln
8
+ from typing import Optional, Union
9
+
10
+
11
+ # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
12
+ def main():
13
+ parser = argparse.ArgumentParser(
14
+ description="Preprocess datasets in a given lamindb collection."
15
+ )
16
+ parser.add_argument(
17
+ "--name", type=str, required=True, help="Name of the input dataset"
18
+ )
19
+ parser.add_argument(
20
+ "--new_name",
21
+ type=str,
22
+ default="preprocessed dataset",
23
+ help="Name of the preprocessed dataset.",
24
+ )
25
+ parser.add_argument(
26
+ "--description",
27
+ type=str,
28
+ default="preprocessed by scDataLoader",
29
+ help="Description of the preprocessed dataset.",
30
+ )
31
+ parser.add_argument(
32
+ "--start_at", type=int, default=0, help="Position to start preprocessing at."
33
+ )
34
+ parser.add_argument(
35
+ "--new_version",
36
+ type=str,
37
+ default="2",
38
+ help="Version of the output dataset and files.",
39
+ )
40
+ parser.add_argument(
41
+ "--instance",
42
+ type=str,
43
+ default=None,
44
+ help="Instance storing the input dataset, if not local",
45
+ )
46
+ parser.add_argument(
47
+ "--version", type=str, default=None, help="Version of the input dataset."
48
+ )
49
+ parser.add_argument(
50
+ "--filter_gene_by_counts",
51
+ type=Union[int, bool],
52
+ default=False,
53
+ help="Determines whether to filter genes by counts.",
54
+ )
55
+ parser.add_argument(
56
+ "--filter_cell_by_counts",
57
+ type=Union[int, bool],
58
+ default=False,
59
+ help="Determines whether to filter cells by counts.",
60
+ )
61
+ parser.add_argument(
62
+ "--normalize_sum",
63
+ type=float,
64
+ default=1e4,
65
+ help="Determines whether to normalize the total counts of each cell to a specific value.",
66
+ )
67
+ parser.add_argument(
68
+ "--subset_hvg",
69
+ type=int,
70
+ default=0,
71
+ help="Determines whether to subset highly variable genes.",
72
+ )
73
+ parser.add_argument(
74
+ "--hvg_flavor",
75
+ type=str,
76
+ default="seurat_v3",
77
+ help="Specifies the flavor of highly variable genes selection.",
78
+ )
79
+ parser.add_argument(
80
+ "--binning",
81
+ type=Optional[int],
82
+ default=None,
83
+ help="Determines whether to bin the data into discrete values of number of bins provided.",
84
+ )
85
+ parser.add_argument(
86
+ "--result_binned_key",
87
+ type=str,
88
+ default="X_binned",
89
+ help="Specifies the key of AnnData to store the binned data.",
90
+ )
91
+ parser.add_argument(
92
+ "--length_normalize",
93
+ type=bool,
94
+ default=False,
95
+ help="Determines whether to normalize the length.",
96
+ )
97
+ parser.add_argument(
98
+ "--force_preprocess",
99
+ type=bool,
100
+ default=False,
101
+ help="Determines whether to force preprocessing.",
102
+ )
103
+ parser.add_argument(
104
+ "--min_dataset_size",
105
+ type=int,
106
+ default=100,
107
+ help="Specifies the minimum dataset size.",
108
+ )
109
+ parser.add_argument(
110
+ "--min_valid_genes_id",
111
+ type=int,
112
+ default=10_000,
113
+ help="Specifies the minimum valid genes id.",
114
+ )
115
+ parser.add_argument(
116
+ "--min_nnz_genes",
117
+ type=int,
118
+ default=400,
119
+ help="Specifies the minimum non-zero genes.",
120
+ )
121
+ parser.add_argument(
122
+ "--maxdropamount",
123
+ type=int,
124
+ default=50,
125
+ help="Specifies the maximum drop amount.",
126
+ )
127
+ parser.add_argument(
128
+ "--madoutlier", type=int, default=5, help="Specifies the MAD outlier."
129
+ )
130
+ parser.add_argument(
131
+ "--pct_mt_outlier",
132
+ type=int,
133
+ default=8,
134
+ help="Specifies the percentage of MT outlier.",
135
+ )
136
+ parser.add_argument(
137
+ "--batch_key", type=Optional[str], default=None, help="Specifies the batch key."
138
+ )
139
+ parser.add_argument(
140
+ "--skip_validate",
141
+ type=bool,
142
+ default=False,
143
+ help="Determines whether to skip validation.",
144
+ )
145
+ parser.add_argument(
146
+ "--do_postp",
147
+ type=bool,
148
+ default=False,
149
+ help="Determines whether to do postprocessing.",
150
+ )
151
+ args = parser.parse_args()
152
+
153
+ # Load the collection
154
+ # if not args.preprocess:
155
+ # print("Only preprocess is available for now")
156
+ # return
157
+ if args.instance is not None:
158
+ collection = (
159
+ ln.Collection.using(instance=args.instance)
160
+ .filter(name=args.name, version=args.version)
161
+ .first()
162
+ )
163
+ else:
164
+ collection = ln.Collection.filter(name=args.name, version=args.version).first()
165
+
166
+ print(
167
+ "using the dataset ", collection, " of size ", len(collection.artifacts.all())
168
+ )
169
+ # Initialize the preprocessor
170
+ preprocessor = LaminPreprocessor(
171
+ filter_gene_by_counts=args.filter_gene_by_counts,
172
+ filter_cell_by_counts=args.filter_cell_by_counts,
173
+ normalize_sum=args.normalize_sum,
174
+ subset_hvg=args.subset_hvg,
175
+ hvg_flavor=args.hvg_flavor,
176
+ binning=args.binning,
177
+ result_binned_key=args.result_binned_key,
178
+ length_normalize=args.length_normalize,
179
+ force_preprocess=args.force_preprocess,
180
+ min_dataset_size=args.min_dataset_size,
181
+ min_valid_genes_id=args.min_valid_genes_id,
182
+ min_nnz_genes=args.min_nnz_genes,
183
+ maxdropamount=args.maxdropamount,
184
+ madoutlier=args.madoutlier,
185
+ pct_mt_outlier=args.pct_mt_outlier,
186
+ batch_key=args.batch_key,
187
+ skip_validate=args.skip_validate,
188
+ do_postp=args.do_postp,
189
+ additional_preprocess=additional_preprocess,
190
+ additional_postprocess=additional_postprocess,
191
+ keep_files=False,
192
+ )
193
+
194
+ # Preprocess the dataset
195
+ preprocessor(
196
+ collection,
197
+ name=args.new_name,
198
+ description=args.description,
199
+ start_at=args.start_at,
200
+ version=args.new_version,
201
+ )
202
+
203
+ print(
204
+ f"Preprocessed dataset saved with version {args.version} and name {args.new_name}."
205
+ )
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()
@@ -0,0 +1,307 @@
1
+ import numpy as np
2
+ from .utils import load_genes
3
+ from torch import Tensor, long
4
+
5
+ # class SimpleCollator:
6
+
7
+
8
+ class Collator:
9
+ def __init__(
10
+ self,
11
+ organisms: list,
12
+ how="all",
13
+ org_to_id: dict = None,
14
+ valid_genes: list = [],
15
+ max_len=2000,
16
+ add_zero_genes=0,
17
+ logp1=False,
18
+ norm_to=None,
19
+ n_bins=0,
20
+ tp_name=None,
21
+ organism_name="organism_ontology_term_id",
22
+ class_names=[],
23
+ genelist=[],
24
+ ):
25
+ """
26
+ This class is responsible for collating data for the scPRINT model. It handles the
27
+ organization and preparation of gene expression data from different organisms,
28
+ allowing for various configurations such as maximum gene list length, normalization,
29
+ and selection method for gene expression.
30
+
31
+ This Collator should work with scVI's dataloader as well!
32
+
33
+ Args:
34
+ organisms (list): List of organisms to be considered for gene expression data.
35
+ it will drop any other organism it sees (might lead to batches of different sizes!)
36
+ how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
37
+ one of ["most expr", "random expr", "all", "some"]:
38
+ "most expr": selects the max_len most expressed genes,
39
+ if less genes are expressed, will sample random unexpressed genes,
40
+ "random expr": uses a random set of max_len expressed genes.
41
+ if less genes are expressed, will sample random unexpressed genes
42
+ "all": uses all genes
43
+ "some": uses only the genes provided through the genelist param
44
+ org_to_id (dict): Dictionary mapping organisms to their respective IDs.
45
+ valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
46
+ it will drop any other genes from the input expression data (usefull when your model only works on some genes)
47
+ max_len (int, optional): Maximum number of genes to use (for random expr and most expr). Defaults to 2000.
48
+ n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
49
+ add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
50
+ logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
51
+ norm_to (str, optional): Normalization method to be applied. Defaults to None.
52
+ """
53
+ self.organisms = organisms
54
+ self.max_len = max_len
55
+ self.n_bins = n_bins
56
+ self.add_zero_genes = add_zero_genes
57
+ self.logp1 = logp1
58
+ self.norm_to = norm_to
59
+ self.org_to_id = org_to_id
60
+ self.how = how
61
+ self.organism_ids = (
62
+ set([org_to_id[k] for k in organisms])
63
+ if org_to_id is not None
64
+ else set(organisms)
65
+ )
66
+ if self.how == "some":
67
+ assert len(genelist) > 0, "if how is some, genelist must be provided"
68
+ self.organism_name = organism_name
69
+ self.tp_name = tp_name
70
+ self.class_names = class_names
71
+
72
+ self.start_idx = {}
73
+ self.accepted_genes = {}
74
+ self.genedf = load_genes(organisms)
75
+ self.to_subset = {}
76
+ for organism in set(self.genedf.organism):
77
+ ogenedf = self.genedf[self.genedf.organism == organism]
78
+ tot = self.genedf[self.genedf.index.isin(valid_genes)]
79
+ org = org_to_id[organism] if org_to_id is not None else organism
80
+ self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
81
+ if len(valid_genes) > 0:
82
+ self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
83
+ if len(genelist) > 0:
84
+ df = ogenedf[ogenedf.index.isin(valid_genes)]
85
+ self.to_subset.update({org: df.index.isin(genelist)})
86
+
87
+ def __call__(self, batch):
88
+ """
89
+ __call__ applies the collator to a minibatch of data
90
+
91
+ Args:
92
+ batch (list[dict[str: array]]): List of dicts of arrays containing gene expression data.
93
+ the first list is for the different samples, the second list is for the different elements with
94
+ elem["x"]: gene expression
95
+ elem["organism_name"]: organism ontology term id
96
+ elem["tp_name"]: heat diff
97
+ elem["class_names.."]: other classes
98
+
99
+ Returns:
100
+ list[Tensor]: List of tensors containing the collated data.
101
+ """
102
+ # do count selection
103
+ # get the unseen info and don't add any unseen
104
+ # get the I most expressed genes, add randomly some unexpressed genes that are not unseen
105
+ exprs = []
106
+ total_count = []
107
+ other_classes = []
108
+ gene_locs = []
109
+ tp = []
110
+ dataset = []
111
+ nnz_loc = []
112
+ for elem in batch:
113
+ organism_id = elem[self.organism_name]
114
+ if organism_id not in self.organism_ids:
115
+ continue
116
+ if "dataset" in elem:
117
+ dataset.append(elem["dataset"])
118
+ expr = np.array(elem["x"])
119
+ total_count.append(expr.sum())
120
+ if len(self.accepted_genes) > 0:
121
+ expr = expr[self.accepted_genes[organism_id]]
122
+ if self.how == "most expr":
123
+ nnz_loc = np.where(expr > 0)[0]
124
+ ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
125
+ loc = np.argsort(expr)[-(ma):][::-1]
126
+ # nnz_loc = [1] * 30_000
127
+ # loc = np.argsort(expr)[-(self.max_len) :][::-1]
128
+ elif self.how == "random expr":
129
+ nnz_loc = np.where(expr > 0)[0]
130
+ loc = nnz_loc[
131
+ np.random.choice(
132
+ len(nnz_loc),
133
+ self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
134
+ replace=False,
135
+ # p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
136
+ )
137
+ ]
138
+ elif self.how in ["all", "some"]:
139
+ loc = np.arange(len(expr))
140
+ else:
141
+ raise ValueError("how must be either most expr or random expr")
142
+ if (
143
+ (self.add_zero_genes > 0) or (self.max_len > len(nnz_loc))
144
+ ) and self.how not in ["all", "some"]:
145
+ zero_loc = np.where(expr == 0)[0]
146
+ zero_loc = zero_loc[
147
+ np.random.choice(
148
+ len(zero_loc),
149
+ self.add_zero_genes
150
+ + (
151
+ 0
152
+ if self.max_len < len(nnz_loc)
153
+ else self.max_len - len(nnz_loc)
154
+ ),
155
+ replace=False,
156
+ )
157
+ ]
158
+ loc = np.concatenate((loc, zero_loc), axis=None)
159
+ expr = expr[loc]
160
+ loc = loc + self.start_idx[organism_id]
161
+ if self.how == "some":
162
+ expr = expr[self.to_subset[organism_id]]
163
+ loc = loc[self.to_subset[organism_id]]
164
+ exprs.append(expr)
165
+ gene_locs.append(loc)
166
+
167
+ if self.tp_name is not None:
168
+ tp.append(elem[self.tp_name])
169
+ else:
170
+ tp.append(0)
171
+
172
+ other_classes.append([elem[i] for i in self.class_names])
173
+
174
+ expr = np.array(exprs)
175
+ tp = np.array(tp)
176
+ gene_locs = np.array(gene_locs)
177
+ total_count = np.array(total_count)
178
+ other_classes = np.array(other_classes)
179
+ dataset = np.array(dataset)
180
+
181
+ # normalize counts
182
+ if self.norm_to is not None:
183
+ expr = (expr * self.norm_to) / total_count[:, None]
184
+ if self.logp1:
185
+ expr = np.log2(1 + expr)
186
+
187
+ # do binning of counts
188
+ if self.n_bins:
189
+ pass
190
+
191
+ # find the associated gene ids (given the species)
192
+
193
+ # get the NN cells
194
+
195
+ # do encoding / selection a la scGPT
196
+
197
+ # do encoding of graph location
198
+ # encode all the edges in some sparse way
199
+ # normalizing total counts between 0,1
200
+ ret = {
201
+ "x": Tensor(expr),
202
+ "genes": Tensor(gene_locs).int(),
203
+ "class": Tensor(other_classes).int(),
204
+ "tp": Tensor(tp),
205
+ "depth": Tensor(total_count),
206
+ }
207
+ if len(dataset) > 0:
208
+ ret.update({"dataset": Tensor(dataset).to(long)})
209
+ return ret
210
+
211
+
212
+ class AnnDataCollator(Collator):
213
+ def __init__(self, *args, **kwargs):
214
+ """
215
+ AnnDataCollator Collator to use if working with AnnData's experimental dataloader (it is very slow!!!)
216
+
217
+ Args:
218
+ @see Collator
219
+ """
220
+ super().__init__(*args, **kwargs)
221
+
222
+ def __call__(self, batch):
223
+ exprs = []
224
+ total_count = []
225
+ other_classes = []
226
+ gene_locs = []
227
+ tp = []
228
+ for elem in batch:
229
+ organism_id = elem.obs[self.organism_name]
230
+ if organism_id.item() not in self.organism_ids:
231
+ print(organism_id)
232
+ expr = np.array(elem.X[0])
233
+
234
+ total_count.append(expr.sum())
235
+ if len(self.accepted_genes) > 0:
236
+ expr = expr[self.accepted_genes[organism_id]]
237
+ if self.how == "most expr":
238
+ loc = np.argsort(expr)[-(self.max_len) :][::-1]
239
+ elif self.how == "random expr":
240
+ nnz_loc = np.where(expr > 0)[0]
241
+ loc = nnz_loc[
242
+ np.random.choice(len(nnz_loc), self.max_len, replace=False)
243
+ ]
244
+ else:
245
+ raise ValueError("how must be either most expr or random expr")
246
+ if self.add_zero_genes > 0:
247
+ zero_loc = np.where(expr == 0)[0]
248
+ zero_loc = [
249
+ np.random.choice(len(zero_loc), self.add_zero_genes, replace=False)
250
+ ]
251
+ loc = np.concatenate((loc, zero_loc), axis=None)
252
+ exprs.append(expr[loc])
253
+ gene_locs.append(loc + self.start_idx[organism_id.item()])
254
+
255
+ if self.tp_name is not None:
256
+ tp.append(elem.obs[self.tp_name])
257
+ else:
258
+ tp.append(0)
259
+
260
+ other_classes.append([elem.obs[i].values[0] for i in self.class_names])
261
+
262
+ expr = np.array(exprs)
263
+ tp = np.array(tp)
264
+ gene_locs = np.array(gene_locs)
265
+ total_count = np.array(total_count)
266
+ other_classes = np.array(other_classes)
267
+ return {
268
+ "x": Tensor(expr),
269
+ "genes": Tensor(gene_locs).int(),
270
+ "depth": Tensor(total_count),
271
+ "class": Tensor(other_classes),
272
+ }
273
+
274
+
275
+ class GeneformerCollator(Collator):
276
+ def __init__(self, *args, gene_norm_list: list, **kwargs):
277
+ """
278
+ GeneformerCollator to finish
279
+
280
+ Args:
281
+ gene_norm_list (list): the normalization of expression through all datasets, per gene.
282
+ """
283
+ super().__init__(*args, **kwargs)
284
+ self.gene_norm_list = gene_norm_list
285
+
286
+ def __call__(self, batch):
287
+ super().__call__(batch)
288
+ # normlization per gene
289
+
290
+ # tokenize the empty locations
291
+
292
+
293
+ class scGPTCollator(Collator):
294
+ """
295
+ scGPTCollator to finish
296
+ """
297
+
298
+ def __call__(self, batch):
299
+ super().__call__(batch)
300
+ # binning
301
+
302
+ # tokenize the empty locations
303
+
304
+
305
+ class scPRINTCollator(Collator):
306
+ def __call__(self, batch):
307
+ super().__call__(batch)
scdataloader/config.py ADDED
@@ -0,0 +1,106 @@
1
+ LABELS_TOADD = {
2
+ "assay_ontology_term_id": {
3
+ "10x transcription profiling": "EFO:0030003",
4
+ "spatial transcriptomics": "EFO:0008994",
5
+ "10x 3' transcription profiling": "EFO:0030003",
6
+ "10x 5' transcription profiling": "EFO:0030004",
7
+ },
8
+ "disease_ontology_term_id": {
9
+ "metabolic disease": "MONDO:0005066",
10
+ "chronic kidney disease": "MONDO:0005300",
11
+ "chromosomal disorder": "MONDO:0019040",
12
+ "infectious disease": "MONDO:0005550",
13
+ "inflammatory disease": "MONDO:0021166",
14
+ # "immune system disease",
15
+ "disorder of development or morphogenesis": "MONDO:0021147",
16
+ "mitochondrial disease": "MONDO:0044970",
17
+ "psychiatric disorder": "MONDO:0002025",
18
+ "cancer or benign tumor": "MONDO:0002025",
19
+ "neoplasm": "MONDO:0005070",
20
+ },
21
+ "cell_type_ontology_term_id": {
22
+ "progenitor cell": "CL:0011026",
23
+ "hematopoietic cell": "CL:0000988",
24
+ "myoblast": "CL:0000056",
25
+ "myeloid cell": "CL:0000763",
26
+ "neuron": "CL:0000540",
27
+ "electrically active cell": "CL:0000211",
28
+ "epithelial cell": "CL:0000066",
29
+ "secretory cell": "CL:0000151",
30
+ "stem cell": "CL:0000034",
31
+ "non-terminally differentiated cell": "CL:0000055",
32
+ "supporting cell": "CL:0000630",
33
+ },
34
+ }
35
+
36
+ COARSE_TISSUE = {
37
+ "adipose tissue": "",
38
+ "bladder organ": "",
39
+ "blood": "",
40
+ "bone marrow": "",
41
+ "brain": "",
42
+ "breast": "",
43
+ "esophagus": "",
44
+ "eye": "",
45
+ "embryo": "",
46
+ "fallopian tube": "",
47
+ "gall bladder": "",
48
+ "heart": "",
49
+ "intestine": "",
50
+ "kidney": "",
51
+ "liver": "",
52
+ "lung": "",
53
+ "lymph node": "",
54
+ "musculature of body": "",
55
+ "nose": "",
56
+ "ovary": "",
57
+ "pancreas": "",
58
+ "placenta": "",
59
+ "skin of body": "",
60
+ "spinal cord": "",
61
+ "spleen": "",
62
+ "stomach": "",
63
+ "thymus": "",
64
+ "thyroid gland": "",
65
+ "tongue": "",
66
+ "uterus": "",
67
+ }
68
+
69
+ COARSE_ANCESTRY = {
70
+ "African": "",
71
+ "Chinese": "",
72
+ "East Asian": "",
73
+ "Eskimo": "",
74
+ "European": "",
75
+ "Greater Middle Eastern (Middle Eastern, North African or Persian)": "",
76
+ "Hispanic or Latin American": "",
77
+ "Native American": "",
78
+ "Oceanian": "",
79
+ "South Asian": "",
80
+ }
81
+
82
+ COARSE_DEVELOPMENT_STAGE = {
83
+ "Embryonic human": "",
84
+ "Fetal": "",
85
+ "Immature": "",
86
+ "Mature": "",
87
+ }
88
+
89
+ COARSE_ASSAY = {
90
+ "10x 3'": "",
91
+ "10x 5'": "",
92
+ "10x multiome": "",
93
+ "CEL-seq2": "",
94
+ "Drop-seq": "",
95
+ "GEXSCOPE technology": "",
96
+ "inDrop": "",
97
+ "microwell-seq": "",
98
+ "sci-Plex": "",
99
+ "sci-RNA-seq": "",
100
+ "Seq-Well": "",
101
+ "Slide-seq": "",
102
+ "Smart-seq": "",
103
+ "SPLiT-seq": "",
104
+ "TruDrop": "",
105
+ "Visium Spatial Gene Expression": "",
106
+ }