scdataloader 2.0.0__tar.gz → 2.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scdataloader-2.0.0 → scdataloader-2.0.3}/.gitignore +1 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/PKG-INFO +5 -5
- {scdataloader-2.0.0 → scdataloader-2.0.3}/pyproject.toml +6 -7
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/__main__.py +4 -5
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/collator.py +65 -56
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/data.py +38 -54
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/datamodule.py +139 -86
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/mapped.py +27 -25
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/preprocess.py +31 -16
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/utils.py +120 -20
- {scdataloader-2.0.0 → scdataloader-2.0.3}/LICENSE +0 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/README.md +0 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/__init__.py +0 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/base.py +0 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/config.py +0 -0
- {scdataloader-2.0.0 → scdataloader-2.0.3}/scdataloader/data.json +0 -0
|
@@ -1,29 +1,29 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: scdataloader
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.3
|
|
4
4
|
Summary: a dataloader for single cell data in lamindb
|
|
5
5
|
Project-URL: repository, https://github.com/jkobject/scDataLoader
|
|
6
6
|
Author-email: jkobject <jkobject@gmail.com>
|
|
7
7
|
License-Expression: MIT
|
|
8
8
|
License-File: LICENSE
|
|
9
9
|
Keywords: dataloader,lamindb,pytorch,scPRINT,scRNAseq
|
|
10
|
-
Requires-Python: <3.
|
|
10
|
+
Requires-Python: <3.13,>=3.10
|
|
11
11
|
Requires-Dist: anndata>=0.9.0
|
|
12
12
|
Requires-Dist: biomart>=0.9.0
|
|
13
13
|
Requires-Dist: cellxgene-census>=0.1.0
|
|
14
14
|
Requires-Dist: django>=4.0.0
|
|
15
15
|
Requires-Dist: ipykernel>=6.20.0
|
|
16
16
|
Requires-Dist: jupytext>=1.16.0
|
|
17
|
-
Requires-Dist: lamindb[bionty,
|
|
17
|
+
Requires-Dist: lamindb[bionty,jupyter,zarr]==1.6.2
|
|
18
18
|
Requires-Dist: leidenalg>=0.8.0
|
|
19
19
|
Requires-Dist: lightning>=2.3.0
|
|
20
20
|
Requires-Dist: matplotlib>=3.5.0
|
|
21
|
-
Requires-Dist: numpy
|
|
21
|
+
Requires-Dist: numpy<=2.2.0
|
|
22
22
|
Requires-Dist: pandas>=2.0.0
|
|
23
23
|
Requires-Dist: pytorch-lightning>=2.3.0
|
|
24
24
|
Requires-Dist: scikit-misc>=0.5.0
|
|
25
25
|
Requires-Dist: seaborn>=0.11.0
|
|
26
|
-
Requires-Dist: torch
|
|
26
|
+
Requires-Dist: torch>=2.2.0
|
|
27
27
|
Requires-Dist: torchdata>=0.5.0
|
|
28
28
|
Requires-Dist: zarr>=2.10.0
|
|
29
29
|
Provides-Extra: dev
|
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "scdataloader"
|
|
3
|
-
version = "2.0.
|
|
3
|
+
version = "2.0.3"
|
|
4
4
|
description = "a dataloader for single cell data in lamindb"
|
|
5
5
|
authors = [
|
|
6
6
|
{name = "jkobject", email = "jkobject@gmail.com"}
|
|
7
7
|
]
|
|
8
8
|
license = "MIT"
|
|
9
9
|
readme = "README.md"
|
|
10
|
-
requires-python = ">=3.10,<3.
|
|
10
|
+
requires-python = ">=3.10,<3.13"
|
|
11
11
|
keywords = ["scRNAseq", "dataloader", "pytorch", "lamindb", "scPRINT"]
|
|
12
12
|
dependencies = [
|
|
13
|
-
"numpy
|
|
14
|
-
"lamindb[bionty,jupyter,
|
|
13
|
+
"numpy<=2.2.0",
|
|
14
|
+
"lamindb[bionty,jupyter,zarr]==1.6.2",
|
|
15
15
|
"cellxgene-census>=0.1.0",
|
|
16
|
-
"torch
|
|
16
|
+
"torch>=2.2.0",
|
|
17
17
|
"pytorch-lightning>=2.3.0",
|
|
18
|
+
"lightning>=2.3.0",
|
|
18
19
|
"anndata>=0.9.0",
|
|
19
20
|
"zarr>=2.10.0",
|
|
20
21
|
"matplotlib>=3.5.0",
|
|
@@ -27,8 +28,6 @@ dependencies = [
|
|
|
27
28
|
"django>=4.0.0",
|
|
28
29
|
"scikit-misc>=0.5.0",
|
|
29
30
|
"jupytext>=1.16.0",
|
|
30
|
-
"lightning>=2.3.0",
|
|
31
|
-
"pytorch-lightning>=2.3.0",
|
|
32
31
|
]
|
|
33
32
|
|
|
34
33
|
[project.optional-dependencies]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import Optional, Union
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
3
|
|
|
4
4
|
import lamindb as ln
|
|
5
5
|
|
|
@@ -149,7 +149,7 @@ def main():
|
|
|
149
149
|
)
|
|
150
150
|
preprocess_parser.add_argument(
|
|
151
151
|
"--batch_keys",
|
|
152
|
-
type=
|
|
152
|
+
type=List[str],
|
|
153
153
|
default=[
|
|
154
154
|
"assay_ontology_term_id",
|
|
155
155
|
"self_reported_ethnicity_ontology_term_id",
|
|
@@ -229,11 +229,11 @@ def main():
|
|
|
229
229
|
if args.instance is not None:
|
|
230
230
|
collection = (
|
|
231
231
|
ln.Collection.using(instance=args.instance)
|
|
232
|
-
.filter(
|
|
232
|
+
.filter(key=args.name, version=args.version)
|
|
233
233
|
.first()
|
|
234
234
|
)
|
|
235
235
|
else:
|
|
236
|
-
collection = ln.Collection.filter(
|
|
236
|
+
collection = ln.Collection.filter(key=args.name, version=args.version).first()
|
|
237
237
|
|
|
238
238
|
print(
|
|
239
239
|
"using the dataset ", collection, " of size ", len(collection.artifacts.all())
|
|
@@ -262,7 +262,6 @@ def main():
|
|
|
262
262
|
additional_preprocess=additional_preprocess,
|
|
263
263
|
additional_postprocess=additional_postprocess,
|
|
264
264
|
keep_files=False,
|
|
265
|
-
force_preloaded=args.force_preloaded,
|
|
266
265
|
)
|
|
267
266
|
|
|
268
267
|
# Preprocess the dataset
|
|
@@ -1,18 +1,20 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
4
5
|
from torch import Tensor, long
|
|
5
6
|
|
|
7
|
+
from .preprocess import _digitize
|
|
6
8
|
from .utils import load_genes
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class Collator:
|
|
10
12
|
def __init__(
|
|
11
13
|
self,
|
|
12
|
-
organisms:
|
|
14
|
+
organisms: List[str],
|
|
13
15
|
how: str = "all",
|
|
14
16
|
org_to_id: dict[str, int] = None,
|
|
15
|
-
valid_genes:
|
|
17
|
+
valid_genes: Optional[List[str]] = None,
|
|
16
18
|
max_len: int = 2000,
|
|
17
19
|
add_zero_genes: int = 0,
|
|
18
20
|
logp1: bool = False,
|
|
@@ -20,8 +22,9 @@ class Collator:
|
|
|
20
22
|
n_bins: int = 0,
|
|
21
23
|
tp_name: Optional[str] = None,
|
|
22
24
|
organism_name: str = "organism_ontology_term_id",
|
|
23
|
-
class_names:
|
|
24
|
-
genelist:
|
|
25
|
+
class_names: List[str] = [],
|
|
26
|
+
genelist: List[str] = [],
|
|
27
|
+
genedf: Optional[pd.DataFrame] = None,
|
|
25
28
|
):
|
|
26
29
|
"""
|
|
27
30
|
This class is responsible for collating data for the scPRINT model. It handles the
|
|
@@ -71,21 +74,22 @@ class Collator:
|
|
|
71
74
|
self.start_idx = {}
|
|
72
75
|
self.accepted_genes = {}
|
|
73
76
|
self.to_subset = {}
|
|
74
|
-
self._setup(
|
|
77
|
+
self._setup(genedf, org_to_id, valid_genes, genelist)
|
|
75
78
|
|
|
76
79
|
def _setup(self, genedf=None, org_to_id=None, valid_genes=[], genelist=[]):
|
|
77
80
|
if genedf is None:
|
|
78
81
|
genedf = load_genes(self.organisms)
|
|
82
|
+
self.organism_ids = (
|
|
83
|
+
set([org_to_id[k] for k in self.organisms])
|
|
84
|
+
if org_to_id is not None
|
|
85
|
+
else set(self.organisms)
|
|
86
|
+
)
|
|
79
87
|
self.org_to_id = org_to_id
|
|
80
88
|
self.to_subset = {}
|
|
81
89
|
self.accepted_genes = {}
|
|
82
90
|
self.start_idx = {}
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
if org_to_id is not None
|
|
86
|
-
else set(self.organisms)
|
|
87
|
-
)
|
|
88
|
-
if len(valid_genes) > 0:
|
|
91
|
+
|
|
92
|
+
if valid_genes is not None:
|
|
89
93
|
if len(set(valid_genes) - set(genedf.index)) > 0:
|
|
90
94
|
print("Some valid genes are not in the genedf!!!")
|
|
91
95
|
tot = genedf[genedf.index.isin(valid_genes)]
|
|
@@ -96,7 +100,7 @@ class Collator:
|
|
|
96
100
|
self.start_idx.update({org: np.where(tot.organism == organism)[0][0]})
|
|
97
101
|
|
|
98
102
|
ogenedf = genedf[genedf.organism == organism]
|
|
99
|
-
if
|
|
103
|
+
if valid_genes is not None:
|
|
100
104
|
self.accepted_genes.update({org: ogenedf.index.isin(valid_genes)})
|
|
101
105
|
if len(genelist) > 0:
|
|
102
106
|
df = ogenedf[ogenedf.index.isin(valid_genes)]
|
|
@@ -107,7 +111,7 @@ class Collator:
|
|
|
107
111
|
__call__ applies the collator to a minibatch of data
|
|
108
112
|
|
|
109
113
|
Args:
|
|
110
|
-
batch (
|
|
114
|
+
batch (List[dict[str: array]]): List of dicts of arrays containing gene expression data.
|
|
111
115
|
the first list is for the different samples, the second list is for the different elements with
|
|
112
116
|
elem["X"]: gene expression
|
|
113
117
|
elem["organism_name"]: organism ontology term id
|
|
@@ -115,7 +119,7 @@ class Collator:
|
|
|
115
119
|
elem["class_names.."]: other classes
|
|
116
120
|
|
|
117
121
|
Returns:
|
|
118
|
-
|
|
122
|
+
List[Tensor]: List of tensors containing the collated data.
|
|
119
123
|
"""
|
|
120
124
|
# do count selection
|
|
121
125
|
# get the unseen info and don't add any unseen
|
|
@@ -129,6 +133,7 @@ class Collator:
|
|
|
129
133
|
nnz_loc = []
|
|
130
134
|
is_meta = []
|
|
131
135
|
knn_cells = []
|
|
136
|
+
knn_cells_info = []
|
|
132
137
|
for elem in batch:
|
|
133
138
|
organism_id = elem[self.organism_name]
|
|
134
139
|
if organism_id not in self.organism_ids:
|
|
@@ -184,7 +189,14 @@ class Collator:
|
|
|
184
189
|
if "knn_cells" in elem:
|
|
185
190
|
# we complete with genes expressed in the knn
|
|
186
191
|
# which is not a zero_loc in this context
|
|
187
|
-
|
|
192
|
+
knn_expr = elem["knn_cells"].sum(0)
|
|
193
|
+
mask = np.ones(len(knn_expr), dtype=bool)
|
|
194
|
+
mask[loc] = False
|
|
195
|
+
available_indices = np.where(mask)[0]
|
|
196
|
+
available_knn_expr = knn_expr[available_indices]
|
|
197
|
+
sorted_indices = np.argsort(available_knn_expr)[::-1]
|
|
198
|
+
selected = min(ma, len(available_indices))
|
|
199
|
+
zero_loc = available_indices[sorted_indices[:selected]]
|
|
188
200
|
else:
|
|
189
201
|
zero_loc = np.where(expr == 0)[0]
|
|
190
202
|
zero_loc = zero_loc[
|
|
@@ -208,6 +220,8 @@ class Collator:
|
|
|
208
220
|
exprs.append(expr)
|
|
209
221
|
if "knn_cells" in elem:
|
|
210
222
|
knn_cells.append(elem["knn_cells"])
|
|
223
|
+
if "knn_cells_info" in elem:
|
|
224
|
+
knn_cells_info.append(elem["knn_cells_info"])
|
|
211
225
|
# then we need to add the start_idx to the loc to give it the correct index
|
|
212
226
|
# according to the model
|
|
213
227
|
gene_locs.append(loc + self.start_idx[organism_id])
|
|
@@ -227,15 +241,46 @@ class Collator:
|
|
|
227
241
|
dataset = np.array(dataset)
|
|
228
242
|
is_meta = np.array(is_meta)
|
|
229
243
|
knn_cells = np.array(knn_cells)
|
|
244
|
+
knn_cells_info = np.array(knn_cells_info)
|
|
245
|
+
|
|
230
246
|
# normalize counts
|
|
231
247
|
if self.norm_to is not None:
|
|
232
248
|
expr = (expr * self.norm_to) / total_count[:, None]
|
|
249
|
+
# TODO: solve issue here
|
|
250
|
+
knn_cells = (knn_cells * self.norm_to) / total_count[:, None]
|
|
233
251
|
if self.logp1:
|
|
234
252
|
expr = np.log2(1 + expr)
|
|
253
|
+
knn_cells = np.log2(1 + knn_cells)
|
|
235
254
|
|
|
236
255
|
# do binning of counts
|
|
237
|
-
if self.n_bins:
|
|
238
|
-
|
|
256
|
+
if self.n_bins > 0:
|
|
257
|
+
binned_rows = []
|
|
258
|
+
bin_edges = []
|
|
259
|
+
for row in expr:
|
|
260
|
+
if row.max() == 0:
|
|
261
|
+
print(
|
|
262
|
+
"The input data contains all zero rows. Please make sure "
|
|
263
|
+
"this is expected. You can use the `filter_cell_by_counts` "
|
|
264
|
+
"arg to filter out all zero rows."
|
|
265
|
+
)
|
|
266
|
+
binned_rows.append(np.zeros_like(row, dtype=np.int64))
|
|
267
|
+
bin_edges.append(np.array([0] * self.n_bins))
|
|
268
|
+
continue
|
|
269
|
+
non_zero_ids = row.nonzero()
|
|
270
|
+
non_zero_row = row[non_zero_ids]
|
|
271
|
+
bins = np.quantile(non_zero_row, np.linspace(0, 1, self.n_bins - 1))
|
|
272
|
+
# bins = np.sort(np.unique(bins))
|
|
273
|
+
# NOTE: comment this line for now, since this will make the each category
|
|
274
|
+
# has different relative meaning across datasets
|
|
275
|
+
non_zero_digits = _digitize(non_zero_row, bins)
|
|
276
|
+
assert non_zero_digits.min() >= 1
|
|
277
|
+
assert non_zero_digits.max() <= self.n_bins - 1
|
|
278
|
+
binned_row = np.zeros_like(row, dtype=np.int64)
|
|
279
|
+
binned_row[non_zero_ids] = non_zero_digits
|
|
280
|
+
binned_rows.append(binned_row)
|
|
281
|
+
bin_edges.append(np.concatenate([[0], bins]))
|
|
282
|
+
expr = np.stack(binned_rows)
|
|
283
|
+
# expr = np.digitize(expr, bins=self.bins)
|
|
239
284
|
|
|
240
285
|
ret = {
|
|
241
286
|
"x": Tensor(expr),
|
|
@@ -248,44 +293,8 @@ class Collator:
|
|
|
248
293
|
ret.update({"is_meta": Tensor(is_meta).int()})
|
|
249
294
|
if len(knn_cells) > 0:
|
|
250
295
|
ret.update({"knn_cells": Tensor(knn_cells)})
|
|
296
|
+
if len(knn_cells_info) > 0:
|
|
297
|
+
ret.update({"knn_cells_info": Tensor(knn_cells_info)})
|
|
251
298
|
if len(dataset) > 0:
|
|
252
299
|
ret.update({"dataset": Tensor(dataset).to(long)})
|
|
253
300
|
return ret
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
#############
|
|
257
|
-
#### WIP ####
|
|
258
|
-
#############
|
|
259
|
-
class GeneformerCollator(Collator):
|
|
260
|
-
def __init__(self, *args, gene_norm_list: list, **kwargs):
|
|
261
|
-
"""
|
|
262
|
-
GeneformerCollator to finish
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
gene_norm_list (list): the normalization of expression through all datasets, per gene.
|
|
266
|
-
"""
|
|
267
|
-
super().__init__(*args, **kwargs)
|
|
268
|
-
self.gene_norm_list = gene_norm_list
|
|
269
|
-
|
|
270
|
-
def __call__(self, batch):
|
|
271
|
-
super().__call__(batch)
|
|
272
|
-
# normlization per gene
|
|
273
|
-
|
|
274
|
-
# tokenize the empty locations
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
class scGPTCollator(Collator):
|
|
278
|
-
"""
|
|
279
|
-
scGPTCollator to finish
|
|
280
|
-
"""
|
|
281
|
-
|
|
282
|
-
def __call__(self, batch):
|
|
283
|
-
super().__call__(batch)
|
|
284
|
-
# binning
|
|
285
|
-
|
|
286
|
-
# tokenize the empty locations
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
class scPRINTCollator(Collator):
|
|
290
|
-
def __call__(self, batch):
|
|
291
|
-
super().__call__(batch)
|
|
@@ -2,7 +2,7 @@ import warnings
|
|
|
2
2
|
from collections import Counter
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from functools import reduce
|
|
5
|
-
from typing import Literal, Optional, Union
|
|
5
|
+
from typing import List, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
# ln.connect("scprint")
|
|
8
8
|
import bionty as bt
|
|
@@ -38,8 +38,8 @@ class Dataset(torchDataset):
|
|
|
38
38
|
----
|
|
39
39
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
40
40
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
41
|
-
obs (
|
|
42
|
-
clss_to_predict (
|
|
41
|
+
obs (List[str]): list of observations to load from the Collection
|
|
42
|
+
clss_to_predict (List[str]): list of observations to encode
|
|
43
43
|
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
44
44
|
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
45
45
|
metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
|
|
@@ -51,9 +51,9 @@ class Dataset(torchDataset):
|
|
|
51
51
|
lamin_dataset: ln.Collection
|
|
52
52
|
genedf: Optional[pd.DataFrame] = None
|
|
53
53
|
# set of obs to prepare for prediction (encode)
|
|
54
|
-
clss_to_predict: Optional[
|
|
54
|
+
clss_to_predict: Optional[List[str]] = field(default_factory=list)
|
|
55
55
|
# set of obs that need to be hierarchically prepared
|
|
56
|
-
hierarchical_clss: Optional[
|
|
56
|
+
hierarchical_clss: Optional[List[str]] = field(default_factory=list)
|
|
57
57
|
join_vars: Literal["inner", "outer"] | None = None
|
|
58
58
|
metacell_mode: float = 0.0
|
|
59
59
|
get_knn_cells: bool = False
|
|
@@ -61,6 +61,7 @@ class Dataset(torchDataset):
|
|
|
61
61
|
force_recompute_indices: bool = False
|
|
62
62
|
|
|
63
63
|
def __post_init__(self):
|
|
64
|
+
# see at the end of the file for the mapped function
|
|
64
65
|
self.mapped_dataset = mapped(
|
|
65
66
|
self.lamin_dataset,
|
|
66
67
|
obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
|
|
@@ -102,10 +103,10 @@ class Dataset(torchDataset):
|
|
|
102
103
|
"need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
|
|
103
104
|
)
|
|
104
105
|
self.organisms = list(self.class_topred["organism_ontology_term_id"])
|
|
105
|
-
self.organisms.sort()
|
|
106
106
|
self.genedf = load_genes(self.organisms)
|
|
107
107
|
else:
|
|
108
|
-
self.organisms =
|
|
108
|
+
self.organisms = self.genedf["organism"].unique().tolist()
|
|
109
|
+
self.organisms.sort()
|
|
109
110
|
|
|
110
111
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
111
112
|
# self.check_aligned_vars()
|
|
@@ -160,13 +161,11 @@ class Dataset(torchDataset):
|
|
|
160
161
|
+ " {} metacell_mode\n".format(self.metacell_mode)
|
|
161
162
|
)
|
|
162
163
|
|
|
163
|
-
def
|
|
164
|
+
def get_label_cats(
|
|
164
165
|
self,
|
|
165
|
-
obs_keys: str
|
|
166
|
-
scaler: int = 10,
|
|
167
|
-
return_categories=False,
|
|
166
|
+
obs_keys: Union[str, List[str]],
|
|
168
167
|
):
|
|
169
|
-
"""Get all
|
|
168
|
+
"""Get all categories for the given label keys."""
|
|
170
169
|
if isinstance(obs_keys, str):
|
|
171
170
|
obs_keys = [obs_keys]
|
|
172
171
|
labels = None
|
|
@@ -176,18 +175,7 @@ class Dataset(torchDataset):
|
|
|
176
175
|
labels = labels_to_str
|
|
177
176
|
else:
|
|
178
177
|
labels = concat_categorical_codes([labels, labels_to_str])
|
|
179
|
-
|
|
180
|
-
if return_categories:
|
|
181
|
-
counter = np.array(list(counter.values()))
|
|
182
|
-
weights = scaler / (counter + scaler)
|
|
183
|
-
return weights, np.array(labels.codes)
|
|
184
|
-
else:
|
|
185
|
-
counts = np.array([counter[label] for label in labels.codes])
|
|
186
|
-
if scaler is None:
|
|
187
|
-
weights = 1.0 / counts
|
|
188
|
-
else:
|
|
189
|
-
weights = scaler / (counts + scaler)
|
|
190
|
-
return weights
|
|
178
|
+
return np.array(labels.codes)
|
|
191
179
|
|
|
192
180
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
193
181
|
"""
|
|
@@ -197,16 +185,16 @@ class Dataset(torchDataset):
|
|
|
197
185
|
idx (int): index of the element to get
|
|
198
186
|
|
|
199
187
|
Returns:
|
|
200
|
-
|
|
188
|
+
List[str]: list of unseen genes
|
|
201
189
|
"""
|
|
202
190
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
203
191
|
|
|
204
|
-
def define_hierarchies(self, clsses:
|
|
192
|
+
def define_hierarchies(self, clsses: List[str]):
|
|
205
193
|
"""
|
|
206
194
|
define_hierarchies is a method to define the hierarchies for the classes to predict
|
|
207
195
|
|
|
208
196
|
Args:
|
|
209
|
-
clsses (
|
|
197
|
+
clsses (List[str]): list of classes to predict
|
|
210
198
|
|
|
211
199
|
Raises:
|
|
212
200
|
ValueError: if the class is not in the accepted classes
|
|
@@ -233,19 +221,19 @@ class Dataset(torchDataset):
|
|
|
233
221
|
elif clss == "cell_type_ontology_term_id":
|
|
234
222
|
parentdf = (
|
|
235
223
|
bt.CellType.filter()
|
|
236
|
-
.df(include=["parents__ontology_id"])
|
|
224
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
237
225
|
.set_index("ontology_id")
|
|
238
226
|
)
|
|
239
227
|
elif clss == "tissue_ontology_term_id":
|
|
240
228
|
parentdf = (
|
|
241
229
|
bt.Tissue.filter()
|
|
242
|
-
.df(include=["parents__ontology_id"])
|
|
230
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
243
231
|
.set_index("ontology_id")
|
|
244
232
|
)
|
|
245
233
|
elif clss == "disease_ontology_term_id":
|
|
246
234
|
parentdf = (
|
|
247
235
|
bt.Disease.filter()
|
|
248
|
-
.df(include=["parents__ontology_id"])
|
|
236
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
249
237
|
.set_index("ontology_id")
|
|
250
238
|
)
|
|
251
239
|
elif clss in [
|
|
@@ -255,19 +243,19 @@ class Dataset(torchDataset):
|
|
|
255
243
|
]:
|
|
256
244
|
parentdf = (
|
|
257
245
|
bt.DevelopmentalStage.filter()
|
|
258
|
-
.df(include=["parents__ontology_id"])
|
|
246
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
259
247
|
.set_index("ontology_id")
|
|
260
248
|
)
|
|
261
249
|
elif clss == "assay_ontology_term_id":
|
|
262
250
|
parentdf = (
|
|
263
251
|
bt.ExperimentalFactor.filter()
|
|
264
|
-
.df(include=["parents__ontology_id"])
|
|
252
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
265
253
|
.set_index("ontology_id")
|
|
266
254
|
)
|
|
267
255
|
elif clss == "self_reported_ethnicity_ontology_term_id":
|
|
268
256
|
parentdf = (
|
|
269
257
|
bt.Ethnicity.filter()
|
|
270
|
-
.df(include=["parents__ontology_id"])
|
|
258
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
271
259
|
.set_index("ontology_id")
|
|
272
260
|
)
|
|
273
261
|
|
|
@@ -279,6 +267,7 @@ class Dataset(torchDataset):
|
|
|
279
267
|
)
|
|
280
268
|
cats = set(self.mapped_dataset.encoders[clss].keys())
|
|
281
269
|
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
|
270
|
+
groupings.pop(None, None)
|
|
282
271
|
for i, j in groupings.items():
|
|
283
272
|
if len(j) == 0:
|
|
284
273
|
# that should not happen
|
|
@@ -286,6 +275,7 @@ class Dataset(torchDataset):
|
|
|
286
275
|
|
|
287
276
|
pdb.set_trace()
|
|
288
277
|
groupings.pop(i)
|
|
278
|
+
|
|
289
279
|
self.labels_groupings[clss] = groupings
|
|
290
280
|
if clss in self.clss_to_predict:
|
|
291
281
|
# if we have added new clss, we need to update the encoder with them too.
|
|
@@ -331,9 +321,10 @@ class SimpleAnnDataset(torchDataset):
|
|
|
331
321
|
def __init__(
|
|
332
322
|
self,
|
|
333
323
|
adata: AnnData,
|
|
334
|
-
obs_to_output: Optional[
|
|
324
|
+
obs_to_output: Optional[List[str]] = [],
|
|
335
325
|
layer: Optional[str] = None,
|
|
336
326
|
get_knn_cells: bool = False,
|
|
327
|
+
encoder: Optional[dict[str, dict]] = None,
|
|
337
328
|
):
|
|
338
329
|
"""
|
|
339
330
|
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
@@ -342,12 +333,14 @@ class SimpleAnnDataset(torchDataset):
|
|
|
342
333
|
Args:
|
|
343
334
|
----
|
|
344
335
|
adata (anndata.AnnData): anndata object to use
|
|
345
|
-
obs_to_output (
|
|
336
|
+
obs_to_output (List[str]): list of observations to output from anndata.obs
|
|
346
337
|
layer (str): layer of the anndata to use
|
|
347
338
|
get_knn_cells (bool): whether to get the knn cells
|
|
339
|
+
encoder (dict[str, dict]): dictionary of encoders for the observations.
|
|
348
340
|
"""
|
|
349
341
|
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
350
342
|
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
343
|
+
self.encoder = encoder if encoder is not None else {}
|
|
351
344
|
|
|
352
345
|
self.obs_to_output = adata.obs[obs_to_output]
|
|
353
346
|
self.get_knn_cells = get_knn_cells
|
|
@@ -361,23 +354,14 @@ class SimpleAnnDataset(torchDataset):
|
|
|
361
354
|
|
|
362
355
|
def __iter__(self):
|
|
363
356
|
for idx in range(self.adataX.shape[0]):
|
|
364
|
-
out =
|
|
365
|
-
out.update(
|
|
366
|
-
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
367
|
-
)
|
|
368
|
-
if self.get_knn_cells:
|
|
369
|
-
distances = self.distances[idx].toarray()[0]
|
|
370
|
-
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
371
|
-
out["knn_cells"] = np.array(
|
|
372
|
-
[self.adataX[i].reshape(-1) for i in nn_idx],
|
|
373
|
-
dtype=int,
|
|
374
|
-
)
|
|
375
|
-
out["distances"] = distances[nn_idx]
|
|
357
|
+
out = self.__getitem__(idx)
|
|
376
358
|
yield out
|
|
377
359
|
|
|
378
360
|
def __getitem__(self, idx):
|
|
379
361
|
out = {"X": self.adataX[idx].reshape(-1)}
|
|
380
|
-
|
|
362
|
+
# put the observation into the output and encode if needed
|
|
363
|
+
for name, val in self.obs_to_output.iloc[idx].items():
|
|
364
|
+
out.update({name: self.encoder[name][val] if name in self.encoder else val})
|
|
381
365
|
if self.get_knn_cells:
|
|
382
366
|
distances = self.distances[idx].toarray()[0]
|
|
383
367
|
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
@@ -385,17 +369,17 @@ class SimpleAnnDataset(torchDataset):
|
|
|
385
369
|
[self.adataX[i].reshape(-1) for i in nn_idx],
|
|
386
370
|
dtype=int,
|
|
387
371
|
)
|
|
388
|
-
out["
|
|
372
|
+
out["knn_cells_info"] = distances[nn_idx]
|
|
389
373
|
return out
|
|
390
374
|
|
|
391
375
|
|
|
392
376
|
def mapped(
|
|
393
377
|
dataset,
|
|
394
|
-
obs_keys:
|
|
395
|
-
obsm_keys:
|
|
378
|
+
obs_keys: List[str] | None = None,
|
|
379
|
+
obsm_keys: List[str] | None = None,
|
|
396
380
|
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
397
381
|
join: Literal["inner", "outer"] | None = "inner",
|
|
398
|
-
encode_labels: bool |
|
|
382
|
+
encode_labels: bool | List[str] = True,
|
|
399
383
|
unknown_label: str | dict[str, str] | None = None,
|
|
400
384
|
cache_categories: bool = True,
|
|
401
385
|
parallel: bool = False,
|
|
@@ -403,7 +387,7 @@ def mapped(
|
|
|
403
387
|
stream: bool = False,
|
|
404
388
|
is_run_input: bool | None = None,
|
|
405
389
|
metacell_mode: bool = False,
|
|
406
|
-
meta_assays:
|
|
390
|
+
meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
|
|
407
391
|
get_knn_cells: bool = False,
|
|
408
392
|
store_location: str | None = None,
|
|
409
393
|
force_recompute_indices: bool = False,
|
|
@@ -440,7 +424,7 @@ def mapped(
|
|
|
440
424
|
return ds
|
|
441
425
|
|
|
442
426
|
|
|
443
|
-
def concat_categorical_codes(series_list:
|
|
427
|
+
def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
|
|
444
428
|
"""Efficiently combine multiple categorical data using their codes,
|
|
445
429
|
only creating categories for combinations that exist in the data.
|
|
446
430
|
|