scdataloader 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,18 @@
1
+ import math
1
2
  import multiprocessing as mp
2
3
  import os
3
4
  import random
4
5
  import time
5
6
  from concurrent.futures import ProcessPoolExecutor, as_completed
6
7
  from functools import partial
7
- from typing import Optional, Sequence, Union
8
+ from typing import List, Optional, Sequence, Union
8
9
 
9
10
  import lamindb as ln
10
11
  import lightning as L
11
12
  import numpy as np
12
13
  import pandas as pd
13
14
  import torch
14
- from torch.utils.data import DataLoader, Sampler
15
+ from torch.utils.data import DataLoader, Sampler, Subset
15
16
  from torch.utils.data.sampler import (
16
17
  RandomSampler,
17
18
  SequentialSampler,
@@ -25,32 +26,30 @@ from .data import Dataset
25
26
  from .utils import fileToList, getBiomartTable, listToFile
26
27
 
27
28
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
29
+ NNZ_SCALE = 1000
28
30
 
29
31
 
30
32
  class DataModule(L.LightningDataModule):
31
33
  def __init__(
32
34
  self,
33
35
  collection_name: str,
34
- clss_to_weight: list = ["organism_ontology_term_id"],
36
+ clss_to_weight: List[str] = ["organism_ontology_term_id"],
35
37
  weight_scaler: int = 10,
36
38
  n_samples_per_epoch: int = 2_000_000,
37
39
  validation_split: float = 0.2,
38
40
  test_split: float = 0,
39
- gene_embeddings: str = "",
40
41
  use_default_col: bool = True,
41
- gene_position_tolerance: int = 10_000,
42
42
  # this is for the mappedCollection
43
- clss_to_predict: list = ["organism_ontology_term_id"],
44
- hierarchical_clss: list = [],
43
+ clss_to_predict: List[str] = ["organism_ontology_term_id"],
44
+ hierarchical_clss: List[str] = [],
45
45
  # this is for the collator
46
46
  how: str = "random expr",
47
- organism_name: str = "organism_ontology_term_id",
47
+ organism_col: str = "organism_ontology_term_id",
48
48
  max_len: int = 1000,
49
- add_zero_genes: int = 100,
50
49
  replacement: bool = True,
51
- do_gene_pos: str = "",
50
+ gene_subset: Optional[list[str]] = None,
52
51
  tp_name: Optional[str] = None, # "heat_diff"
53
- assays_to_drop: list = [
52
+ assays_to_drop: List[str] = [
54
53
  # "EFO:0008853", #patch seq
55
54
  # "EFO:0010961", # visium
56
55
  "EFO:0030007", # ATACseq
@@ -62,6 +61,10 @@ class DataModule(L.LightningDataModule):
62
61
  force_recompute_indices: bool = False,
63
62
  sampler_workers: int = None,
64
63
  sampler_chunk_size: int = None,
64
+ organisms: Optional[str] = None,
65
+ genedf: Optional[pd.DataFrame] = None,
66
+ n_bins: int = 0,
67
+ curiculum: int = 0,
65
68
  **kwargs,
66
69
  ):
67
70
  """
@@ -78,23 +81,19 @@ class DataModule(L.LightningDataModule):
78
81
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
79
82
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
80
83
  it will use a full dataset and will round to the nearest dataset's cell count.
81
- gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
84
+ use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
85
+ clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
86
+ assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
87
+ gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
82
88
  the file must have ensembl_gene_id as index.
83
89
  This is used to subset the available genes further to the ones that have embeddings in your model.
84
- use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
85
- gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
86
- any genes within this distance of each other will be considered at the same position.
87
- clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
88
- assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
89
- do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
90
90
  max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
91
- add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
92
91
  how (str, optional): The method to use for the collator. Defaults to "random expr".
93
- organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
92
+ organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
94
93
  tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
95
- hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
94
+ hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
96
95
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
97
- clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
96
+ clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
98
97
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
99
98
  store_location (str, optional): The location to store the sampler indices. Defaults to None.
100
99
  force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
@@ -107,44 +106,44 @@ class DataModule(L.LightningDataModule):
107
106
  raise ValueError(
108
107
  "need 'organism_ontology_term_id' in the set of classes at least"
109
108
  )
109
+ if metacell_mode > 0 and get_knn_cells:
110
+ raise ValueError(
111
+ "cannot use metacell mode and get_knn_cells at the same time"
112
+ )
110
113
  mdataset = Dataset(
111
- ln.Collection.filter(name=collection_name, is_latest=True).first(),
114
+ ln.Collection.filter(key=collection_name, is_latest=True).first(),
112
115
  clss_to_predict=clss_to_predict,
113
116
  hierarchical_clss=hierarchical_clss,
114
117
  metacell_mode=metacell_mode,
115
118
  get_knn_cells=get_knn_cells,
116
119
  store_location=store_location,
117
120
  force_recompute_indices=force_recompute_indices,
121
+ genedf=genedf,
118
122
  )
119
123
  # and location
120
124
  self.metacell_mode = bool(metacell_mode)
121
125
  self.gene_pos = None
122
126
  self.collection_name = collection_name
123
- if do_gene_pos:
124
- biomart = pd.read_parquet(do_gene_pos)
125
- mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
126
- self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
127
- if gene_embeddings != "":
128
- mdataset.genedf = mdataset.genedf.join(
129
- pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
130
- )
131
- if do_gene_pos:
132
- self.gene_pos = mdataset.genedf["pos"].tolist()
127
+ if gene_subset is not None:
128
+ tokeep = set(mdataset.genedf.index.tolist())
129
+ gene_subset = [u for u in gene_subset if u in tokeep]
133
130
  self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
134
131
  # we might want not to order the genes by expression (or do it?)
135
132
  # we might want to not introduce zeros and
136
133
  if use_default_col:
137
134
  kwargs["collate_fn"] = Collator(
138
- organisms=mdataset.organisms,
135
+ organisms=mdataset.organisms if organisms is None else organisms,
139
136
  how=how,
140
- valid_genes=mdataset.genedf.index.tolist(),
137
+ valid_genes=gene_subset,
141
138
  max_len=max_len,
142
- add_zero_genes=add_zero_genes,
143
- org_to_id=mdataset.encoder[organism_name],
139
+ org_to_id=mdataset.encoder[organism_col],
144
140
  tp_name=tp_name,
145
- organism_name=organism_name,
141
+ organism_name=organism_col,
146
142
  class_names=list(self.classes.keys()),
143
+ genedf=genedf,
144
+ n_bins=n_bins,
147
145
  )
146
+ self.n_bins = n_bins
148
147
  self.validation_split = validation_split
149
148
  self.test_split = test_split
150
149
  self.dataset = mdataset
@@ -163,8 +162,12 @@ class DataModule(L.LightningDataModule):
163
162
  self.sampler_chunk_size = sampler_chunk_size
164
163
  self.store_location = store_location
165
164
  self.nnz = None
165
+ self.idx_full = None
166
+ self.max_len = max_len
166
167
  self.test_datasets = []
167
168
  self.force_recompute_indices = force_recompute_indices
169
+ self.curiculum = curiculum
170
+ self.valid_idx = []
168
171
  self.test_idx = []
169
172
  super().__init__()
170
173
  print("finished init")
@@ -183,9 +186,11 @@ class DataModule(L.LightningDataModule):
183
186
  f"\ttest datasets={str(self.test_datasets)},\n"
184
187
  f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
185
188
  f"\tclss_to_weight={self.clss_to_weight}\n"
186
- + ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
187
- if self.idx_full is not None
188
- else ")"
189
+ + (
190
+ "\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
191
+ if self.idx_full is not None
192
+ else ")"
193
+ )
189
194
  )
190
195
 
191
196
  @property
@@ -229,12 +234,17 @@ class DataModule(L.LightningDataModule):
229
234
  """
230
235
  return self.dataset.genedf.index.tolist()
231
236
 
232
- @genes.setter
233
- def genes(self, genes):
234
- self.dataset.genedf = self.dataset.genedf.loc[genes]
235
- self.kwargs["collate_fn"].genes = genes
237
+ @property
238
+ def genes_dict(self):
239
+ return {
240
+ i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
241
+ for i in self.dataset.organisms
242
+ }
243
+
244
+ def set_valid_genes_collator(self, genes):
236
245
  self.kwargs["collate_fn"]._setup(
237
- genedf=self.dataset.genedf,
246
+ # cannot use genedf there since I am purposefully decreasing it...
247
+ # genedf=self.dataset.genedf,
238
248
  org_to_id=self.kwargs["collate_fn"].org_to_id,
239
249
  valid_genes=genes,
240
250
  )
@@ -280,14 +290,11 @@ class DataModule(L.LightningDataModule):
280
290
  stage (str, optional): The stage of the model training process.
281
291
  It can be either 'fit' or 'test'. Defaults to None.
282
292
  """
283
- SCALE = 10
284
293
  print("setting up the datamodule")
285
294
  start_time = time.time()
286
295
  if (
287
296
  self.store_location is None
288
- or not os.path.exists(
289
- os.path.join(self.store_location, "train_weights.npy")
290
- )
297
+ or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
291
298
  or self.force_recompute_indices
292
299
  ):
293
300
  if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
@@ -295,18 +302,19 @@ class DataModule(L.LightningDataModule):
295
302
  "nnz", is_cat=False
296
303
  )
297
304
  self.clss_to_weight.remove("nnz")
298
- (
299
- (self.nnz.max() / SCALE)
300
- / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
301
- ).min()
305
+ # Sigmoid scaling with 2 parameters
306
+ midpoint = 2000
307
+ steepness = 0.003
308
+ # Apply sigmoid transformation
309
+ # sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
310
+ # Then scale to [1, NNZ_SCALE] range
311
+ sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
312
+ self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
302
313
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
303
- weights, labels = self.dataset.get_label_weights(
314
+ labels = self.dataset.get_label_cats(
304
315
  self.clss_to_weight,
305
- scaler=self.weight_scaler,
306
- return_categories=True,
307
316
  )
308
317
  else:
309
- weights = np.ones(1)
310
318
  labels = np.zeros(self.n_samples, dtype=int)
311
319
  if isinstance(self.validation_split, int):
312
320
  len_valid = self.validation_split
@@ -363,28 +371,22 @@ class DataModule(L.LightningDataModule):
363
371
  idx_full = idx_full[len_valid:]
364
372
  else:
365
373
  self.valid_idx = None
366
- weights = np.concatenate([weights, np.zeros(1)])
367
- labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
374
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
368
375
  # some labels will now not exist anymore as replaced by len(weights) - 1.
369
376
  # this means that the associated weights should be 0.
370
377
  # by doing np.bincount(labels)*weights this will be taken into account
371
- self.train_weights = weights
372
378
  self.train_labels = labels
373
379
  self.idx_full = idx_full
374
380
  if self.store_location is not None:
375
381
  if (
376
382
  not os.path.exists(
377
- os.path.join(self.store_location, "train_weights.npy")
383
+ os.path.join(self.store_location, "train_labels.npy")
378
384
  )
379
385
  or self.force_recompute_indices
380
386
  ):
381
387
  os.makedirs(self.store_location, exist_ok=True)
382
388
  if self.nnz is not None:
383
389
  np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
384
- np.save(
385
- os.path.join(self.store_location, "train_weights.npy"),
386
- self.train_weights,
387
- )
388
390
  np.save(
389
391
  os.path.join(self.store_location, "train_labels.npy"),
390
392
  self.train_labels,
@@ -411,9 +413,6 @@ class DataModule(L.LightningDataModule):
411
413
  if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
412
414
  else None
413
415
  )
414
- self.train_weights = np.load(
415
- os.path.join(self.store_location, "train_weights.npy")
416
- )
417
416
  self.train_labels = np.load(
418
417
  os.path.join(self.store_location, "train_labels.npy")
419
418
  )
@@ -446,8 +445,8 @@ class DataModule(L.LightningDataModule):
446
445
  # Create the optimized parallel sampler
447
446
  print(f"Using {self.sampler_workers} workers for class indexing")
448
447
  train_sampler = LabelWeightedSampler(
449
- label_weights=self.train_weights,
450
448
  labels=self.train_labels,
449
+ weight_scaler=self.weight_scaler,
451
450
  num_samples=int(self.n_samples_per_epoch),
452
451
  element_weights=self.nnz,
453
452
  replacement=self.replacement,
@@ -455,15 +454,18 @@ class DataModule(L.LightningDataModule):
455
454
  chunk_size=self.sampler_chunk_size,
456
455
  store_location=self.store_location,
457
456
  force_recompute_indices=self.force_recompute_indices,
457
+ curiculum=self.curiculum,
458
458
  )
459
459
  except ValueError as e:
460
460
  raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
461
+ dataset = None
461
462
  else:
462
- train_sampler = SubsetRandomSampler(self.idx_full)
463
+ dataset = Subset(self.dataset, self.idx_full)
464
+ train_sampler = RankShardSampler(len(dataset))
463
465
  current_loader_kwargs = kwargs.copy()
464
466
  current_loader_kwargs.update(self.kwargs)
465
467
  return DataLoader(
466
- self.dataset,
468
+ self.dataset if dataset is None else dataset,
467
469
  sampler=train_sampler,
468
470
  **current_loader_kwargs,
469
471
  )
@@ -471,12 +473,11 @@ class DataModule(L.LightningDataModule):
471
473
  def val_dataloader(self):
472
474
  return (
473
475
  DataLoader(
474
- self.dataset,
475
- sampler=SubsetRandomSampler(self.valid_idx),
476
+ Subset(self.dataset, self.valid_idx),
476
477
  **self.kwargs,
477
478
  )
478
479
  if self.valid_idx is not None
479
- else None
480
+ else []
480
481
  )
481
482
 
482
483
  def test_dataloader(self):
@@ -485,20 +486,21 @@ class DataModule(L.LightningDataModule):
485
486
  self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
486
487
  )
487
488
  if self.test_idx is not None
488
- else None
489
+ else []
489
490
  )
490
491
 
491
492
  def predict_dataloader(self):
493
+ subset = Subset(self.dataset, self.idx_full)
492
494
  return DataLoader(
493
- self.dataset,
494
- sampler=SubsetRandomSampler(self.idx_full),
495
+ subset,
496
+ sampler=RankShardSampler(len(subset)),
495
497
  **self.kwargs,
496
498
  )
497
499
 
498
500
 
499
501
  class LabelWeightedSampler(Sampler[int]):
500
502
  """
501
- A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
503
+ A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
502
504
 
503
505
  This sampler is designed to handle very large datasets efficiently, with optimizations for:
504
506
  1. Parallel building of class indices
@@ -515,21 +517,22 @@ class LabelWeightedSampler(Sampler[int]):
515
517
 
516
518
  def __init__(
517
519
  self,
518
- label_weights: Sequence[float],
519
520
  labels: np.ndarray,
520
521
  num_samples: int,
521
522
  replacement: bool = True,
523
+ weight_scaler: Optional[float] = None,
522
524
  element_weights: Optional[Sequence[float]] = None,
523
525
  n_workers: int = None,
524
526
  chunk_size: int = None, # Process 10M elements per chunk
525
527
  store_location: str = None,
526
528
  force_recompute_indices: bool = False,
529
+ curiculum: int = 0,
527
530
  ) -> None:
528
531
  """
529
532
  Initialize the sampler with parallel processing for large datasets.
530
533
 
531
534
  Args:
532
- label_weights: Weights for each class (length = number of classes)
535
+ weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
533
536
  labels: Class label for each dataset element (length = dataset size)
534
537
  num_samples: Number of samples to draw
535
538
  replacement: Whether to sample with replacement
@@ -539,10 +542,14 @@ class LabelWeightedSampler(Sampler[int]):
539
542
  """
540
543
  print("Initializing optimized parallel weighted sampler...")
541
544
  super(LabelWeightedSampler, self).__init__(None)
545
+ self.count = 0
546
+ self.curiculum = curiculum
542
547
 
543
548
  # Compute label weights (incorporating class frequencies)
544
549
  # Directly use labels as numpy array without conversion
545
- label_weights = np.asarray(label_weights) * np.bincount(labels)
550
+ counts = np.bincount(labels)
551
+ counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
552
+ label_weights = (weight_scaler * counts) / (counts + weight_scaler)
546
553
  self.label_weights = torch.as_tensor(
547
554
  label_weights, dtype=torch.float32
548
555
  ).share_memory_()
@@ -643,11 +650,16 @@ class LabelWeightedSampler(Sampler[int]):
643
650
  print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
644
651
 
645
652
  def __iter__(self):
653
+ self.count += 1
646
654
  # Sample classes according to their weights
647
655
  print("sampling a new batch of size", self.num_samples)
648
656
 
649
657
  sample_labels = torch.multinomial(
650
- self.label_weights,
658
+ (
659
+ self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
660
+ if self.curiculum
661
+ else self.label_weights
662
+ ),
651
663
  num_samples=self.num_samples,
652
664
  replacement=True,
653
665
  )
@@ -675,6 +687,11 @@ class LabelWeightedSampler(Sampler[int]):
675
687
  # This is a critical point for memory
676
688
  current_element_weights_slice = self.element_weights[klass_index]
677
689
 
690
+ if current_element_weights_slice.shape[0] >= (2**24) - 1:
691
+ ind = torch.randperm(len(klass_index))[: (2**24) - 10]
692
+ klass_index = klass_index[ind]
693
+ current_element_weights_slice = current_element_weights_slice[ind]
694
+
678
695
  if self.replacement:
679
696
  right_inds = torch.multinomial(
680
697
  current_element_weights_slice,
@@ -827,3 +844,27 @@ class LabelWeightedSampler(Sampler[int]):
827
844
  chunk_indices[int(label)] = indices[label_mask]
828
845
 
829
846
  return chunk_indices
847
+
848
+
849
+ class RankShardSampler(Sampler[int]):
850
+ """Shards a dataset contiguously across ranks without padding or duplicates.
851
+ Preserves the existing order (e.g., your pre-shuffled idx_full)."""
852
+
853
+ def __init__(self, data_len: int):
854
+ self.data_len = data_len
855
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
856
+ self.rank = torch.distributed.get_rank()
857
+ self.world_size = torch.distributed.get_world_size()
858
+ else:
859
+ self.rank, self.world_size = 0, 1
860
+
861
+ # contiguous chunk per rank (last rank may be shorter)
862
+ per_rank = math.ceil(self.data_len / self.world_size)
863
+ self.start = self.rank * per_rank
864
+ self.end = min(self.start + per_rank, self.data_len)
865
+
866
+ def __iter__(self):
867
+ return iter(range(self.start, self.end))
868
+
869
+ def __len__(self):
870
+ return self.end - self.start
scdataloader/mapped.py CHANGED
@@ -10,7 +10,7 @@ from __future__ import annotations
10
10
  import os
11
11
  from collections import Counter
12
12
  from functools import reduce
13
- from typing import TYPE_CHECKING, Literal
13
+ from typing import TYPE_CHECKING, List, Literal
14
14
 
15
15
  import numpy as np
16
16
  import pandas as pd
@@ -117,20 +117,20 @@ class MappedCollection:
117
117
 
118
118
  def __init__(
119
119
  self,
120
- path_list: list[UPathStr],
121
- layers_keys: str | list[str] | None = None,
122
- obs_keys: str | list[str] | None = None,
123
- obsm_keys: str | list[str] | None = None,
120
+ path_list: List[UPathStr],
121
+ layers_keys: str | List[str] | None = None,
122
+ obs_keys: str | List[str] | None = None,
123
+ obsm_keys: str | List[str] | None = None,
124
124
  obs_filter: dict[str, str | tuple[str, ...]] | None = None,
125
125
  join: Literal["inner", "outer"] | None = "inner",
126
- encode_labels: bool | list[str] = True,
126
+ encode_labels: bool | List[str] = True,
127
127
  unknown_label: str | dict[str, str] | None = None,
128
128
  cache_categories: bool = True,
129
129
  parallel: bool = False,
130
130
  dtype: str | None = None,
131
131
  metacell_mode: float = 0.0,
132
132
  get_knn_cells: bool = False,
133
- meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
133
+ meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
134
134
  store_location: str | None = None,
135
135
  force_recompute_indices: bool = False,
136
136
  ):
@@ -200,7 +200,9 @@ class MappedCollection:
200
200
  self._cache_categories(self.obs_keys)
201
201
  torch.save(self._cache_cats, self.store_location)
202
202
  else:
203
- self._cache_cats = torch.load(self.store_location)
203
+ self._cache_cats = torch.load(
204
+ self.store_location, weights_only=False
205
+ )
204
206
  print(f"Loaded categories from {self.store_location}")
205
207
  self.encoders: dict = {}
206
208
  if self.encode_labels:
@@ -348,7 +350,7 @@ class MappedCollection:
348
350
  vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
349
351
  return all(vrs_sort_status)
350
352
 
351
- def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
353
+ def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
352
354
  """Returns indices of objects with non-aligned variables.
353
355
 
354
356
  Args:
@@ -380,7 +382,7 @@ class MappedCollection:
380
382
  return (self.n_obs, self.n_vars)
381
383
 
382
384
  @property
383
- def original_shapes(self) -> list[tuple[int, int]]:
385
+ def original_shapes(self) -> List[tuple[int, int]]:
384
386
  """Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
385
387
  if self.n_vars_list is None:
386
388
  n_vars_list = [None] * len(self.n_obs_list)
@@ -437,7 +439,20 @@ class MappedCollection:
437
439
  print(out)
438
440
  raise
439
441
 
440
- if self.metacell_mode > 0:
442
+ if self.get_knn_cells:
443
+ distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
444
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
445
+ out["knn_cells"] = np.array(
446
+ [
447
+ self._get_data_idx(
448
+ lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
449
+ )
450
+ for i in nn_idx
451
+ ],
452
+ dtype=int,
453
+ )
454
+ out["knn_cells_info"] = distances[nn_idx]
455
+ elif self.metacell_mode > 0:
441
456
  if (
442
457
  len(self.meta_assays) > 0
443
458
  and "assay_ontology_term_id" in self.obs_keys
@@ -454,19 +469,6 @@ class MappedCollection:
454
469
  out[layers_key] += self._get_data_idx(
455
470
  lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
456
471
  )
457
- elif self.get_knn_cells:
458
- distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
459
- nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
460
- out["knn_cells"] = np.array(
461
- [
462
- self._get_data_idx(
463
- lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
464
- )
465
- for i in nn_idx
466
- ],
467
- dtype=int,
468
- )
469
- out["distances"] = distances[nn_idx]
470
472
 
471
473
  return out
472
474
 
@@ -541,7 +543,7 @@ class MappedCollection:
541
543
 
542
544
  def get_label_weights(
543
545
  self,
544
- obs_keys: str | list[str],
546
+ obs_keys: str | List[str],
545
547
  scaler: float | None = None,
546
548
  return_categories: bool = False,
547
549
  ):