scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +76 -78
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +134 -77
- scdataloader/datamodule.py +638 -245
- scdataloader/mapped.py +104 -43
- scdataloader/preprocess.py +136 -110
- scdataloader/utils.py +158 -52
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/METADATA +6 -7
- scdataloader-2.0.2.dist-info/RECORD +16 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/WHEEL +1 -1
- scdataloader-2.0.2.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/entry_points.txt +0 -0
scdataloader/__main__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import Optional, Union
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
3
|
|
|
4
4
|
import lamindb as ln
|
|
5
5
|
|
|
@@ -149,7 +149,7 @@ def main():
|
|
|
149
149
|
)
|
|
150
150
|
preprocess_parser.add_argument(
|
|
151
151
|
"--batch_keys",
|
|
152
|
-
type=
|
|
152
|
+
type=List[str],
|
|
153
153
|
default=[
|
|
154
154
|
"assay_ontology_term_id",
|
|
155
155
|
"self_reported_ethnicity_ontology_term_id",
|
|
@@ -229,11 +229,11 @@ def main():
|
|
|
229
229
|
if args.instance is not None:
|
|
230
230
|
collection = (
|
|
231
231
|
ln.Collection.using(instance=args.instance)
|
|
232
|
-
.filter(
|
|
232
|
+
.filter(key=args.name, version=args.version)
|
|
233
233
|
.first()
|
|
234
234
|
)
|
|
235
235
|
else:
|
|
236
|
-
collection = ln.Collection.filter(
|
|
236
|
+
collection = ln.Collection.filter(key=args.name, version=args.version).first()
|
|
237
237
|
|
|
238
238
|
print(
|
|
239
239
|
"using the dataset ", collection, " of size ", len(collection.artifacts.all())
|
|
@@ -262,7 +262,6 @@ def main():
|
|
|
262
262
|
additional_preprocess=additional_preprocess,
|
|
263
263
|
additional_postprocess=additional_postprocess,
|
|
264
264
|
keep_files=False,
|
|
265
|
-
force_preloaded=args.force_preloaded,
|
|
266
265
|
)
|
|
267
266
|
|
|
268
267
|
# Preprocess the dataset
|
scdataloader/collator.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
4
5
|
from torch import Tensor, long
|
|
5
6
|
|
|
6
|
-
from .
|
|
7
|
+
from .preprocess import _digitize
|
|
8
|
+
from .utils import load_genes
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class Collator:
|
|
10
12
|
def __init__(
|
|
11
13
|
self,
|
|
12
|
-
organisms:
|
|
14
|
+
organisms: List[str],
|
|
13
15
|
how: str = "all",
|
|
14
16
|
org_to_id: dict[str, int] = None,
|
|
15
|
-
valid_genes:
|
|
17
|
+
valid_genes: Optional[List[str]] = None,
|
|
16
18
|
max_len: int = 2000,
|
|
17
19
|
add_zero_genes: int = 0,
|
|
18
20
|
logp1: bool = False,
|
|
@@ -20,10 +22,9 @@ class Collator:
|
|
|
20
22
|
n_bins: int = 0,
|
|
21
23
|
tp_name: Optional[str] = None,
|
|
22
24
|
organism_name: str = "organism_ontology_term_id",
|
|
23
|
-
class_names:
|
|
24
|
-
genelist:
|
|
25
|
-
|
|
26
|
-
save_output: Optional[str] = None,
|
|
25
|
+
class_names: List[str] = [],
|
|
26
|
+
genelist: List[str] = [],
|
|
27
|
+
genedf: Optional[pd.DataFrame] = None,
|
|
27
28
|
):
|
|
28
29
|
"""
|
|
29
30
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -57,13 +58,8 @@ class Collator:
|
|
|
57
58
|
class_names (list, optional): List of other classes to be considered. Defaults to [].
|
|
58
59
|
genelist (list, optional): List of genes to be considered. Defaults to [].
|
|
59
60
|
If [] all genes will be considered
|
|
60
|
-
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
|
|
61
|
-
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
|
|
62
|
-
save_output (str, optional): If not None, saves the output to a file. Defaults to None.
|
|
63
|
-
This is mainly for debugging purposes
|
|
64
61
|
"""
|
|
65
62
|
self.organisms = organisms
|
|
66
|
-
self.genedf = load_genes(organisms)
|
|
67
63
|
self.max_len = max_len
|
|
68
64
|
self.n_bins = n_bins
|
|
69
65
|
self.add_zero_genes = add_zero_genes
|
|
@@ -75,32 +71,36 @@ class Collator:
|
|
|
75
71
|
self.organism_name = organism_name
|
|
76
72
|
self.tp_name = tp_name
|
|
77
73
|
self.class_names = class_names
|
|
78
|
-
self.save_output = save_output
|
|
79
74
|
self.start_idx = {}
|
|
80
75
|
self.accepted_genes = {}
|
|
81
|
-
self.downsample = downsample
|
|
82
76
|
self.to_subset = {}
|
|
83
|
-
self._setup(org_to_id, valid_genes, genelist)
|
|
77
|
+
self._setup(genedf, org_to_id, valid_genes, genelist)
|
|
84
78
|
|
|
85
|
-
def _setup(self, org_to_id=None, valid_genes=[], genelist=[]):
|
|
79
|
+
def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
|
|
80
|
+
if genedf is None:
|
|
81
|
+
genedf = load_genes(self.organisms)
|
|
82
|
+
self.organism_ids = (
|
|
83
|
+
set([org_to_id[k] for k in self.organisms])
|
|
84
|
+
if org_to_id is not None
|
|
85
|
+
else set(self.organisms)
|
|
86
|
+
)
|
|
86
87
|
self.org_to_id = org_to_id
|
|
87
88
|
self.to_subset = {}
|
|
88
89
|
self.accepted_genes = {}
|
|
89
90
|
self.start_idx = {}
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
if
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
|
|
92
|
+
if valid_genes is not None:
|
|
93
|
+
if len(set(valid_genes) - set(genedf.index)) > 0:
|
|
94
|
+
print("Some valid genes are not in the genedf!!!")
|
|
95
|
+
tot = genedf[genedf.index.isin(valid_genes)]
|
|
96
|
+
else:
|
|
97
|
+
tot = genedf
|
|
95
98
|
for organism in self.organisms:
|
|
96
|
-
ogenedf = self.genedf[self.genedf.organism == organism]
|
|
97
|
-
if len(valid_genes) > 0:
|
|
98
|
-
tot = self.genedf[self.genedf.index.isin(valid_genes)]
|
|
99
|
-
else:
|
|
100
|
-
tot = self.genedf
|
|
101
99
|
org = org_to_id[organism] if org_to_id is not None else organism
|
|
102
100
|
self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
|
|
103
|
-
|
|
101
|
+
|
|
102
|
+
ogenedf = genedf[genedf.organism == organism]
|
|
103
|
+
if valid_genes is not None:
|
|
104
104
|
self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
|
|
105
105
|
if len(genelist) > 0:
|
|
106
106
|
df = ogenedf[ogenedf.index.isin(valid_genes)]
|
|
@@ -111,7 +111,7 @@ class Collator:
|
|
|
111
111
|
__call__ applies the collator to a minibatch of data
|
|
112
112
|
|
|
113
113
|
Args:
|
|
114
|
-
batch (
|
|
114
|
+
batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
115
115
|
the first list is for the different samples, the second list is for the different elements with
|
|
116
116
|
elem["X"]: gene expression
|
|
117
117
|
elem["organism_name"]: organism ontology term id
|
|
@@ -119,7 +119,7 @@ class Collator:
|
|
|
119
119
|
elem["class_names.."]: other classes
|
|
120
120
|
|
|
121
121
|
Returns:
|
|
122
|
-
|
|
122
|
+
List[Tensor]: List of tensors containing the collated data.
|
|
123
123
|
"""
|
|
124
124
|
# do count selection
|
|
125
125
|
# get the unseen info and don't add any unseen
|
|
@@ -133,6 +133,7 @@ class Collator:
|
|
|
133
133
|
nnz_loc = []
|
|
134
134
|
is_meta = []
|
|
135
135
|
knn_cells = []
|
|
136
|
+
knn_cells_info = []
|
|
136
137
|
for elem in batch:
|
|
137
138
|
organism_id = elem[self.organism_name]
|
|
138
139
|
if organism_id not in self.organism_ids:
|
|
@@ -188,7 +189,14 @@ class Collator:
|
|
|
188
189
|
if "knn_cells" in elem:
|
|
189
190
|
# we complete with genes expressed in the knn
|
|
190
191
|
# which is not a zero_loc in this context
|
|
191
|
-
|
|
192
|
+
knn_expr = elem["knn_cells"].sum(0)
|
|
193
|
+
mask = np.ones(len(knn_expr), dtype=bool)
|
|
194
|
+
mask[loc] = False
|
|
195
|
+
available_indices = np.where(mask)[0]
|
|
196
|
+
available_knn_expr = knn_expr[available_indices]
|
|
197
|
+
sorted_indices = np.argsort(available_knn_expr)[::-1]
|
|
198
|
+
selected = min(ma, len(available_indices))
|
|
199
|
+
zero_loc = available_indices[sorted_indices[:selected]]
|
|
192
200
|
else:
|
|
193
201
|
zero_loc = np.where(expr == 0)[0]
|
|
194
202
|
zero_loc = zero_loc[
|
|
@@ -212,6 +220,8 @@ class Collator:
|
|
|
212
220
|
exprs.append(expr)
|
|
213
221
|
if "knn_cells" in elem:
|
|
214
222
|
knn_cells.append(elem["knn_cells"])
|
|
223
|
+
if "knn_cells_info" in elem:
|
|
224
|
+
knn_cells_info.append(elem["knn_cells_info"])
|
|
215
225
|
# then we need to add the start_idx to the loc to give it the correct index
|
|
216
226
|
# according to the model
|
|
217
227
|
gene_locs.append(loc + self.start_idx[organism_id])
|
|
@@ -231,15 +241,46 @@ class Collator:
|
|
|
231
241
|
dataset = np.array(dataset)
|
|
232
242
|
is_meta = np.array(is_meta)
|
|
233
243
|
knn_cells = np.array(knn_cells)
|
|
244
|
+
knn_cells_info = np.array(knn_cells_info)
|
|
245
|
+
|
|
234
246
|
# normalize counts
|
|
235
247
|
if self.norm_to is not None:
|
|
236
248
|
expr = (expr * self.norm_to) / total_count[:, None]
|
|
249
|
+
# TODO: solve issue here
|
|
250
|
+
knn_cells = (knn_cells * self.norm_to) / total_count[:, None]
|
|
237
251
|
if self.logp1:
|
|
238
252
|
expr = np.log2(1 + expr)
|
|
253
|
+
knn_cells = np.log2(1 + knn_cells)
|
|
239
254
|
|
|
240
255
|
# do binning of counts
|
|
241
|
-
if self.n_bins:
|
|
242
|
-
|
|
256
|
+
if self.n_bins > 0:
|
|
257
|
+
binned_rows = []
|
|
258
|
+
bin_edges = []
|
|
259
|
+
for row in expr:
|
|
260
|
+
if row.max() == 0:
|
|
261
|
+
print(
|
|
262
|
+
"The input data contains all zero rows. Please make sure "
|
|
263
|
+
"this is expected. You can use the `filter_cell_by_counts` "
|
|
264
|
+
"arg to filter out all zero rows."
|
|
265
|
+
)
|
|
266
|
+
binned_rows.append(np.zeros_like(row, dtype=np.int64))
|
|
267
|
+
bin_edges.append(np.array([0] * self.n_bins))
|
|
268
|
+
continue
|
|
269
|
+
non_zero_ids = row.nonzero()
|
|
270
|
+
non_zero_row = row[non_zero_ids]
|
|
271
|
+
bins = np.quantile(non_zero_row, np.linspace(0, 1, self.n_bins - 1))
|
|
272
|
+
# bins = np.sort(np.unique(bins))
|
|
273
|
+
# NOTE: comment this line for now, since this will make the each category
|
|
274
|
+
# has different relative meaning across datasets
|
|
275
|
+
non_zero_digits = _digitize(non_zero_row, bins)
|
|
276
|
+
assert non_zero_digits.min() >= 1
|
|
277
|
+
assert non_zero_digits.max() <= self.n_bins - 1
|
|
278
|
+
binned_row = np.zeros_like(row, dtype=np.int64)
|
|
279
|
+
binned_row[non_zero_ids] = non_zero_digits
|
|
280
|
+
binned_rows.append(binned_row)
|
|
281
|
+
bin_edges.append(np.concatenate([[0], bins]))
|
|
282
|
+
expr = np.stack(binned_rows)
|
|
283
|
+
# expr = np.digitize(expr, bins=self.bins)
|
|
243
284
|
|
|
244
285
|
ret = {
|
|
245
286
|
"x": Tensor(expr),
|
|
@@ -252,51 +293,8 @@ class Collator:
|
|
|
252
293
|
ret.update({"is_meta": Tensor(is_meta).int()})
|
|
253
294
|
if len(knn_cells) > 0:
|
|
254
295
|
ret.update({"knn_cells": Tensor(knn_cells)})
|
|
296
|
+
if len(knn_cells_info) > 0:
|
|
297
|
+
ret.update({"knn_cells_info": Tensor(knn_cells_info)})
|
|
255
298
|
if len(dataset) > 0:
|
|
256
299
|
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
257
|
-
if self.downsample is not None:
|
|
258
|
-
ret["x"] = downsample_profile(ret["x"], self.downsample)
|
|
259
|
-
if self.save_output is not None:
|
|
260
|
-
with open(self.save_output, "a") as f:
|
|
261
|
-
np.savetxt(f, ret["x"].numpy())
|
|
262
|
-
with open(self.save_output + "_loc", "a") as f:
|
|
263
|
-
np.savetxt(f, gene_locs)
|
|
264
300
|
return ret
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
#############
|
|
268
|
-
#### WIP ####
|
|
269
|
-
#############
|
|
270
|
-
class GeneformerCollator(Collator):
|
|
271
|
-
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
272
|
-
"""
|
|
273
|
-
GeneformerCollator to finish
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
gene_norm_list (list): the normalization of expression through all datasets, per gene.
|
|
277
|
-
"""
|
|
278
|
-
super().__init__(*args, **kwargs)
|
|
279
|
-
self.gene_norm_list = gene_norm_list
|
|
280
|
-
|
|
281
|
-
def __call__(self, batch):
|
|
282
|
-
super().__call__(batch)
|
|
283
|
-
# normlization per gene
|
|
284
|
-
|
|
285
|
-
# tokenize the empty locations
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
class scGPTCollator(Collator):
|
|
289
|
-
"""
|
|
290
|
-
scGPTCollator to finish
|
|
291
|
-
"""
|
|
292
|
-
|
|
293
|
-
def __call__(self, batch):
|
|
294
|
-
super().__call__(batch)
|
|
295
|
-
# binning
|
|
296
|
-
|
|
297
|
-
# tokenize the empty locations
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class scPRINTCollator(Collator):
|
|
301
|
-
def __call__(self, batch):
|
|
302
|
-
super().__call__(batch)
|
scdataloader/config.py
CHANGED
|
@@ -113,26 +113,34 @@ COARSE_ASSAY = {
|
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
|
|
116
|
-
"HsapDv:0010000": [
|
|
116
|
+
"HsapDv:0010000": [ # postnatal stage
|
|
117
117
|
"MmusDv:0000092", # postnatal stage
|
|
118
118
|
],
|
|
119
|
-
"HsapDv:0000258": [ # mature stage
|
|
119
|
+
"HsapDv:0000258": [ # mature stage >15
|
|
120
120
|
"MmusDv:0000110", # mature stage
|
|
121
|
-
"HsapDv:0000204",
|
|
121
|
+
"HsapDv:0000204", #
|
|
122
122
|
],
|
|
123
|
-
"HsapDv:
|
|
123
|
+
"HsapDv:0000087": [], # adult stage >19
|
|
124
|
+
"HsapDv:0000227": [ # late adult stage > 40
|
|
124
125
|
"MmusDv:0000091", # 20 month-old stage
|
|
125
126
|
"MmusDv:0000089", # 18 month-old stage
|
|
127
|
+
"HsapDv:0000091", # > 45
|
|
128
|
+
"HsapDv:0000093", # > 65
|
|
129
|
+
],
|
|
130
|
+
"HsapDv:0000272": [ # 60-79 year-old stage
|
|
131
|
+
"HsapDv:0000094", # 60-79 year-old stage
|
|
126
132
|
],
|
|
127
|
-
"HsapDv:0000272": [], # 60-79 year-old stage
|
|
128
133
|
"HsapDv:0000095": [], # 80 year-old and over stage
|
|
129
|
-
"HsapDv:0000267": [ # middle aged stage
|
|
134
|
+
"HsapDv:0000267": [ # middle aged stage >40 <60
|
|
130
135
|
"MmusDv:0000087", # 16 month-old stage
|
|
131
136
|
"UBERON:0018241", # prime adult stage
|
|
132
137
|
"MmusDv:0000083", # 12 month-old stage
|
|
133
138
|
"HsapDv:0000092", # same
|
|
134
139
|
],
|
|
135
|
-
"HsapDv:0000266": [ # young adult stage
|
|
140
|
+
"HsapDv:0000266": [ # young adult stage <40
|
|
141
|
+
"HsapDv:0000088", # mature stage
|
|
142
|
+
"HsapDv:0000090", # 25 - 44
|
|
143
|
+
"HsapDv:0000086", # adolescent stage
|
|
136
144
|
"MmusDv:0000050", # 6 weeks
|
|
137
145
|
"HsapDv:0000089", # same
|
|
138
146
|
"MmusDv:0000051", # 7 weeks
|
|
@@ -163,22 +171,30 @@ MAIN_HUMAN_MOUSE_DEV_STAGE_MAP = {
|
|
|
163
171
|
"MmusDv:0000099", # 26 weeks
|
|
164
172
|
"MmusDv:0000102", # 29 weeks
|
|
165
173
|
],
|
|
166
|
-
"HsapDv:0000265": [
|
|
174
|
+
"HsapDv:0000265": [ # child stage (1-4 yo)
|
|
175
|
+
"HsapDv:0000084", # 2-5 yo
|
|
176
|
+
],
|
|
167
177
|
"HsapDv:0000271": [ # juvenile stage (5-14 yo)
|
|
168
178
|
"MmusDv:0000048", # 4 weeks
|
|
169
179
|
"MmusDv:0000049", # 5 weeks
|
|
180
|
+
"HsapDv:0000081", # child
|
|
181
|
+
"HsapDv:0000085", # 6-11 yo
|
|
170
182
|
],
|
|
171
|
-
"HsapDv:0000260": [ # infant stage
|
|
183
|
+
"HsapDv:0000260": [ # infant stage <2
|
|
172
184
|
"MmusDv:0000046", # 2 weeks
|
|
173
185
|
"MmusDv:0000045", # 1 week
|
|
174
186
|
"MmusDv:0000047", # 3 weeks
|
|
175
187
|
"HsapDv:0000083",
|
|
188
|
+
"HsapDv:0000256", # under 1 yo
|
|
176
189
|
],
|
|
177
190
|
"HsapDv:0000262": [ # newborn stage (0-28 days)
|
|
178
191
|
"MmusDv:0000036", # Theiler stage 27
|
|
179
192
|
"MmusDv:0000037", # Theiler stage 28
|
|
180
193
|
"MmusDv:0000113", # 4-7 days
|
|
194
|
+
"HsapDv:0000174", # 1 month-old stage
|
|
195
|
+
"HsapDv:0000082", # newborn stage
|
|
181
196
|
],
|
|
197
|
+
"HsapDv:0000002": [], # embryonic stage
|
|
182
198
|
"HsapDv:0000007": [], # Carnegie stage 03
|
|
183
199
|
"HsapDv:0000008": [], # Carnegie stage 04
|
|
184
200
|
"HsapDv:0000009": [], # Carnegie stage 05
|