scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/__main__.py +4 -5
- scdataloader/collator.py +76 -78
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +134 -77
- scdataloader/datamodule.py +638 -245
- scdataloader/mapped.py +104 -43
- scdataloader/preprocess.py +136 -110
- scdataloader/utils.py +158 -52
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/METADATA +6 -7
- scdataloader-2.0.2.dist-info/RECORD +16 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/WHEEL +1 -1
- scdataloader-2.0.2.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.2.dist-info}/entry_points.txt +0 -0
scdataloader/mapped.py
CHANGED
|
@@ -7,12 +7,14 @@
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
+
import os
|
|
10
11
|
from collections import Counter
|
|
11
12
|
from functools import reduce
|
|
12
|
-
from typing import TYPE_CHECKING, Literal
|
|
13
|
+
from typing import TYPE_CHECKING, List, Literal
|
|
13
14
|
|
|
14
15
|
import numpy as np
|
|
15
16
|
import pandas as pd
|
|
17
|
+
import torch
|
|
16
18
|
from lamindb.core.storage._anndata_accessor import (
|
|
17
19
|
ArrayType,
|
|
18
20
|
ArrayTypes,
|
|
@@ -24,10 +26,13 @@ from lamindb.core.storage._anndata_accessor import (
|
|
|
24
26
|
registry,
|
|
25
27
|
)
|
|
26
28
|
from lamindb_setup.core.upath import UPath
|
|
29
|
+
from tqdm import tqdm
|
|
27
30
|
|
|
28
31
|
if TYPE_CHECKING:
|
|
29
32
|
from lamindb_setup.core.types import UPathStr
|
|
30
33
|
|
|
34
|
+
from pandas.api.types import union_categoricals
|
|
35
|
+
|
|
31
36
|
|
|
32
37
|
class _Connect:
|
|
33
38
|
def __init__(self, storage):
|
|
@@ -106,24 +111,28 @@ class MappedCollection:
|
|
|
106
111
|
meta_assays: Assays that are already defined as metacells.
|
|
107
112
|
metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
|
|
108
113
|
get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
|
|
114
|
+
store_location: Path to a directory where klass_indices can be cached, or full path to the cache file.
|
|
115
|
+
force_recompute_indices: If True, recompute indices even if a cache file exists.
|
|
109
116
|
"""
|
|
110
117
|
|
|
111
118
|
def __init__(
|
|
112
119
|
self,
|
|
113
|
-
path_list:
|
|
114
|
-
layers_keys: str |
|
|
115
|
-
obs_keys: str |
|
|
116
|
-
obsm_keys: str |
|
|
120
|
+
path_list: List[UPathStr],
|
|
121
|
+
layers_keys: str | List[str] | None = None,
|
|
122
|
+
obs_keys: str | List[str] | None = None,
|
|
123
|
+
obsm_keys: str | List[str] | None = None,
|
|
117
124
|
obs_filter: dict[str, str | tuple[str, ...]] | None = None,
|
|
118
125
|
join: Literal["inner", "outer"] | None = "inner",
|
|
119
|
-
encode_labels: bool |
|
|
126
|
+
encode_labels: bool | List[str] = True,
|
|
120
127
|
unknown_label: str | dict[str, str] | None = None,
|
|
121
128
|
cache_categories: bool = True,
|
|
122
129
|
parallel: bool = False,
|
|
123
130
|
dtype: str | None = None,
|
|
124
131
|
metacell_mode: float = 0.0,
|
|
125
132
|
get_knn_cells: bool = False,
|
|
126
|
-
meta_assays:
|
|
133
|
+
meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
|
|
134
|
+
store_location: str | None = None,
|
|
135
|
+
force_recompute_indices: bool = False,
|
|
127
136
|
):
|
|
128
137
|
if join not in {None, "inner", "outer"}: # pragma: nocover
|
|
129
138
|
raise ValueError(
|
|
@@ -181,14 +190,28 @@ class MappedCollection:
|
|
|
181
190
|
self._cache_cats: dict = {}
|
|
182
191
|
if self.obs_keys is not None:
|
|
183
192
|
if cache_categories:
|
|
184
|
-
|
|
193
|
+
if store_location is not None:
|
|
194
|
+
os.makedirs(store_location, exist_ok=True)
|
|
195
|
+
self.store_location = os.path.join(store_location, "categories")
|
|
196
|
+
if (
|
|
197
|
+
not os.path.exists(self.store_location)
|
|
198
|
+
or force_recompute_indices
|
|
199
|
+
):
|
|
200
|
+
self._cache_categories(self.obs_keys)
|
|
201
|
+
torch.save(self._cache_cats, self.store_location)
|
|
202
|
+
else:
|
|
203
|
+
self._cache_cats = torch.load(
|
|
204
|
+
self.store_location, weights_only=False
|
|
205
|
+
)
|
|
206
|
+
print(f"Loaded categories from {self.store_location}")
|
|
185
207
|
self.encoders: dict = {}
|
|
186
208
|
if self.encode_labels:
|
|
187
209
|
self._make_encoders(self.encode_labels) # type: ignore
|
|
188
|
-
|
|
189
210
|
self.n_obs_list = []
|
|
190
211
|
self.indices_list = []
|
|
191
|
-
for i, storage in
|
|
212
|
+
for i, storage in tqdm(
|
|
213
|
+
enumerate(self.storages), desc="Checking datasets", total=len(self.storages)
|
|
214
|
+
):
|
|
192
215
|
with _Connect(storage) as store:
|
|
193
216
|
X = store["X"]
|
|
194
217
|
store_path = self.path_list[i]
|
|
@@ -263,13 +286,10 @@ class MappedCollection:
|
|
|
263
286
|
self._cache_cats = {}
|
|
264
287
|
for label in obs_keys:
|
|
265
288
|
self._cache_cats[label] = []
|
|
266
|
-
for storage in self.storages:
|
|
289
|
+
for storage in tqdm(self.storages, f"caching categories, {label}"):
|
|
267
290
|
with _Connect(storage) as store:
|
|
268
291
|
cats = self._get_categories(store, label)
|
|
269
|
-
if cats
|
|
270
|
-
cats = (
|
|
271
|
-
_decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
272
|
-
)
|
|
292
|
+
cats = _decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
273
293
|
self._cache_cats[label].append(cats)
|
|
274
294
|
|
|
275
295
|
def _make_encoders(self, encode_labels: list):
|
|
@@ -330,7 +350,7 @@ class MappedCollection:
|
|
|
330
350
|
vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
|
|
331
351
|
return all(vrs_sort_status)
|
|
332
352
|
|
|
333
|
-
def check_vars_non_aligned(self, vars: pd.Index |
|
|
353
|
+
def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
|
|
334
354
|
"""Returns indices of objects with non-aligned variables.
|
|
335
355
|
|
|
336
356
|
Args:
|
|
@@ -362,7 +382,7 @@ class MappedCollection:
|
|
|
362
382
|
return (self.n_obs, self.n_vars)
|
|
363
383
|
|
|
364
384
|
@property
|
|
365
|
-
def original_shapes(self) ->
|
|
385
|
+
def original_shapes(self) -> List[tuple[int, int]]:
|
|
366
386
|
"""Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
|
|
367
387
|
if self.n_vars_list is None:
|
|
368
388
|
n_vars_list = [None] * len(self.n_obs_list)
|
|
@@ -403,10 +423,36 @@ class MappedCollection:
|
|
|
403
423
|
cats = None
|
|
404
424
|
label_idx = self._get_obs_idx(store, obs_idx, label, cats)
|
|
405
425
|
if label in self.encoders:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
426
|
+
try:
|
|
427
|
+
label_idx = self.encoders[label][label_idx]
|
|
428
|
+
except:
|
|
429
|
+
print(self.storages[storage_idx])
|
|
430
|
+
print(label, label_idx)
|
|
431
|
+
print(idx)
|
|
432
|
+
print(cats)
|
|
433
|
+
raise
|
|
434
|
+
try:
|
|
435
|
+
out[label] = label_idx
|
|
436
|
+
except:
|
|
437
|
+
print(self.storages[storage_idx])
|
|
438
|
+
print(label, label_idx)
|
|
439
|
+
print(out)
|
|
440
|
+
raise
|
|
441
|
+
|
|
442
|
+
if self.get_knn_cells:
|
|
443
|
+
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
444
|
+
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
445
|
+
out["knn_cells"] = np.array(
|
|
446
|
+
[
|
|
447
|
+
self._get_data_idx(
|
|
448
|
+
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
449
|
+
)
|
|
450
|
+
for i in nn_idx
|
|
451
|
+
],
|
|
452
|
+
dtype=int,
|
|
453
|
+
)
|
|
454
|
+
out["knn_cells_info"] = distances[nn_idx]
|
|
455
|
+
elif self.metacell_mode > 0:
|
|
410
456
|
if (
|
|
411
457
|
len(self.meta_assays) > 0
|
|
412
458
|
and "assay_ontology_term_id" in self.obs_keys
|
|
@@ -423,19 +469,6 @@ class MappedCollection:
|
|
|
423
469
|
out[layers_key] += self._get_data_idx(
|
|
424
470
|
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
425
471
|
)
|
|
426
|
-
elif self.get_knn_cells:
|
|
427
|
-
distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
|
|
428
|
-
nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
|
|
429
|
-
out["knn_cells"] = np.array(
|
|
430
|
-
[
|
|
431
|
-
self._get_data_idx(
|
|
432
|
-
lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
|
|
433
|
-
)
|
|
434
|
-
for i in nn_idx
|
|
435
|
-
],
|
|
436
|
-
dtype=int,
|
|
437
|
-
)
|
|
438
|
-
out["distances"] = distances[nn_idx]
|
|
439
472
|
|
|
440
473
|
return out
|
|
441
474
|
|
|
@@ -510,7 +543,7 @@ class MappedCollection:
|
|
|
510
543
|
|
|
511
544
|
def get_label_weights(
|
|
512
545
|
self,
|
|
513
|
-
obs_keys: str |
|
|
546
|
+
obs_keys: str | List[str],
|
|
514
547
|
scaler: float | None = None,
|
|
515
548
|
return_categories: bool = False,
|
|
516
549
|
):
|
|
@@ -555,21 +588,41 @@ class MappedCollection:
|
|
|
555
588
|
weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
|
|
556
589
|
return weights
|
|
557
590
|
|
|
558
|
-
def get_merged_labels(self, label_key: str):
|
|
591
|
+
def get_merged_labels(self, label_key: str, is_cat: bool = True):
|
|
559
592
|
"""Get merged labels for `label_key` from all `.obs`."""
|
|
560
593
|
labels_merge = []
|
|
561
|
-
for i, storage in
|
|
594
|
+
for i, storage in tqdm(
|
|
595
|
+
enumerate(self.storages), label_key, total=len(self.storages)
|
|
596
|
+
):
|
|
562
597
|
with _Connect(storage) as store:
|
|
563
|
-
labels = self._get_labels(
|
|
598
|
+
labels = self._get_labels(
|
|
599
|
+
store, label_key, storage_idx=i, is_cat=is_cat
|
|
600
|
+
)
|
|
564
601
|
if self.filtered:
|
|
565
602
|
labels = labels[self.indices_list[i]]
|
|
566
603
|
labels_merge.append(labels)
|
|
567
|
-
|
|
604
|
+
if is_cat:
|
|
605
|
+
try:
|
|
606
|
+
return union_categoricals(labels_merge)
|
|
607
|
+
except TypeError:
|
|
608
|
+
typ = type(int)
|
|
609
|
+
for i in range(len(labels_merge)):
|
|
610
|
+
if typ != type(labels_merge[i][0]):
|
|
611
|
+
self.storages[i]
|
|
612
|
+
typ = type(labels_merge[i][0])
|
|
613
|
+
return []
|
|
614
|
+
else:
|
|
615
|
+
print("concatenating labels")
|
|
616
|
+
return np.concatenate(labels_merge)
|
|
568
617
|
|
|
569
618
|
def get_merged_categories(self, label_key: str):
|
|
570
619
|
"""Get merged categories for `label_key` from all `.obs`."""
|
|
571
620
|
cats_merge = set()
|
|
572
|
-
for i, storage in
|
|
621
|
+
for i, storage in tqdm(
|
|
622
|
+
enumerate(self.storages),
|
|
623
|
+
total=len(self.storages),
|
|
624
|
+
desc="merging all " + label_key + " categories",
|
|
625
|
+
):
|
|
573
626
|
with _Connect(storage) as store:
|
|
574
627
|
if label_key in self._cache_cats:
|
|
575
628
|
cats = self._cache_cats[label_key][i]
|
|
@@ -609,8 +662,8 @@ class MappedCollection:
|
|
|
609
662
|
else:
|
|
610
663
|
if "categories" in labels.attrs:
|
|
611
664
|
return labels.attrs["categories"]
|
|
612
|
-
|
|
613
|
-
return
|
|
665
|
+
elif labels.dtype == "bool":
|
|
666
|
+
return np.array(["True", "False"])
|
|
614
667
|
return None
|
|
615
668
|
|
|
616
669
|
def _get_codes(self, storage: StorageType, label_key: str):
|
|
@@ -626,11 +679,17 @@ class MappedCollection:
|
|
|
626
679
|
return label["codes"][...]
|
|
627
680
|
|
|
628
681
|
def _get_labels(
|
|
629
|
-
self,
|
|
682
|
+
self,
|
|
683
|
+
storage: StorageType,
|
|
684
|
+
label_key: str,
|
|
685
|
+
storage_idx: int | None = None,
|
|
686
|
+
is_cat: bool = True,
|
|
630
687
|
):
|
|
631
688
|
"""Get labels."""
|
|
632
689
|
codes = self._get_codes(storage, label_key)
|
|
633
690
|
labels = _decode(codes) if isinstance(codes[0], bytes) else codes
|
|
691
|
+
if labels.dtype == bool:
|
|
692
|
+
labels = labels.astype(int)
|
|
634
693
|
if storage_idx is not None and label_key in self._cache_cats:
|
|
635
694
|
cats = self._cache_cats[label_key][storage_idx]
|
|
636
695
|
else:
|
|
@@ -638,6 +697,8 @@ class MappedCollection:
|
|
|
638
697
|
if cats is not None:
|
|
639
698
|
cats = _decode(cats) if isinstance(cats[0], bytes) else cats
|
|
640
699
|
labels = cats[labels]
|
|
700
|
+
if is_cat:
|
|
701
|
+
labels = pd.Categorical(labels.astype(str))
|
|
641
702
|
return labels
|
|
642
703
|
|
|
643
704
|
def close(self):
|