scdataloader 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +63 -42
- scdataloader/collator.py +87 -43
- scdataloader/config.py +106 -0
- scdataloader/data.py +78 -98
- scdataloader/datamodule.py +375 -0
- scdataloader/mapped.py +22 -7
- scdataloader/preprocess.py +444 -109
- scdataloader/utils.py +106 -63
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/METADATA +46 -2
- scdataloader-0.0.4.dist-info/RECORD +16 -0
- scdataloader/dataloader.py +0 -318
- scdataloader-0.0.3.dist-info/RECORD +0 -15
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/LICENSE +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/WHEEL +0 -0
- {scdataloader-0.0.3.dist-info → scdataloader-0.0.4.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -1,83 +1,53 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
3
|
import lamindb as ln
|
|
4
|
-
import
|
|
4
|
+
import bionty as bt
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from torch.utils.data import Dataset as torchDataset
|
|
7
|
-
from typing import Union
|
|
7
|
+
from typing import Union, Optional, Literal
|
|
8
8
|
from scdataloader import mapped
|
|
9
9
|
import warnings
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
from anndata import AnnData
|
|
12
|
+
|
|
13
13
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
"assay_ontology_term_id": {
|
|
17
|
-
"10x transcription profiling": "EFO:0030003",
|
|
18
|
-
"spatial transcriptomics": "EFO:0008994",
|
|
19
|
-
"10x 3' transcription profiling": "EFO:0030003",
|
|
20
|
-
"10x 5' transcription profiling": "EFO:0030004",
|
|
21
|
-
},
|
|
22
|
-
"disease_ontology_term_id": {
|
|
23
|
-
"metabolic disease": "MONDO:0005066",
|
|
24
|
-
"chronic kidney disease": "MONDO:0005300",
|
|
25
|
-
"chromosomal disorder": "MONDO:0019040",
|
|
26
|
-
"infectious disease": "MONDO:0005550",
|
|
27
|
-
"inflammatory disease": "MONDO:0021166",
|
|
28
|
-
# "immune system disease",
|
|
29
|
-
"disorder of development or morphogenesis": "MONDO:0021147",
|
|
30
|
-
"mitochondrial disease": "MONDO:0044970",
|
|
31
|
-
"psychiatric disorder": "MONDO:0002025",
|
|
32
|
-
"cancer or benign tumor": "MONDO:0002025",
|
|
33
|
-
"neoplasm": "MONDO:0005070",
|
|
34
|
-
},
|
|
35
|
-
"cell_type_ontology_term_id": {
|
|
36
|
-
"progenitor cell": "CL:0011026",
|
|
37
|
-
"hematopoietic cell": "CL:0000988",
|
|
38
|
-
"myoblast": "CL:0000056",
|
|
39
|
-
"myeloid cell": "CL:0000763",
|
|
40
|
-
"neuron": "CL:0000540",
|
|
41
|
-
"electrically active cell": "CL:0000211",
|
|
42
|
-
"epithelial cell": "CL:0000066",
|
|
43
|
-
"secretory cell": "CL:0000151",
|
|
44
|
-
"stem cell": "CL:0000034",
|
|
45
|
-
"non-terminally differentiated cell": "CL:0000055",
|
|
46
|
-
"supporting cell": "CL:0000630",
|
|
47
|
-
},
|
|
48
|
-
}
|
|
15
|
+
from .config import LABELS_TOADD
|
|
49
16
|
|
|
50
17
|
|
|
51
18
|
@dataclass
|
|
52
19
|
class Dataset(torchDataset):
|
|
53
20
|
"""
|
|
54
|
-
Dataset class to load a bunch of anndata from a lamin dataset in a memory efficient way.
|
|
21
|
+
Dataset class to load a bunch of anndata from a lamin dataset (Collection) in a memory efficient way.
|
|
22
|
+
|
|
23
|
+
This serves as a wrapper around lamin's mappedCollection to provide more features,
|
|
24
|
+
mostly, the management of hierarchical labels, the encoding of labels, the management of multiple species
|
|
55
25
|
|
|
56
|
-
For an example, see :meth:`~lamindb.Dataset.mapped`.
|
|
26
|
+
For an example of mappedDataset, see :meth:`~lamindb.Dataset.mapped`.
|
|
57
27
|
|
|
58
28
|
.. note::
|
|
59
29
|
|
|
60
|
-
A
|
|
30
|
+
A related data loader exists `here
|
|
61
31
|
<https://github.com/Genentech/scimilarity>`__.
|
|
62
32
|
|
|
63
|
-
|
|
33
|
+
Args:
|
|
64
34
|
----
|
|
65
35
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
66
36
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
67
|
-
gene_embedding: dataframe containing the gene embeddings
|
|
68
37
|
organisms (list[str]): list of organisms to load
|
|
69
|
-
|
|
38
|
+
(for now only validates the the genes map to this organism)
|
|
39
|
+
obs (list[str]): list of observations to load from the Collection
|
|
70
40
|
clss_to_pred (list[str]): list of observations to encode
|
|
71
|
-
|
|
41
|
+
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
42
|
+
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
72
43
|
"""
|
|
73
44
|
|
|
74
45
|
lamin_dataset: ln.Collection
|
|
75
|
-
genedf: pd.DataFrame = None
|
|
76
|
-
|
|
77
|
-
organisms: Union[list[str], str] = field(
|
|
46
|
+
genedf: Optional[pd.DataFrame] = None
|
|
47
|
+
organisms: Optional[Union[list[str], str]] = field(
|
|
78
48
|
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
79
49
|
)
|
|
80
|
-
obs: list[str] = field(
|
|
50
|
+
obs: Optional[list[str]] = field(
|
|
81
51
|
default_factory=[
|
|
82
52
|
"self_reported_ethnicity_ontology_term_id",
|
|
83
53
|
"assay_ontology_term_id",
|
|
@@ -88,16 +58,16 @@ class Dataset(torchDataset):
|
|
|
88
58
|
"sex_ontology_term_id",
|
|
89
59
|
#'dataset_id',
|
|
90
60
|
#'cell_culture',
|
|
91
|
-
"dpt_group",
|
|
92
|
-
"heat_diff",
|
|
93
|
-
"nnz",
|
|
61
|
+
#"dpt_group",
|
|
62
|
+
#"heat_diff",
|
|
63
|
+
#"nnz",
|
|
94
64
|
]
|
|
95
65
|
)
|
|
96
66
|
# set of obs to prepare for prediction (encode)
|
|
97
|
-
clss_to_pred: list[str] = field(default_factory=list)
|
|
67
|
+
clss_to_pred: Optional[list[str]] = field(default_factory=list)
|
|
98
68
|
# set of obs that need to be hierarchically prepared
|
|
99
|
-
hierarchical_clss: list[str] = field(default_factory=list)
|
|
100
|
-
join_vars:
|
|
69
|
+
hierarchical_clss: Optional[list[str]] = field(default_factory=list)
|
|
70
|
+
join_vars: Optional[Literal["auto", "inner", "None"]] = "None"
|
|
101
71
|
|
|
102
72
|
def __post_init__(self):
|
|
103
73
|
self.mapped_dataset = mapped.mapped(
|
|
@@ -111,6 +81,8 @@ class Dataset(torchDataset):
|
|
|
111
81
|
print(
|
|
112
82
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
113
83
|
)
|
|
84
|
+
self.class_groupings = {}
|
|
85
|
+
self.class_topred = {}
|
|
114
86
|
# generate tree from ontologies
|
|
115
87
|
if len(self.hierarchical_clss) > 0:
|
|
116
88
|
self.define_hierarchies(self.hierarchical_clss)
|
|
@@ -149,7 +121,12 @@ class Dataset(torchDataset):
|
|
|
149
121
|
|
|
150
122
|
def __getitem__(self, *args, **kwargs):
|
|
151
123
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
152
|
-
#
|
|
124
|
+
# import pdb
|
|
125
|
+
|
|
126
|
+
# pdb.set_trace()
|
|
127
|
+
# item.update(
|
|
128
|
+
# {"unseen_genes": self.get_unseen_mapped_dataset_elements(*args, **kwargs)}
|
|
129
|
+
# )
|
|
153
130
|
# ret = {}
|
|
154
131
|
# ret["count"] = item[0]
|
|
155
132
|
# for i, val in enumerate(self.obs):
|
|
@@ -160,51 +137,36 @@ class Dataset(torchDataset):
|
|
|
160
137
|
return item
|
|
161
138
|
|
|
162
139
|
def __repr__(self):
|
|
163
|
-
|
|
164
|
-
"total dataset size is {} Gb".format(
|
|
140
|
+
return (
|
|
141
|
+
"total dataset size is {} Gb\n".format(
|
|
165
142
|
sum([file.size for file in self.lamin_dataset.artifacts.all()]) / 1e9
|
|
166
143
|
)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
"
|
|
176
|
-
|
|
144
|
+
+ "---\n"
|
|
145
|
+
+ "dataset contains:\n"
|
|
146
|
+
+ " {} cells\n".format(self.mapped_dataset.__len__())
|
|
147
|
+
+ " {} genes\n".format(self.genedf.shape[0])
|
|
148
|
+
+ " {} labels\n".format(len(self.obs))
|
|
149
|
+
+ " {} clss_to_pred\n".format(len(self.clss_to_pred))
|
|
150
|
+
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
151
|
+
+ " {} join_vars\n".format(len(self.join_vars))
|
|
152
|
+
+ " {} organisms\n".format(len(self.organisms))
|
|
153
|
+
+ (
|
|
154
|
+
"dataset contains {} classes to predict\n".format(
|
|
155
|
+
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
156
|
+
)
|
|
157
|
+
if len(self.class_topred) > 0
|
|
158
|
+
else ""
|
|
177
159
|
)
|
|
178
160
|
)
|
|
179
|
-
# print("embedding size is {}".format(self.gene_embedding.shape[1]))
|
|
180
|
-
return ""
|
|
181
161
|
|
|
182
162
|
def get_label_weights(self, *args, **kwargs):
|
|
183
163
|
return self.mapped_dataset.get_label_weights(*args, **kwargs)
|
|
184
164
|
|
|
185
|
-
def get_unseen_mapped_dataset_elements(self, idx):
|
|
165
|
+
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
186
166
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
187
167
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
# for o in self.organisms:
|
|
191
|
-
# genedf = genedfs[genedfs.organism == o]
|
|
192
|
-
# org_name = lb.Organism.filter(ontology_id=o).one().scientific_name
|
|
193
|
-
# embedding = embed(
|
|
194
|
-
# genedf=genedf,
|
|
195
|
-
# organism=org_name,
|
|
196
|
-
# cache=cache,
|
|
197
|
-
# fasta_path="/tmp/data/fasta/",
|
|
198
|
-
# embedding_size=embedding_size,
|
|
199
|
-
# )
|
|
200
|
-
# genedf = pd.concat(
|
|
201
|
-
# [genedf.set_index("ensembl_gene_id"), embedding], axis=1, join="inner"
|
|
202
|
-
# )
|
|
203
|
-
# genedf.columns = genedf.columns.astype(str)
|
|
204
|
-
# embeddings.append(genedf)
|
|
205
|
-
# return pd.concat(embeddings)
|
|
206
|
-
|
|
207
|
-
def define_hierarchies(self, labels):
|
|
168
|
+
def define_hierarchies(self, labels: list[str]):
|
|
169
|
+
# TODO: use all possible hierarchies instead of just the ones for which we have a sample annotated with
|
|
208
170
|
self.class_groupings = {}
|
|
209
171
|
self.class_topred = {}
|
|
210
172
|
for label in labels:
|
|
@@ -223,37 +185,37 @@ class Dataset(torchDataset):
|
|
|
223
185
|
)
|
|
224
186
|
elif label == "cell_type_ontology_term_id":
|
|
225
187
|
parentdf = (
|
|
226
|
-
|
|
188
|
+
bt.CellType.filter()
|
|
227
189
|
.df(include=["parents__ontology_id"])
|
|
228
190
|
.set_index("ontology_id")
|
|
229
191
|
)
|
|
230
192
|
elif label == "tissue_ontology_term_id":
|
|
231
193
|
parentdf = (
|
|
232
|
-
|
|
194
|
+
bt.Tissue.filter()
|
|
233
195
|
.df(include=["parents__ontology_id"])
|
|
234
196
|
.set_index("ontology_id")
|
|
235
197
|
)
|
|
236
198
|
elif label == "disease_ontology_term_id":
|
|
237
199
|
parentdf = (
|
|
238
|
-
|
|
200
|
+
bt.Disease.filter()
|
|
239
201
|
.df(include=["parents__ontology_id"])
|
|
240
202
|
.set_index("ontology_id")
|
|
241
203
|
)
|
|
242
204
|
elif label == "development_stage_ontology_term_id":
|
|
243
205
|
parentdf = (
|
|
244
|
-
|
|
206
|
+
bt.DevelopmentalStage.filter()
|
|
245
207
|
.df(include=["parents__ontology_id"])
|
|
246
208
|
.set_index("ontology_id")
|
|
247
209
|
)
|
|
248
210
|
elif label == "assay_ontology_term_id":
|
|
249
211
|
parentdf = (
|
|
250
|
-
|
|
212
|
+
bt.ExperimentalFactor.filter()
|
|
251
213
|
.df(include=["parents__ontology_id"])
|
|
252
214
|
.set_index("ontology_id")
|
|
253
215
|
)
|
|
254
216
|
elif label == "self_reported_ethnicity_ontology_term_id":
|
|
255
217
|
parentdf = (
|
|
256
|
-
|
|
218
|
+
bt.Ethnicity.filter()
|
|
257
219
|
.df(include=["parents__ontology_id"])
|
|
258
220
|
.set_index("ontology_id")
|
|
259
221
|
)
|
|
@@ -267,6 +229,9 @@ class Dataset(torchDataset):
|
|
|
267
229
|
cats = self.mapped_dataset.get_merged_categories(label)
|
|
268
230
|
addition = set(LABELS_TOADD.get(label, {}).values())
|
|
269
231
|
cats |= addition
|
|
232
|
+
# import pdb
|
|
233
|
+
|
|
234
|
+
# pdb.set_trace()
|
|
270
235
|
groupings, _, lclass = get_ancestry_mapping(cats, parentdf)
|
|
271
236
|
for i, j in groupings.items():
|
|
272
237
|
if len(j) == 0:
|
|
@@ -316,7 +281,22 @@ class Dataset(torchDataset):
|
|
|
316
281
|
|
|
317
282
|
|
|
318
283
|
class SimpleAnnDataset:
|
|
319
|
-
def __init__(
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
adata: AnnData,
|
|
287
|
+
obs_to_output: Optional[list[str]] = [],
|
|
288
|
+
layer: Optional[str] = None,
|
|
289
|
+
):
|
|
290
|
+
"""
|
|
291
|
+
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
292
|
+
scDataloader and with your model during inference.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
----
|
|
296
|
+
adata (anndata.AnnData): anndata object to use
|
|
297
|
+
obs_to_output (list[str]): list of observations to output from anndata.obs
|
|
298
|
+
layer (str): layer of the anndata to use
|
|
299
|
+
"""
|
|
320
300
|
self.adata = adata
|
|
321
301
|
self.obs_to_output = obs_to_output
|
|
322
302
|
self.layer = layer
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
|
|
5
|
+
from torch.utils.data.sampler import (
|
|
6
|
+
WeightedRandomSampler,
|
|
7
|
+
SubsetRandomSampler,
|
|
8
|
+
SequentialSampler,
|
|
9
|
+
)
|
|
10
|
+
import torch
|
|
11
|
+
from torch.utils.data import DataLoader, Sampler
|
|
12
|
+
import lightning as L
|
|
13
|
+
|
|
14
|
+
from typing import Optional, Union, Sequence
|
|
15
|
+
|
|
16
|
+
from .data import Dataset
|
|
17
|
+
from .collator import Collator
|
|
18
|
+
from .utils import getBiomartTable
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DataModule(L.LightningDataModule):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
collection_name: str,
|
|
25
|
+
label_to_weight: list = ["organism_ontology_term_id"],
|
|
26
|
+
organisms: list = ["NCBITaxon:9606"],
|
|
27
|
+
weight_scaler: int = 10,
|
|
28
|
+
train_oversampling_per_epoch: float = 0.1,
|
|
29
|
+
validation_split: float = 0.2,
|
|
30
|
+
test_split: float = 0,
|
|
31
|
+
gene_embeddings: str = "",
|
|
32
|
+
use_default_col: bool = True,
|
|
33
|
+
gene_position_tolerance: int = 10_000,
|
|
34
|
+
# this is for the mappedCollection
|
|
35
|
+
label_to_pred: list = ["organism_ontology_term_id"],
|
|
36
|
+
all_labels: list = ["organism_ontology_term_id"],
|
|
37
|
+
hierarchical_labels: list = [],
|
|
38
|
+
# this is for the collator
|
|
39
|
+
how: str = "random expr",
|
|
40
|
+
organism_name: str = "organism_ontology_term_id",
|
|
41
|
+
max_len: int = 1000,
|
|
42
|
+
add_zero_genes: int = 100,
|
|
43
|
+
do_gene_pos: Union[bool, str] = True,
|
|
44
|
+
tp_name: Optional[str] = None, # "heat_diff"
|
|
45
|
+
assays_to_drop: list = [
|
|
46
|
+
"EFO:0008853",
|
|
47
|
+
"EFO:0010961",
|
|
48
|
+
"EFO:0030007",
|
|
49
|
+
"EFO:0030062",
|
|
50
|
+
],
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
DataModule a pytorch lighting datamodule directly from a lamin Collection.
|
|
55
|
+
it can work with bare pytorch too
|
|
56
|
+
|
|
57
|
+
It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
|
|
58
|
+
This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
collection_name (str): The lamindb collection to be used.
|
|
62
|
+
weight_scaler (int, optional): how much more you will see the most present vs less present category.
|
|
63
|
+
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
|
|
64
|
+
any genes within this distance of each other will be considered at the same position.
|
|
65
|
+
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
|
|
66
|
+
the file must have ensembl_gene_id as index.
|
|
67
|
+
This is used to subset the available genes further to the ones that have embeddings in your model.
|
|
68
|
+
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
|
|
69
|
+
label_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
|
|
70
|
+
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
|
|
71
|
+
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
|
|
72
|
+
it will use a full dataset and will round to the nearest dataset's cell count.
|
|
73
|
+
**other args: see @file data.py and @file collator.py for more details
|
|
74
|
+
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
|
|
75
|
+
"""
|
|
76
|
+
if collection_name is not None:
|
|
77
|
+
mdataset = Dataset(
|
|
78
|
+
ln.Collection.filter(name=collection_name).first(),
|
|
79
|
+
organisms=organisms,
|
|
80
|
+
obs=all_labels,
|
|
81
|
+
clss_to_pred=label_to_pred,
|
|
82
|
+
hierarchical_clss=hierarchical_labels,
|
|
83
|
+
)
|
|
84
|
+
print(mdataset)
|
|
85
|
+
# and location
|
|
86
|
+
if do_gene_pos:
|
|
87
|
+
if type(do_gene_pos) is str:
|
|
88
|
+
print("seeing a string: loading gene positions as biomart parquet file")
|
|
89
|
+
biomart = pd.read_parquet(do_gene_pos)
|
|
90
|
+
else:
|
|
91
|
+
# and annotations
|
|
92
|
+
biomart = getBiomartTable(
|
|
93
|
+
attributes=["start_position", "chromosome_name"]
|
|
94
|
+
).set_index("ensembl_gene_id")
|
|
95
|
+
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
96
|
+
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
97
|
+
c = []
|
|
98
|
+
i = 0
|
|
99
|
+
prev_position = -100000
|
|
100
|
+
prev_chromosome = None
|
|
101
|
+
for _, r in biomart.iterrows():
|
|
102
|
+
if (
|
|
103
|
+
r["chromosome_name"] != prev_chromosome
|
|
104
|
+
or r["start_position"] - prev_position > gene_position_tolerance
|
|
105
|
+
):
|
|
106
|
+
i += 1
|
|
107
|
+
c.append(i)
|
|
108
|
+
prev_position = r["start_position"]
|
|
109
|
+
prev_chromosome = r["chromosome_name"]
|
|
110
|
+
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
111
|
+
biomart["pos"] = c
|
|
112
|
+
mdataset.genedf = biomart.loc[mdataset.genedf.index]
|
|
113
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
114
|
+
|
|
115
|
+
if gene_embeddings != "":
|
|
116
|
+
mdataset.genedf = mdataset.genedf.join(
|
|
117
|
+
pd.read_parquet(gene_embeddings), how="inner"
|
|
118
|
+
)
|
|
119
|
+
if do_gene_pos:
|
|
120
|
+
self.gene_pos = mdataset.genedf["pos"].tolist()
|
|
121
|
+
self.labels = {k: len(v) for k, v in mdataset.class_topred.items()}
|
|
122
|
+
# we might want not to order the genes by expression (or do it?)
|
|
123
|
+
# we might want to not introduce zeros and
|
|
124
|
+
if use_default_col:
|
|
125
|
+
kwargs["collate_fn"] = Collator(
|
|
126
|
+
organisms=organisms,
|
|
127
|
+
how=how,
|
|
128
|
+
valid_genes=mdataset.genedf.index.tolist(),
|
|
129
|
+
max_len=max_len,
|
|
130
|
+
add_zero_genes=add_zero_genes,
|
|
131
|
+
org_to_id=mdataset.encoder[organism_name],
|
|
132
|
+
tp_name=tp_name,
|
|
133
|
+
organism_name=organism_name,
|
|
134
|
+
class_names=label_to_weight,
|
|
135
|
+
)
|
|
136
|
+
self.validation_split = validation_split
|
|
137
|
+
self.test_split = test_split
|
|
138
|
+
self.dataset = mdataset
|
|
139
|
+
self.kwargs = kwargs
|
|
140
|
+
self.assays_to_drop = assays_to_drop
|
|
141
|
+
self.n_samples = len(mdataset)
|
|
142
|
+
self.weight_scaler = weight_scaler
|
|
143
|
+
self.train_oversampling_per_epoch = train_oversampling_per_epoch
|
|
144
|
+
self.label_to_weight = label_to_weight
|
|
145
|
+
self.train_weights = None
|
|
146
|
+
self.train_labels = None
|
|
147
|
+
super().__init__()
|
|
148
|
+
|
|
149
|
+
def __repr__(self):
|
|
150
|
+
return (
|
|
151
|
+
f"DataLoader(\n"
|
|
152
|
+
f"\twith a dataset=({self.dataset.__repr__()}\n)\n"
|
|
153
|
+
f"\tvalidation_split={self.validation_split},\n"
|
|
154
|
+
f"\ttest_split={self.test_split},\n"
|
|
155
|
+
f"\tn_samples={self.n_samples},\n"
|
|
156
|
+
f"\tweight_scaler={self.weight_scaler},\n"
|
|
157
|
+
f"\train_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
|
|
158
|
+
f"\tlabel_to_weight={self.label_to_weight}\n"
|
|
159
|
+
+ (
|
|
160
|
+
"\twith train_dataset size of=("
|
|
161
|
+
+ str((self.train_weights != 0).sum())
|
|
162
|
+
+ ")\n)"
|
|
163
|
+
)
|
|
164
|
+
if self.train_weights is not None
|
|
165
|
+
else ")"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def decoders(self):
|
|
170
|
+
"""
|
|
171
|
+
decoders the decoders for any labels that would have been encoded
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
dict[str, dict[int, str]]
|
|
175
|
+
"""
|
|
176
|
+
decoders = {}
|
|
177
|
+
for k, v in self.dataset.encoder.items():
|
|
178
|
+
decoders[k] = {va: ke for ke, va in v.items()}
|
|
179
|
+
return decoders
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def cls_hierarchy(self):
|
|
183
|
+
"""
|
|
184
|
+
cls_hierarchy the hierarchy of labels for any cls that would have a hierarchy
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
dict[str, dict[str, str]]
|
|
188
|
+
"""
|
|
189
|
+
cls_hierarchy = {}
|
|
190
|
+
for k, dic in self.dataset.class_groupings.items():
|
|
191
|
+
rdic = {}
|
|
192
|
+
for sk, v in dic.items():
|
|
193
|
+
rdic[self.dataset.encoder[k][sk]] = [
|
|
194
|
+
self.dataset.encoder[k][i] for i in list(v)
|
|
195
|
+
]
|
|
196
|
+
cls_hierarchy[k] = rdic
|
|
197
|
+
return cls_hierarchy
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def genes(self):
|
|
201
|
+
"""
|
|
202
|
+
genes the genes used in this datamodule
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
list
|
|
206
|
+
"""
|
|
207
|
+
return self.dataset.genedf.index.tolist()
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def num_datasets(self):
|
|
211
|
+
return len(self.dataset.mapped_dataset.storages)
|
|
212
|
+
|
|
213
|
+
def setup(self, stage=None):
|
|
214
|
+
"""
|
|
215
|
+
setup method is used to prepare the data for the training, validation, and test sets.
|
|
216
|
+
It shuffles the data, calculates weights for each set, and creates samplers for each set.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
stage (str, optional): The stage of the model training process.
|
|
220
|
+
It can be either 'fit' or 'test'. Defaults to None.
|
|
221
|
+
"""
|
|
222
|
+
if len(self.label_to_weight) > 0:
|
|
223
|
+
weights, labels = self.dataset.get_label_weights(
|
|
224
|
+
self.label_to_weight, scaler=self.weight_scaler
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
weights = np.ones(1)
|
|
228
|
+
labels = np.zeros(self.n_samples)
|
|
229
|
+
if isinstance(self.validation_split, int):
|
|
230
|
+
len_valid = self.validation_split
|
|
231
|
+
else:
|
|
232
|
+
len_valid = int(self.n_samples * self.validation_split)
|
|
233
|
+
if isinstance(self.test_split, int):
|
|
234
|
+
len_test = self.test_split
|
|
235
|
+
else:
|
|
236
|
+
len_test = int(self.n_samples * self.test_split)
|
|
237
|
+
assert (
|
|
238
|
+
len_test + len_valid < self.n_samples
|
|
239
|
+
), "test set + valid set size is configured to be larger than entire dataset."
|
|
240
|
+
|
|
241
|
+
idx_full = []
|
|
242
|
+
if len(self.assays_to_drop) > 0:
|
|
243
|
+
for i, a in enumerate(
|
|
244
|
+
self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id")
|
|
245
|
+
):
|
|
246
|
+
if a not in self.assays_to_drop:
|
|
247
|
+
idx_full.append(i)
|
|
248
|
+
idx_full = np.array(idx_full)
|
|
249
|
+
else:
|
|
250
|
+
idx_full = np.arange(self.n_samples)
|
|
251
|
+
test_datasets = []
|
|
252
|
+
if len_test > 0:
|
|
253
|
+
# this way we work on some never seen datasets
|
|
254
|
+
# keeping at least one
|
|
255
|
+
len_test = (
|
|
256
|
+
len_test
|
|
257
|
+
if len_test > self.dataset.mapped_dataset.n_obs_list[0]
|
|
258
|
+
else self.dataset.mapped_dataset.n_obs_list[0]
|
|
259
|
+
)
|
|
260
|
+
cs = 0
|
|
261
|
+
print("these files will be considered test datasets:")
|
|
262
|
+
for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
|
|
263
|
+
if cs + c > len_test:
|
|
264
|
+
break
|
|
265
|
+
else:
|
|
266
|
+
print(" " + self.dataset.mapped_dataset.path_list[i].path)
|
|
267
|
+
test_datasets.append(self.dataset.mapped_dataset.path_list[i].path)
|
|
268
|
+
cs += c
|
|
269
|
+
|
|
270
|
+
len_test = cs
|
|
271
|
+
print("perc test: ", len_test / self.n_samples)
|
|
272
|
+
self.test_idx = idx_full[:len_test]
|
|
273
|
+
idx_full = idx_full[len_test:]
|
|
274
|
+
else:
|
|
275
|
+
self.test_idx = None
|
|
276
|
+
|
|
277
|
+
np.random.shuffle(idx_full)
|
|
278
|
+
if len_valid > 0:
|
|
279
|
+
self.valid_idx = idx_full[:len_valid].copy()
|
|
280
|
+
idx_full = idx_full[len_valid:]
|
|
281
|
+
else:
|
|
282
|
+
self.valid_idx = None
|
|
283
|
+
weights = np.concatenate([weights, np.zeros(1)])
|
|
284
|
+
labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
|
|
285
|
+
|
|
286
|
+
self.train_weights = weights
|
|
287
|
+
self.train_labels = labels
|
|
288
|
+
self.idx_full = idx_full
|
|
289
|
+
|
|
290
|
+
return test_datasets
|
|
291
|
+
|
|
292
|
+
def train_dataloader(self, **kwargs):
|
|
293
|
+
# train_sampler = WeightedRandomSampler(
|
|
294
|
+
# self.train_weights[self.train_labels],
|
|
295
|
+
# int(self.n_samples*self.train_oversampling_per_epoch),
|
|
296
|
+
# replacement=True,
|
|
297
|
+
# )
|
|
298
|
+
train_sampler = LabelWeightedSampler(
|
|
299
|
+
self.train_weights,
|
|
300
|
+
self.train_labels,
|
|
301
|
+
num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
|
|
302
|
+
# replacement=True,
|
|
303
|
+
)
|
|
304
|
+
return DataLoader(self.dataset, sampler=train_sampler, **self.kwargs, **kwargs)
|
|
305
|
+
|
|
306
|
+
def val_dataloader(self):
|
|
307
|
+
return (
|
|
308
|
+
DataLoader(
|
|
309
|
+
self.dataset, sampler=SubsetRandomSampler(self.valid_idx), **self.kwargs
|
|
310
|
+
)
|
|
311
|
+
if self.valid_idx is not None
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def test_dataloader(self):
|
|
316
|
+
return (
|
|
317
|
+
DataLoader(
|
|
318
|
+
self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
|
|
319
|
+
)
|
|
320
|
+
if self.test_idx is not None
|
|
321
|
+
else None
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# def teardown(self):
|
|
325
|
+
# clean up state after the trainer stops, delete files...
|
|
326
|
+
# called on every process in DDP
|
|
327
|
+
# pass
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class LabelWeightedSampler(Sampler[int]):
|
|
331
|
+
label_weights: Sequence[float]
|
|
332
|
+
klass_indices: Sequence[Sequence[int]]
|
|
333
|
+
num_samples: int
|
|
334
|
+
|
|
335
|
+
# when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
|
|
336
|
+
# this will result a class-balanced sampling, no matter how imbalance the labels are.
|
|
337
|
+
# NOTE: here we use replacement=True, you can change it if you don't upsample a class.
|
|
338
|
+
def __init__(
|
|
339
|
+
self, label_weights: Sequence[float], labels: Sequence[int], num_samples: int
|
|
340
|
+
) -> None:
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
:param label_weights: list(len=num_classes)[float], weights for each class.
|
|
344
|
+
:param labels: list(len=dataset_len)[int], labels of a dataset.
|
|
345
|
+
:param num_samples: number of samples.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
super(LabelWeightedSampler, self).__init__(None)
|
|
349
|
+
# reweight labels from counter otherwsie same weight to labels that have many elements vs a few
|
|
350
|
+
label_weights = np.array(label_weights) * np.bincount(labels)
|
|
351
|
+
|
|
352
|
+
self.label_weights = torch.as_tensor(label_weights, dtype=torch.float32)
|
|
353
|
+
self.labels = torch.as_tensor(labels, dtype=torch.int)
|
|
354
|
+
self.num_samples = num_samples
|
|
355
|
+
# list of tensor.
|
|
356
|
+
self.klass_indices = [
|
|
357
|
+
(self.labels == i_klass).nonzero().squeeze(1)
|
|
358
|
+
for i_klass in range(len(label_weights))
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
def __iter__(self):
|
|
362
|
+
sample_labels = torch.multinomial(
|
|
363
|
+
self.label_weights, num_samples=self.num_samples, replacement=True
|
|
364
|
+
)
|
|
365
|
+
sample_indices = torch.empty_like(sample_labels)
|
|
366
|
+
for i_klass, klass_index in enumerate(self.klass_indices):
|
|
367
|
+
if klass_index.numel() == 0:
|
|
368
|
+
continue
|
|
369
|
+
left_inds = (sample_labels == i_klass).nonzero().squeeze(1)
|
|
370
|
+
right_inds = torch.randint(len(klass_index), size=(len(left_inds),))
|
|
371
|
+
sample_indices[left_inds] = klass_index[right_inds]
|
|
372
|
+
yield from iter(sample_indices.tolist())
|
|
373
|
+
|
|
374
|
+
def __len__(self):
|
|
375
|
+
return self.num_samples
|