scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -7,12 +7,14 @@
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
+ import os
10
11
  from collections import Counter
11
12
  from functools import reduce
12
- from typing import TYPE_CHECKING, Literal
13
+ from typing import TYPE_CHECKING, List, Literal
13
14
 
14
15
  import numpy as np
15
16
  import pandas as pd
17
+ import torch
16
18
  from lamindb.core.storage._anndata_accessor import (
17
19
  ArrayType,
18
20
  ArrayTypes,
@@ -24,10 +26,13 @@ from lamindb.core.storage._anndata_accessor import (
24
26
  registry,
25
27
  )
26
28
  from lamindb_setup.core.upath import UPath
29
+ from tqdm import tqdm
27
30
 
28
31
  if TYPE_CHECKING:
29
32
  from lamindb_setup.core.types import UPathStr
30
33
 
34
+ from pandas.api.types import union_categoricals
35
+
31
36
 
32
37
  class _Connect:
33
38
  def __init__(self, storage):
@@ -106,24 +111,28 @@ class MappedCollection:
106
111
  meta_assays: Assays that are already defined as metacells.
107
112
  metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
108
113
  get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
114
+ store_location: Path to a directory where klass_indices can be cached, or full path to the cache file.
115
+ force_recompute_indices: If True, recompute indices even if a cache file exists.
109
116
  """
110
117
 
111
118
  def __init__(
112
119
  self,
113
- path_list: list[UPathStr],
114
- layers_keys: str | list[str] | None = None,
115
- obs_keys: str | list[str] | None = None,
116
- obsm_keys: str | list[str] | None = None,
120
+ path_list: List[UPathStr],
121
+ layers_keys: str | List[str] | None = None,
122
+ obs_keys: str | List[str] | None = None,
123
+ obsm_keys: str | List[str] | None = None,
117
124
  obs_filter: dict[str, str | tuple[str, ...]] | None = None,
118
125
  join: Literal["inner", "outer"] | None = "inner",
119
- encode_labels: bool | list[str] = True,
126
+ encode_labels: bool | List[str] = True,
120
127
  unknown_label: str | dict[str, str] | None = None,
121
128
  cache_categories: bool = True,
122
129
  parallel: bool = False,
123
130
  dtype: str | None = None,
124
131
  metacell_mode: float = 0.0,
125
132
  get_knn_cells: bool = False,
126
- meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
133
+ meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
134
+ store_location: str | None = None,
135
+ force_recompute_indices: bool = False,
127
136
  ):
128
137
  if join not in {None, "inner", "outer"}: # pragma: nocover
129
138
  raise ValueError(
@@ -181,14 +190,28 @@ class MappedCollection:
181
190
  self._cache_cats: dict = {}
182
191
  if self.obs_keys is not None:
183
192
  if cache_categories:
184
- self._cache_categories(self.obs_keys)
193
+ if store_location is not None:
194
+ os.makedirs(store_location, exist_ok=True)
195
+ self.store_location = os.path.join(store_location, "categories")
196
+ if (
197
+ not os.path.exists(self.store_location)
198
+ or force_recompute_indices
199
+ ):
200
+ self._cache_categories(self.obs_keys)
201
+ torch.save(self._cache_cats, self.store_location)
202
+ else:
203
+ self._cache_cats = torch.load(
204
+ self.store_location, weights_only=False
205
+ )
206
+ print(f"Loaded categories from {self.store_location}")
185
207
  self.encoders: dict = {}
186
208
  if self.encode_labels:
187
209
  self._make_encoders(self.encode_labels) # type: ignore
188
-
189
210
  self.n_obs_list = []
190
211
  self.indices_list = []
191
- for i, storage in enumerate(self.storages):
212
+ for i, storage in tqdm(
213
+ enumerate(self.storages), desc="Checking datasets", total=len(self.storages)
214
+ ):
192
215
  with _Connect(storage) as store:
193
216
  X = store["X"]
194
217
  store_path = self.path_list[i]
@@ -263,13 +286,10 @@ class MappedCollection:
263
286
  self._cache_cats = {}
264
287
  for label in obs_keys:
265
288
  self._cache_cats[label] = []
266
- for storage in self.storages:
289
+ for storage in tqdm(self.storages, f"caching categories, {label}"):
267
290
  with _Connect(storage) as store:
268
291
  cats = self._get_categories(store, label)
269
- if cats is not None:
270
- cats = (
271
- _decode(cats) if isinstance(cats[0], bytes) else cats[...]
272
- )
292
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats[...]
273
293
  self._cache_cats[label].append(cats)
274
294
 
275
295
  def _make_encoders(self, encode_labels: list):
@@ -330,7 +350,7 @@ class MappedCollection:
330
350
  vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
331
351
  return all(vrs_sort_status)
332
352
 
333
- def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
353
+ def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
334
354
  """Returns indices of objects with non-aligned variables.
335
355
 
336
356
  Args:
@@ -362,7 +382,7 @@ class MappedCollection:
362
382
  return (self.n_obs, self.n_vars)
363
383
 
364
384
  @property
365
- def original_shapes(self) -> list[tuple[int, int]]:
385
+ def original_shapes(self) -> List[tuple[int, int]]:
366
386
  """Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
367
387
  if self.n_vars_list is None:
368
388
  n_vars_list = [None] * len(self.n_obs_list)
@@ -403,10 +423,36 @@ class MappedCollection:
403
423
  cats = None
404
424
  label_idx = self._get_obs_idx(store, obs_idx, label, cats)
405
425
  if label in self.encoders:
406
- label_idx = self.encoders[label][label_idx]
407
- out[label] = label_idx
408
-
409
- if self.metacell_mode > 0:
426
+ try:
427
+ label_idx = self.encoders[label][label_idx]
428
+ except:
429
+ print(self.storages[storage_idx])
430
+ print(label, label_idx)
431
+ print(idx)
432
+ print(cats)
433
+ raise
434
+ try:
435
+ out[label] = label_idx
436
+ except:
437
+ print(self.storages[storage_idx])
438
+ print(label, label_idx)
439
+ print(out)
440
+ raise
441
+
442
+ if self.get_knn_cells:
443
+ distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
444
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
445
+ out["knn_cells"] = np.array(
446
+ [
447
+ self._get_data_idx(
448
+ lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
449
+ )
450
+ for i in nn_idx
451
+ ],
452
+ dtype=int,
453
+ )
454
+ out["knn_cells_info"] = distances[nn_idx]
455
+ elif self.metacell_mode > 0:
410
456
  if (
411
457
  len(self.meta_assays) > 0
412
458
  and "assay_ontology_term_id" in self.obs_keys
@@ -423,19 +469,6 @@ class MappedCollection:
423
469
  out[layers_key] += self._get_data_idx(
424
470
  lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
425
471
  )
426
- elif self.get_knn_cells:
427
- distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
428
- nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
429
- out["knn_cells"] = np.array(
430
- [
431
- self._get_data_idx(
432
- lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
433
- )
434
- for i in nn_idx
435
- ],
436
- dtype=int,
437
- )
438
- out["distances"] = distances[nn_idx]
439
472
 
440
473
  return out
441
474
 
@@ -510,7 +543,7 @@ class MappedCollection:
510
543
 
511
544
  def get_label_weights(
512
545
  self,
513
- obs_keys: str | list[str],
546
+ obs_keys: str | List[str],
514
547
  scaler: float | None = None,
515
548
  return_categories: bool = False,
516
549
  ):
@@ -555,21 +588,41 @@ class MappedCollection:
555
588
  weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
556
589
  return weights
557
590
 
558
- def get_merged_labels(self, label_key: str):
591
+ def get_merged_labels(self, label_key: str, is_cat: bool = True):
559
592
  """Get merged labels for `label_key` from all `.obs`."""
560
593
  labels_merge = []
561
- for i, storage in enumerate(self.storages):
594
+ for i, storage in tqdm(
595
+ enumerate(self.storages), label_key, total=len(self.storages)
596
+ ):
562
597
  with _Connect(storage) as store:
563
- labels = self._get_labels(store, label_key, storage_idx=i)
598
+ labels = self._get_labels(
599
+ store, label_key, storage_idx=i, is_cat=is_cat
600
+ )
564
601
  if self.filtered:
565
602
  labels = labels[self.indices_list[i]]
566
603
  labels_merge.append(labels)
567
- return np.hstack(labels_merge)
604
+ if is_cat:
605
+ try:
606
+ return union_categoricals(labels_merge)
607
+ except TypeError:
608
+ typ = type(int)
609
+ for i in range(len(labels_merge)):
610
+ if typ != type(labels_merge[i][0]):
611
+ self.storages[i]
612
+ typ = type(labels_merge[i][0])
613
+ return []
614
+ else:
615
+ print("concatenating labels")
616
+ return np.concatenate(labels_merge)
568
617
 
569
618
  def get_merged_categories(self, label_key: str):
570
619
  """Get merged categories for `label_key` from all `.obs`."""
571
620
  cats_merge = set()
572
- for i, storage in enumerate(self.storages):
621
+ for i, storage in tqdm(
622
+ enumerate(self.storages),
623
+ total=len(self.storages),
624
+ desc="merging all " + label_key + " categories",
625
+ ):
573
626
  with _Connect(storage) as store:
574
627
  if label_key in self._cache_cats:
575
628
  cats = self._cache_cats[label_key][i]
@@ -609,8 +662,8 @@ class MappedCollection:
609
662
  else:
610
663
  if "categories" in labels.attrs:
611
664
  return labels.attrs["categories"]
612
- else:
613
- return None
665
+ elif labels.dtype == "bool":
666
+ return np.array(["True", "False"])
614
667
  return None
615
668
 
616
669
  def _get_codes(self, storage: StorageType, label_key: str):
@@ -626,11 +679,17 @@ class MappedCollection:
626
679
  return label["codes"][...]
627
680
 
628
681
  def _get_labels(
629
- self, storage: StorageType, label_key: str, storage_idx: int | None = None
682
+ self,
683
+ storage: StorageType,
684
+ label_key: str,
685
+ storage_idx: int | None = None,
686
+ is_cat: bool = True,
630
687
  ):
631
688
  """Get labels."""
632
689
  codes = self._get_codes(storage, label_key)
633
690
  labels = _decode(codes) if isinstance(codes[0], bytes) else codes
691
+ if labels.dtype == bool:
692
+ labels = labels.astype(int)
634
693
  if storage_idx is not None and label_key in self._cache_cats:
635
694
  cats = self._cache_cats[label_key][storage_idx]
636
695
  else:
@@ -638,6 +697,8 @@ class MappedCollection:
638
697
  if cats is not None:
639
698
  cats = _decode(cats) if isinstance(cats[0], bytes) else cats
640
699
  labels = cats[labels]
700
+ if is_cat:
701
+ labels = pd.Categorical(labels.astype(str))
641
702
  return labels
642
703
 
643
704
  def close(self):