scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +76 -78
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +134 -77
- scdataloader/datamodule.py +638 -245
- scdataloader/mapped.py +104 -43
- scdataloader/preprocess.py +136 -110
- scdataloader/utils.py +158 -52
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/METADATA +6 -7
- scdataloader-2.0.2.dist-info/RECORD +16 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/WHEEL +1 -1
- scdataloader-2.0.2.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/entry_points.txt +0 -0
scdataloader/data.py
CHANGED
|
@@ -2,7 +2,7 @@ import warnings
|
|
|
2
2
|
from collections import Counter
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from functools import reduce
|
|
5
|
-
from typing import Literal, Optional, Union
|
|
5
|
+
from typing import List, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
# ln.connect("scprint")
|
|
8
8
|
import bionty as bt
|
|
@@ -16,7 +16,6 @@ from torch.utils.data import Dataset as torchDataset
|
|
|
16
16
|
|
|
17
17
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
18
18
|
|
|
19
|
-
from .config import LABELS_TOADD
|
|
20
19
|
from .mapped import MappedCollection, _Connect
|
|
21
20
|
|
|
22
21
|
|
|
@@ -39,28 +38,30 @@ class Dataset(torchDataset):
|
|
|
39
38
|
----
|
|
40
39
|
lamin_dataset (lamindb.Dataset): lamin dataset to load
|
|
41
40
|
genedf (pd.Dataframe): dataframe containing the genes to load
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
obs (list[str]): list of observations to load from the Collection
|
|
45
|
-
clss_to_predict (list[str]): list of observations to encode
|
|
41
|
+
obs (List[str]): list of observations to load from the Collection
|
|
42
|
+
clss_to_predict (List[str]): list of observations to encode
|
|
46
43
|
join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
|
|
47
44
|
hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
|
|
45
|
+
metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
|
|
46
|
+
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
|
|
47
|
+
store_location (str, optional): The location to store the sampler indices. Defaults to None.
|
|
48
|
+
force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
|
|
48
49
|
"""
|
|
49
50
|
|
|
50
51
|
lamin_dataset: ln.Collection
|
|
51
52
|
genedf: Optional[pd.DataFrame] = None
|
|
52
|
-
organisms: Optional[Union[list[str], str]] = field(
|
|
53
|
-
default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
|
|
54
|
-
)
|
|
55
53
|
# set of obs to prepare for prediction (encode)
|
|
56
|
-
clss_to_predict: Optional[
|
|
54
|
+
clss_to_predict: Optional[List[str]] = field(default_factory=list)
|
|
57
55
|
# set of obs that need to be hierarchically prepared
|
|
58
|
-
hierarchical_clss: Optional[
|
|
56
|
+
hierarchical_clss: Optional[List[str]] = field(default_factory=list)
|
|
59
57
|
join_vars: Literal["inner", "outer"] | None = None
|
|
60
58
|
metacell_mode: float = 0.0
|
|
61
59
|
get_knn_cells: bool = False
|
|
60
|
+
store_location: str | None = None
|
|
61
|
+
force_recompute_indices: bool = False
|
|
62
62
|
|
|
63
63
|
def __post_init__(self):
|
|
64
|
+
# see at the end of the file for the mapped function
|
|
64
65
|
self.mapped_dataset = mapped(
|
|
65
66
|
self.lamin_dataset,
|
|
66
67
|
obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
|
|
@@ -71,6 +72,8 @@ class Dataset(torchDataset):
|
|
|
71
72
|
parallel=True,
|
|
72
73
|
metacell_mode=self.metacell_mode,
|
|
73
74
|
get_knn_cells=self.get_knn_cells,
|
|
75
|
+
store_location=self.store_location,
|
|
76
|
+
force_recompute_indices=self.force_recompute_indices,
|
|
74
77
|
)
|
|
75
78
|
print(
|
|
76
79
|
"won't do any check but we recommend to have your dataset coming from local storage"
|
|
@@ -85,7 +88,7 @@ class Dataset(torchDataset):
|
|
|
85
88
|
if clss not in self.hierarchical_clss:
|
|
86
89
|
# otherwise it's already been done
|
|
87
90
|
self.class_topred[clss] = set(
|
|
88
|
-
self.mapped_dataset.
|
|
91
|
+
self.mapped_dataset.encoders[clss].keys()
|
|
89
92
|
)
|
|
90
93
|
if (
|
|
91
94
|
self.mapped_dataset.unknown_label
|
|
@@ -94,12 +97,19 @@ class Dataset(torchDataset):
|
|
|
94
97
|
self.class_topred[clss] -= set(
|
|
95
98
|
[self.mapped_dataset.unknown_label]
|
|
96
99
|
)
|
|
97
|
-
|
|
98
100
|
if self.genedf is None:
|
|
101
|
+
if "organism_ontology_term_id" not in self.clss_to_predict:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
|
|
104
|
+
)
|
|
105
|
+
self.organisms = list(self.class_topred["organism_ontology_term_id"])
|
|
99
106
|
self.genedf = load_genes(self.organisms)
|
|
107
|
+
else:
|
|
108
|
+
self.organisms = self.genedf["organism"].unique().tolist()
|
|
109
|
+
self.organisms.sort()
|
|
100
110
|
|
|
101
111
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
102
|
-
self.check_aligned_vars()
|
|
112
|
+
# self.check_aligned_vars()
|
|
103
113
|
|
|
104
114
|
def check_aligned_vars(self):
|
|
105
115
|
vars = self.genedf.index.tolist()
|
|
@@ -117,6 +127,10 @@ class Dataset(torchDataset):
|
|
|
117
127
|
def encoder(self):
|
|
118
128
|
return self.mapped_dataset.encoders
|
|
119
129
|
|
|
130
|
+
@encoder.setter
|
|
131
|
+
def encoder(self, encoder):
|
|
132
|
+
self.mapped_dataset.encoders = encoder
|
|
133
|
+
|
|
120
134
|
def __getitem__(self, *args, **kwargs):
|
|
121
135
|
item = self.mapped_dataset.__getitem__(*args, **kwargs)
|
|
122
136
|
return item
|
|
@@ -132,7 +146,11 @@ class Dataset(torchDataset):
|
|
|
132
146
|
+ " {} genes\n".format(self.genedf.shape[0])
|
|
133
147
|
+ " {} clss_to_predict\n".format(len(self.clss_to_predict))
|
|
134
148
|
+ " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
|
|
135
|
-
+
|
|
149
|
+
+ (
|
|
150
|
+
" {} organisms\n".format(len(self.organisms))
|
|
151
|
+
if self.organisms is not None
|
|
152
|
+
else ""
|
|
153
|
+
)
|
|
136
154
|
+ (
|
|
137
155
|
"dataset contains {} classes to predict\n".format(
|
|
138
156
|
sum([len(self.class_topred[i]) for i in self.class_topred])
|
|
@@ -143,41 +161,21 @@ class Dataset(torchDataset):
|
|
|
143
161
|
+ " {} metacell_mode\n".format(self.metacell_mode)
|
|
144
162
|
)
|
|
145
163
|
|
|
146
|
-
def
|
|
164
|
+
def get_label_cats(
|
|
147
165
|
self,
|
|
148
|
-
obs_keys: str
|
|
149
|
-
scaler: int = 10,
|
|
150
|
-
return_categories=False,
|
|
151
|
-
bypass_label=["neuron"],
|
|
166
|
+
obs_keys: Union[str, List[str]],
|
|
152
167
|
):
|
|
153
|
-
"""Get all
|
|
168
|
+
"""Get all categories for the given label keys."""
|
|
154
169
|
if isinstance(obs_keys, str):
|
|
155
170
|
obs_keys = [obs_keys]
|
|
156
|
-
|
|
171
|
+
labels = None
|
|
157
172
|
for label_key in obs_keys:
|
|
158
|
-
labels_to_str = (
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
labels_list.append(labels_to_str)
|
|
162
|
-
if len(labels_list) > 1:
|
|
163
|
-
labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
|
|
164
|
-
else:
|
|
165
|
-
labels = labels_list[0]
|
|
166
|
-
|
|
167
|
-
counter = Counter(labels) # type: ignore
|
|
168
|
-
if return_categories:
|
|
169
|
-
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
170
|
-
labels = np.array([rn[label] for label in labels])
|
|
171
|
-
counter = np.array(list(counter.values()))
|
|
172
|
-
weights = scaler / (counter + scaler)
|
|
173
|
-
return weights, labels
|
|
174
|
-
else:
|
|
175
|
-
counts = np.array([counter[label] for label in labels])
|
|
176
|
-
if scaler is None:
|
|
177
|
-
weights = 1.0 / counts
|
|
173
|
+
labels_to_str = self.mapped_dataset.get_merged_labels(label_key)
|
|
174
|
+
if labels is None:
|
|
175
|
+
labels = labels_to_str
|
|
178
176
|
else:
|
|
179
|
-
|
|
180
|
-
|
|
177
|
+
labels = concat_categorical_codes([labels, labels_to_str])
|
|
178
|
+
return np.array(labels.codes)
|
|
181
179
|
|
|
182
180
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
183
181
|
"""
|
|
@@ -187,16 +185,16 @@ class Dataset(torchDataset):
|
|
|
187
185
|
idx (int): index of the element to get
|
|
188
186
|
|
|
189
187
|
Returns:
|
|
190
|
-
|
|
188
|
+
List[str]: list of unseen genes
|
|
191
189
|
"""
|
|
192
190
|
return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
|
|
193
191
|
|
|
194
|
-
def define_hierarchies(self, clsses:
|
|
192
|
+
def define_hierarchies(self, clsses: List[str]):
|
|
195
193
|
"""
|
|
196
194
|
define_hierarchies is a method to define the hierarchies for the classes to predict
|
|
197
195
|
|
|
198
196
|
Args:
|
|
199
|
-
clsses (
|
|
197
|
+
clsses (List[str]): list of classes to predict
|
|
200
198
|
|
|
201
199
|
Raises:
|
|
202
200
|
ValueError: if the class is not in the accepted classes
|
|
@@ -223,19 +221,19 @@ class Dataset(torchDataset):
|
|
|
223
221
|
elif clss == "cell_type_ontology_term_id":
|
|
224
222
|
parentdf = (
|
|
225
223
|
bt.CellType.filter()
|
|
226
|
-
.df(include=["parents__ontology_id"])
|
|
224
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
227
225
|
.set_index("ontology_id")
|
|
228
226
|
)
|
|
229
227
|
elif clss == "tissue_ontology_term_id":
|
|
230
228
|
parentdf = (
|
|
231
229
|
bt.Tissue.filter()
|
|
232
|
-
.df(include=["parents__ontology_id"])
|
|
230
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
233
231
|
.set_index("ontology_id")
|
|
234
232
|
)
|
|
235
233
|
elif clss == "disease_ontology_term_id":
|
|
236
234
|
parentdf = (
|
|
237
235
|
bt.Disease.filter()
|
|
238
|
-
.df(include=["parents__ontology_id"])
|
|
236
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
239
237
|
.set_index("ontology_id")
|
|
240
238
|
)
|
|
241
239
|
elif clss in [
|
|
@@ -245,19 +243,19 @@ class Dataset(torchDataset):
|
|
|
245
243
|
]:
|
|
246
244
|
parentdf = (
|
|
247
245
|
bt.DevelopmentalStage.filter()
|
|
248
|
-
.df(include=["parents__ontology_id"])
|
|
246
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
249
247
|
.set_index("ontology_id")
|
|
250
248
|
)
|
|
251
249
|
elif clss == "assay_ontology_term_id":
|
|
252
250
|
parentdf = (
|
|
253
251
|
bt.ExperimentalFactor.filter()
|
|
254
|
-
.df(include=["parents__ontology_id"])
|
|
252
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
255
253
|
.set_index("ontology_id")
|
|
256
254
|
)
|
|
257
255
|
elif clss == "self_reported_ethnicity_ontology_term_id":
|
|
258
256
|
parentdf = (
|
|
259
257
|
bt.Ethnicity.filter()
|
|
260
|
-
.df(include=["parents__ontology_id"])
|
|
258
|
+
.df(include=["parents__ontology_id", "ontology_id"])
|
|
261
259
|
.set_index("ontology_id")
|
|
262
260
|
)
|
|
263
261
|
|
|
@@ -267,13 +265,17 @@ class Dataset(torchDataset):
|
|
|
267
265
|
clss
|
|
268
266
|
)
|
|
269
267
|
)
|
|
270
|
-
cats = set(self.mapped_dataset.
|
|
271
|
-
addition = set(LABELS_TOADD.get(clss, {}).values())
|
|
272
|
-
cats |= addition
|
|
268
|
+
cats = set(self.mapped_dataset.encoders[clss].keys())
|
|
273
269
|
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
|
270
|
+
groupings.pop(None, None)
|
|
274
271
|
for i, j in groupings.items():
|
|
275
272
|
if len(j) == 0:
|
|
273
|
+
# that should not happen
|
|
274
|
+
import pdb
|
|
275
|
+
|
|
276
|
+
pdb.set_trace()
|
|
276
277
|
groupings.pop(i)
|
|
278
|
+
|
|
277
279
|
self.labels_groupings[clss] = groupings
|
|
278
280
|
if clss in self.clss_to_predict:
|
|
279
281
|
# if we have added new clss, we need to update the encoder with them too.
|
|
@@ -287,11 +289,12 @@ class Dataset(torchDataset):
|
|
|
287
289
|
)
|
|
288
290
|
|
|
289
291
|
for i, v in enumerate(
|
|
290
|
-
|
|
292
|
+
set(groupings.keys())
|
|
293
|
+
- set(self.mapped_dataset.encoders[clss].keys())
|
|
291
294
|
):
|
|
292
295
|
self.mapped_dataset.encoders[clss].update({v: mlength + i})
|
|
293
|
-
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
294
296
|
|
|
297
|
+
# we need to change the ordering so that the things that can't be predicted appear afterward
|
|
295
298
|
self.class_topred[clss] = leaf_labels
|
|
296
299
|
c = 0
|
|
297
300
|
update = {}
|
|
@@ -318,8 +321,10 @@ class SimpleAnnDataset(torchDataset):
|
|
|
318
321
|
def __init__(
|
|
319
322
|
self,
|
|
320
323
|
adata: AnnData,
|
|
321
|
-
obs_to_output: Optional[
|
|
324
|
+
obs_to_output: Optional[List[str]] = [],
|
|
322
325
|
layer: Optional[str] = None,
|
|
326
|
+
get_knn_cells: bool = False,
|
|
327
|
+
encoder: Optional[dict[str, dict]] = None,
|
|
323
328
|
):
|
|
324
329
|
"""
|
|
325
330
|
SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
|
|
@@ -328,43 +333,53 @@ class SimpleAnnDataset(torchDataset):
|
|
|
328
333
|
Args:
|
|
329
334
|
----
|
|
330
335
|
adata (anndata.AnnData): anndata object to use
|
|
331
|
-
obs_to_output (
|
|
336
|
+
obs_to_output (List[str]): list of observations to output from anndata.obs
|
|
332
337
|
layer (str): layer of the anndata to use
|
|
338
|
+
get_knn_cells (bool): whether to get the knn cells
|
|
339
|
+
encoder (dict[str, dict]): dictionary of encoders for the observations.
|
|
333
340
|
"""
|
|
334
341
|
self.adataX = adata.layers[layer] if layer is not None else adata.X
|
|
335
342
|
self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
|
|
343
|
+
self.encoder = encoder if encoder is not None else {}
|
|
344
|
+
|
|
336
345
|
self.obs_to_output = adata.obs[obs_to_output]
|
|
346
|
+
self.get_knn_cells = get_knn_cells
|
|
347
|
+
if get_knn_cells and "connectivities" not in adata.obsp:
|
|
348
|
+
raise ValueError("neighbors key not found in adata.obsm")
|
|
349
|
+
if get_knn_cells:
|
|
350
|
+
self.distances = adata.obsp["distances"]
|
|
337
351
|
|
|
338
352
|
def __len__(self):
|
|
339
353
|
return self.adataX.shape[0]
|
|
340
354
|
|
|
341
355
|
def __iter__(self):
|
|
342
|
-
for idx
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
out = {"X": self.adataX[idx].reshape(-1)}
|
|
346
|
-
out.update(
|
|
347
|
-
{name: val for name, val in self.obs_to_output.iloc[idx].items()}
|
|
348
|
-
)
|
|
349
|
-
yield out
|
|
356
|
+
for idx in range(self.adataX.shape[0]):
|
|
357
|
+
out = self.__getitem__(idx)
|
|
358
|
+
yield out
|
|
350
359
|
|
|
351
360
|
def __getitem__(self, idx):
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
out.update(
|
|
356
|
-
|
|
361
|
+
out = {"X": self.adataX[idx].reshape(-1)}
|
|
362
|
+
# put the observation into the output and encode if needed
|
|
363
|
+
for name, val in self.obs_to_output.iloc[idx].items():
|
|
364
|
+
out.update({name: self.encoder[name][val] if name in self.encoder else val})
|
|
365
|
+
if self.get_knn_cells:
|
|
366
|
+
distances = self.distances[idx].toarray()[0]
|
|
367
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
368
|
+
out["knn_cells"] = np.array(
|
|
369
|
+
[self.adataX[i].reshape(-1) for i in nn_idx],
|
|
370
|
+
dtype=int,
|
|
357
371
|
)
|
|
372
|
+
out["knn_cells_info"] = distances[nn_idx]
|
|
358
373
|
return out
|
|
359
374
|
|
|
360
375
|
|
|
361
376
|
def mapped(
|
|
362
377
|
dataset,
|
|
363
|
-
obs_keys:
|
|
364
|
-
obsm_keys:
|
|
378
|
+
obs_keys: List[str] | None = None,
|
|
379
|
+
obsm_keys: List[str] | None = None,
|
|
365
380
|
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
366
381
|
join: Literal["inner", "outer"] | None = "inner",
|
|
367
|
-
encode_labels: bool |
|
|
382
|
+
encode_labels: bool | List[str] = True,
|
|
368
383
|
unknown_label: str | dict[str, str] | None = None,
|
|
369
384
|
cache_categories: bool = True,
|
|
370
385
|
parallel: bool = False,
|
|
@@ -372,8 +387,10 @@ def mapped(
|
|
|
372
387
|
stream: bool = False,
|
|
373
388
|
is_run_input: bool | None = None,
|
|
374
389
|
metacell_mode: bool = False,
|
|
375
|
-
meta_assays:
|
|
390
|
+
meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
|
|
376
391
|
get_knn_cells: bool = False,
|
|
392
|
+
store_location: str | None = None,
|
|
393
|
+
force_recompute_indices: bool = False,
|
|
377
394
|
) -> MappedCollection:
|
|
378
395
|
path_list = []
|
|
379
396
|
for artifact in dataset.artifacts.all():
|
|
@@ -401,5 +418,45 @@ def mapped(
|
|
|
401
418
|
meta_assays=meta_assays,
|
|
402
419
|
metacell_mode=metacell_mode,
|
|
403
420
|
get_knn_cells=get_knn_cells,
|
|
421
|
+
store_location=store_location,
|
|
422
|
+
force_recompute_indices=force_recompute_indices,
|
|
404
423
|
)
|
|
405
424
|
return ds
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
|
|
428
|
+
"""Efficiently combine multiple categorical data using their codes,
|
|
429
|
+
only creating categories for combinations that exist in the data.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
series_list: List of pandas Categorical data
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Combined Categorical with only existing combinations
|
|
436
|
+
"""
|
|
437
|
+
# Get the codes for each categorical
|
|
438
|
+
codes_list = [s.codes.astype(np.int32) for s in series_list]
|
|
439
|
+
n_cats = [len(s.categories) for s in series_list]
|
|
440
|
+
|
|
441
|
+
# Calculate combined codes
|
|
442
|
+
combined_codes = codes_list[0]
|
|
443
|
+
multiplier = n_cats[0]
|
|
444
|
+
for codes, n_cat in zip(codes_list[1:], n_cats[1:]):
|
|
445
|
+
combined_codes = (combined_codes * n_cat) + codes
|
|
446
|
+
multiplier *= n_cat
|
|
447
|
+
|
|
448
|
+
# Find unique combinations that actually exist in the data
|
|
449
|
+
unique_existing_codes = np.unique(combined_codes)
|
|
450
|
+
|
|
451
|
+
# Create a mapping from old codes to new compressed codes
|
|
452
|
+
code_mapping = {old: new for new, old in enumerate(unique_existing_codes)}
|
|
453
|
+
|
|
454
|
+
# Map the combined codes to their new compressed values
|
|
455
|
+
combined_codes = np.array([code_mapping[code] for code in combined_codes])
|
|
456
|
+
|
|
457
|
+
# Create final categorical with only existing combinations
|
|
458
|
+
return pd.Categorical.from_codes(
|
|
459
|
+
codes=combined_codes,
|
|
460
|
+
categories=np.arange(len(unique_existing_codes)),
|
|
461
|
+
ordered=False,
|
|
462
|
+
)
|