scdataloader 2.0.0__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,18 @@
1
+ import math
1
2
  import multiprocessing as mp
2
3
  import os
3
4
  import random
4
5
  import time
5
6
  from concurrent.futures import ProcessPoolExecutor, as_completed
6
7
  from functools import partial
7
- from typing import Optional, Sequence, Union
8
+ from typing import List, Optional, Sequence, Union
8
9
 
9
10
  import lamindb as ln
10
11
  import lightning as L
11
12
  import numpy as np
12
13
  import pandas as pd
13
14
  import torch
14
- from torch.utils.data import DataLoader, Sampler
15
+ from torch.utils.data import DataLoader, Sampler, Subset
15
16
  from torch.utils.data.sampler import (
16
17
  RandomSampler,
17
18
  SequentialSampler,
@@ -25,32 +26,30 @@ from .data import Dataset
25
26
  from .utils import fileToList, getBiomartTable, listToFile
26
27
 
27
28
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
29
+ NNZ_SCALE = 1000
28
30
 
29
31
 
30
32
  class DataModule(L.LightningDataModule):
31
33
  def __init__(
32
34
  self,
33
35
  collection_name: str,
34
- clss_to_weight: list = ["organism_ontology_term_id"],
36
+ clss_to_weight: List[str] = ["organism_ontology_term_id"],
35
37
  weight_scaler: int = 10,
36
38
  n_samples_per_epoch: int = 2_000_000,
37
39
  validation_split: float = 0.2,
38
40
  test_split: float = 0,
39
- gene_embeddings: str = "",
40
41
  use_default_col: bool = True,
41
- gene_position_tolerance: int = 10_000,
42
42
  # this is for the mappedCollection
43
- clss_to_predict: list = ["organism_ontology_term_id"],
44
- hierarchical_clss: list = [],
43
+ clss_to_predict: List[str] = ["organism_ontology_term_id"],
44
+ hierarchical_clss: List[str] = [],
45
45
  # this is for the collator
46
46
  how: str = "random expr",
47
- organism_name: str = "organism_ontology_term_id",
47
+ organism_col: str = "organism_ontology_term_id",
48
48
  max_len: int = 1000,
49
- add_zero_genes: int = 100,
50
49
  replacement: bool = True,
51
- do_gene_pos: str = "",
50
+ gene_subset: Optional[list[str]] = None,
52
51
  tp_name: Optional[str] = None, # "heat_diff"
53
- assays_to_drop: list = [
52
+ assays_to_drop: List[str] = [
54
53
  # "EFO:0008853", #patch seq
55
54
  # "EFO:0010961", # visium
56
55
  "EFO:0030007", # ATACseq
@@ -62,6 +61,11 @@ class DataModule(L.LightningDataModule):
62
61
  force_recompute_indices: bool = False,
63
62
  sampler_workers: int = None,
64
63
  sampler_chunk_size: int = None,
64
+ organisms: Optional[str] = None,
65
+ genedf: Optional[pd.DataFrame] = None,
66
+ n_bins: int = 0,
67
+ curiculum: int = 0,
68
+ start_at: int = 0,
65
69
  **kwargs,
66
70
  ):
67
71
  """
@@ -78,23 +82,19 @@ class DataModule(L.LightningDataModule):
78
82
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
79
83
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
80
84
  it will use a full dataset and will round to the nearest dataset's cell count.
81
- gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
85
+ use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
86
+ clss_to_weight (List[str], optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
87
+ assays_to_drop (List[str], optional): List of assays to drop from the dataset. Defaults to [].
88
+ gene_pos_file (Union[bool, str], optional): The path to the gene positions file. Defaults to True.
82
89
  the file must have ensembl_gene_id as index.
83
90
  This is used to subset the available genes further to the ones that have embeddings in your model.
84
- use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
85
- gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
86
- any genes within this distance of each other will be considered at the same position.
87
- clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
88
- assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
89
- do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
90
91
  max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
91
- add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
92
92
  how (str, optional): The method to use for the collator. Defaults to "random expr".
93
- organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
93
+ organism_col (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
94
94
  tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
95
- hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
95
+ hierarchical_clss (List[str], optional): List of hierarchical classes. Defaults to [].
96
96
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
97
- clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
97
+ clss_to_predict (List[str], optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
98
98
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
99
99
  store_location (str, optional): The location to store the sampler indices. Defaults to None.
100
100
  force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
@@ -107,44 +107,44 @@ class DataModule(L.LightningDataModule):
107
107
  raise ValueError(
108
108
  "need 'organism_ontology_term_id' in the set of classes at least"
109
109
  )
110
+ if metacell_mode > 0 and get_knn_cells:
111
+ raise ValueError(
112
+ "cannot use metacell mode and get_knn_cells at the same time"
113
+ )
110
114
  mdataset = Dataset(
111
- ln.Collection.filter(name=collection_name, is_latest=True).first(),
115
+ ln.Collection.filter(key=collection_name, is_latest=True).first(),
112
116
  clss_to_predict=clss_to_predict,
113
117
  hierarchical_clss=hierarchical_clss,
114
118
  metacell_mode=metacell_mode,
115
119
  get_knn_cells=get_knn_cells,
116
120
  store_location=store_location,
117
121
  force_recompute_indices=force_recompute_indices,
122
+ genedf=genedf,
118
123
  )
119
124
  # and location
120
125
  self.metacell_mode = bool(metacell_mode)
121
126
  self.gene_pos = None
122
127
  self.collection_name = collection_name
123
- if do_gene_pos:
124
- biomart = pd.read_parquet(do_gene_pos)
125
- mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
126
- self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
127
- if gene_embeddings != "":
128
- mdataset.genedf = mdataset.genedf.join(
129
- pd.read_parquet(gene_embeddings).loc[:, :2], how="inner"
130
- )
131
- if do_gene_pos:
132
- self.gene_pos = mdataset.genedf["pos"].tolist()
128
+ if gene_subset is not None:
129
+ tokeep = set(mdataset.genedf.index.tolist())
130
+ gene_subset = [u for u in gene_subset if u in tokeep]
133
131
  self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
134
132
  # we might want not to order the genes by expression (or do it?)
135
133
  # we might want to not introduce zeros and
136
134
  if use_default_col:
137
135
  kwargs["collate_fn"] = Collator(
138
- organisms=mdataset.organisms,
136
+ organisms=mdataset.organisms if organisms is None else organisms,
139
137
  how=how,
140
- valid_genes=mdataset.genedf.index.tolist(),
138
+ valid_genes=gene_subset,
141
139
  max_len=max_len,
142
- add_zero_genes=add_zero_genes,
143
- org_to_id=mdataset.encoder[organism_name],
140
+ org_to_id=mdataset.encoder[organism_col],
144
141
  tp_name=tp_name,
145
- organism_name=organism_name,
142
+ organism_name=organism_col,
146
143
  class_names=list(self.classes.keys()),
144
+ genedf=genedf,
145
+ n_bins=n_bins,
147
146
  )
147
+ self.n_bins = n_bins
148
148
  self.validation_split = validation_split
149
149
  self.test_split = test_split
150
150
  self.dataset = mdataset
@@ -163,8 +163,13 @@ class DataModule(L.LightningDataModule):
163
163
  self.sampler_chunk_size = sampler_chunk_size
164
164
  self.store_location = store_location
165
165
  self.nnz = None
166
+ self.start_at = start_at
167
+ self.idx_full = None
168
+ self.max_len = max_len
166
169
  self.test_datasets = []
167
170
  self.force_recompute_indices = force_recompute_indices
171
+ self.curiculum = curiculum
172
+ self.valid_idx = []
168
173
  self.test_idx = []
169
174
  super().__init__()
170
175
  print("finished init")
@@ -183,9 +188,11 @@ class DataModule(L.LightningDataModule):
183
188
  f"\ttest datasets={str(self.test_datasets)},\n"
184
189
  f"perc test: {str(len(self.test_idx) / self.n_samples)},\n"
185
190
  f"\tclss_to_weight={self.clss_to_weight}\n"
186
- + ("\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)")
187
- if self.idx_full is not None
188
- else ")"
191
+ + (
192
+ "\twith train_dataset size of=(" + str(len(self.idx_full)) + ")\n)"
193
+ if self.idx_full is not None
194
+ else ")"
195
+ )
189
196
  )
190
197
 
191
198
  @property
@@ -229,12 +236,17 @@ class DataModule(L.LightningDataModule):
229
236
  """
230
237
  return self.dataset.genedf.index.tolist()
231
238
 
232
- @genes.setter
233
- def genes(self, genes):
234
- self.dataset.genedf = self.dataset.genedf.loc[genes]
235
- self.kwargs["collate_fn"].genes = genes
239
+ @property
240
+ def genes_dict(self):
241
+ return {
242
+ i: self.dataset.genedf.index[self.dataset.genedf.organism == i].tolist()
243
+ for i in self.dataset.organisms
244
+ }
245
+
246
+ def set_valid_genes_collator(self, genes):
236
247
  self.kwargs["collate_fn"]._setup(
237
- genedf=self.dataset.genedf,
248
+ # cannot use genedf there since I am purposefully decreasing it...
249
+ # genedf=self.dataset.genedf,
238
250
  org_to_id=self.kwargs["collate_fn"].org_to_id,
239
251
  valid_genes=genes,
240
252
  )
@@ -280,14 +292,11 @@ class DataModule(L.LightningDataModule):
280
292
  stage (str, optional): The stage of the model training process.
281
293
  It can be either 'fit' or 'test'. Defaults to None.
282
294
  """
283
- SCALE = 10
284
295
  print("setting up the datamodule")
285
296
  start_time = time.time()
286
297
  if (
287
298
  self.store_location is None
288
- or not os.path.exists(
289
- os.path.join(self.store_location, "train_weights.npy")
290
- )
299
+ or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
291
300
  or self.force_recompute_indices
292
301
  ):
293
302
  if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
@@ -295,18 +304,19 @@ class DataModule(L.LightningDataModule):
295
304
  "nnz", is_cat=False
296
305
  )
297
306
  self.clss_to_weight.remove("nnz")
298
- (
299
- (self.nnz.max() / SCALE)
300
- / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
301
- ).min()
307
+ # Sigmoid scaling with 2 parameters
308
+ midpoint = 2000
309
+ steepness = 0.003
310
+ # Apply sigmoid transformation
311
+ # sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
312
+ # Then scale to [1, NNZ_SCALE] range
313
+ sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
314
+ self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
302
315
  if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
303
- weights, labels = self.dataset.get_label_weights(
316
+ labels = self.dataset.get_label_cats(
304
317
  self.clss_to_weight,
305
- scaler=self.weight_scaler,
306
- return_categories=True,
307
318
  )
308
319
  else:
309
- weights = np.ones(1)
310
320
  labels = np.zeros(self.n_samples, dtype=int)
311
321
  if isinstance(self.validation_split, int):
312
322
  len_valid = self.validation_split
@@ -316,9 +326,9 @@ class DataModule(L.LightningDataModule):
316
326
  len_test = self.test_split
317
327
  else:
318
328
  len_test = int(self.n_samples * self.test_split)
319
- assert len_test + len_valid < self.n_samples, (
320
- "test set + valid set size is configured to be larger than entire dataset."
321
- )
329
+ assert (
330
+ len_test + len_valid < self.n_samples
331
+ ), "test set + valid set size is configured to be larger than entire dataset."
322
332
 
323
333
  idx_full = []
324
334
  if len(self.assays_to_drop) > 0:
@@ -363,28 +373,22 @@ class DataModule(L.LightningDataModule):
363
373
  idx_full = idx_full[len_valid:]
364
374
  else:
365
375
  self.valid_idx = None
366
- weights = np.concatenate([weights, np.zeros(1)])
367
- labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
376
+ labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
368
377
  # some labels will now not exist anymore as replaced by len(weights) - 1.
369
378
  # this means that the associated weights should be 0.
370
379
  # by doing np.bincount(labels)*weights this will be taken into account
371
- self.train_weights = weights
372
380
  self.train_labels = labels
373
381
  self.idx_full = idx_full
374
382
  if self.store_location is not None:
375
383
  if (
376
384
  not os.path.exists(
377
- os.path.join(self.store_location, "train_weights.npy")
385
+ os.path.join(self.store_location, "train_labels.npy")
378
386
  )
379
387
  or self.force_recompute_indices
380
388
  ):
381
389
  os.makedirs(self.store_location, exist_ok=True)
382
390
  if self.nnz is not None:
383
391
  np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
384
- np.save(
385
- os.path.join(self.store_location, "train_weights.npy"),
386
- self.train_weights,
387
- )
388
392
  np.save(
389
393
  os.path.join(self.store_location, "train_labels.npy"),
390
394
  self.train_labels,
@@ -411,9 +415,6 @@ class DataModule(L.LightningDataModule):
411
415
  if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
412
416
  else None
413
417
  )
414
- self.train_weights = np.load(
415
- os.path.join(self.store_location, "train_weights.npy")
416
- )
417
418
  self.train_labels = np.load(
418
419
  os.path.join(self.store_location, "train_labels.npy")
419
420
  )
@@ -446,8 +447,8 @@ class DataModule(L.LightningDataModule):
446
447
  # Create the optimized parallel sampler
447
448
  print(f"Using {self.sampler_workers} workers for class indexing")
448
449
  train_sampler = LabelWeightedSampler(
449
- label_weights=self.train_weights,
450
450
  labels=self.train_labels,
451
+ weight_scaler=self.weight_scaler,
451
452
  num_samples=int(self.n_samples_per_epoch),
452
453
  element_weights=self.nnz,
453
454
  replacement=self.replacement,
@@ -455,15 +456,18 @@ class DataModule(L.LightningDataModule):
455
456
  chunk_size=self.sampler_chunk_size,
456
457
  store_location=self.store_location,
457
458
  force_recompute_indices=self.force_recompute_indices,
459
+ curiculum=self.curiculum,
458
460
  )
459
461
  except ValueError as e:
460
462
  raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
463
+ dataset = None
461
464
  else:
462
- train_sampler = SubsetRandomSampler(self.idx_full)
465
+ dataset = Subset(self.dataset, self.idx_full)
466
+ train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
463
467
  current_loader_kwargs = kwargs.copy()
464
468
  current_loader_kwargs.update(self.kwargs)
465
469
  return DataLoader(
466
- self.dataset,
470
+ self.dataset if dataset is None else dataset,
467
471
  sampler=train_sampler,
468
472
  **current_loader_kwargs,
469
473
  )
@@ -471,12 +475,11 @@ class DataModule(L.LightningDataModule):
471
475
  def val_dataloader(self):
472
476
  return (
473
477
  DataLoader(
474
- self.dataset,
475
- sampler=SubsetRandomSampler(self.valid_idx),
478
+ Subset(self.dataset, self.valid_idx),
476
479
  **self.kwargs,
477
480
  )
478
481
  if self.valid_idx is not None
479
- else None
482
+ else []
480
483
  )
481
484
 
482
485
  def test_dataloader(self):
@@ -485,20 +488,21 @@ class DataModule(L.LightningDataModule):
485
488
  self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
486
489
  )
487
490
  if self.test_idx is not None
488
- else None
491
+ else []
489
492
  )
490
493
 
491
494
  def predict_dataloader(self):
495
+ subset = Subset(self.dataset, self.idx_full)
492
496
  return DataLoader(
493
497
  self.dataset,
494
- sampler=SubsetRandomSampler(self.idx_full),
498
+ sampler=RankShardSampler(len(subset), start_at=self.start_at),
495
499
  **self.kwargs,
496
500
  )
497
501
 
498
502
 
499
503
  class LabelWeightedSampler(Sampler[int]):
500
504
  """
501
- A weighted random sampler that samples from a dataset with respect t o both class weights and element weights.
505
+ A weighted random sampler that samples from a dataset with respect to both class weights and element weights.
502
506
 
503
507
  This sampler is designed to handle very large datasets efficiently, with optimizations for:
504
508
  1. Parallel building of class indices
@@ -515,21 +519,22 @@ class LabelWeightedSampler(Sampler[int]):
515
519
 
516
520
  def __init__(
517
521
  self,
518
- label_weights: Sequence[float],
519
522
  labels: np.ndarray,
520
523
  num_samples: int,
521
524
  replacement: bool = True,
525
+ weight_scaler: Optional[float] = None,
522
526
  element_weights: Optional[Sequence[float]] = None,
523
527
  n_workers: int = None,
524
528
  chunk_size: int = None, # Process 10M elements per chunk
525
529
  store_location: str = None,
526
530
  force_recompute_indices: bool = False,
531
+ curiculum: int = 0,
527
532
  ) -> None:
528
533
  """
529
534
  Initialize the sampler with parallel processing for large datasets.
530
535
 
531
536
  Args:
532
- label_weights: Weights for each class (length = number of classes)
537
+ weight_scaler: Scaling factor for class weights (higher means less balanced sampling)
533
538
  labels: Class label for each dataset element (length = dataset size)
534
539
  num_samples: Number of samples to draw
535
540
  replacement: Whether to sample with replacement
@@ -539,10 +544,14 @@ class LabelWeightedSampler(Sampler[int]):
539
544
  """
540
545
  print("Initializing optimized parallel weighted sampler...")
541
546
  super(LabelWeightedSampler, self).__init__(None)
547
+ self.count = 0
548
+ self.curiculum = curiculum
542
549
 
543
550
  # Compute label weights (incorporating class frequencies)
544
551
  # Directly use labels as numpy array without conversion
545
- label_weights = np.asarray(label_weights) * np.bincount(labels)
552
+ counts = np.bincount(labels)
553
+ counts[-1] = 0 # Ensure the weight for the 'NONE' class is zero
554
+ label_weights = (weight_scaler * counts) / (counts + weight_scaler)
546
555
  self.label_weights = torch.as_tensor(
547
556
  label_weights, dtype=torch.float32
548
557
  ).share_memory_()
@@ -643,11 +652,16 @@ class LabelWeightedSampler(Sampler[int]):
643
652
  print(f"Done initializing sampler with {len(self.klass_offsets)} classes")
644
653
 
645
654
  def __iter__(self):
655
+ self.count += 1
646
656
  # Sample classes according to their weights
647
657
  print("sampling a new batch of size", self.num_samples)
648
658
 
649
659
  sample_labels = torch.multinomial(
650
- self.label_weights,
660
+ (
661
+ self.label_weights ** min(1, ((self.count + 5) / self.curiculum))
662
+ if self.curiculum
663
+ else self.label_weights
664
+ ),
651
665
  num_samples=self.num_samples,
652
666
  replacement=True,
653
667
  )
@@ -655,7 +669,9 @@ class LabelWeightedSampler(Sampler[int]):
655
669
  unique_samples, sample_counts = torch.unique(sample_labels, return_counts=True)
656
670
 
657
671
  # Initialize result tensor
658
- result_indices_list = [] # Changed name to avoid conflict if you had result_indices elsewhere
672
+ result_indices_list = (
673
+ []
674
+ ) # Changed name to avoid conflict if you had result_indices elsewhere
659
675
 
660
676
  # Process only the classes that were actually sampled
661
677
  for i, (label, count) in tqdm(
@@ -675,6 +691,11 @@ class LabelWeightedSampler(Sampler[int]):
675
691
  # This is a critical point for memory
676
692
  current_element_weights_slice = self.element_weights[klass_index]
677
693
 
694
+ if current_element_weights_slice.shape[0] >= (2**24) - 1:
695
+ ind = torch.randperm(len(klass_index))[: (2**24) - 10]
696
+ klass_index = klass_index[ind]
697
+ current_element_weights_slice = current_element_weights_slice[ind]
698
+
678
699
  if self.replacement:
679
700
  right_inds = torch.multinomial(
680
701
  current_element_weights_slice,
@@ -827,3 +848,35 @@ class LabelWeightedSampler(Sampler[int]):
827
848
  chunk_indices[int(label)] = indices[label_mask]
828
849
 
829
850
  return chunk_indices
851
+
852
+
853
+ class RankShardSampler(Sampler[int]):
854
+ """Shards a dataset contiguously across ranks without padding or duplicates.
855
+ Preserves the existing order (e.g., your pre-shuffled idx_full)."""
856
+
857
+ def __init__(self, data_len: int, start_at: int = 0) -> None:
858
+ self.data_len = data_len
859
+ self.start_at = start_at
860
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
861
+ self.rank = torch.distributed.get_rank()
862
+ self.world_size = torch.distributed.get_world_size()
863
+ else:
864
+ self.rank, self.world_size = 0, 1
865
+
866
+ # contiguous chunk per rank (last rank may be shorter)
867
+ if self.start_at > 0:
868
+ print(
869
+ "!!!!ATTTENTION: make sure that you are running on the exact same \
870
+ number of GPU as your previous run!!!!!"
871
+ )
872
+ print(f"Sharding data of size {data_len} over {self.world_size} ranks")
873
+ per_rank = math.ceil((self.data_len - self.start_at) / self.world_size)
874
+ self.start = int((self.start_at / self.world_size) + (self.rank * per_rank))
875
+ self.end = min(self.start + per_rank, self.data_len)
876
+ print(f"Rank {self.rank} processing indices from {self.start} to {self.end}")
877
+
878
+ def __iter__(self):
879
+ return iter(range(self.start, self.end))
880
+
881
+ def __len__(self):
882
+ return self.end - self.start
scdataloader/mapped.py CHANGED
@@ -10,7 +10,7 @@ from __future__ import annotations
10
10
  import os
11
11
  from collections import Counter
12
12
  from functools import reduce
13
- from typing import TYPE_CHECKING, Literal
13
+ from typing import TYPE_CHECKING, List, Literal
14
14
 
15
15
  import numpy as np
16
16
  import pandas as pd
@@ -117,20 +117,20 @@ class MappedCollection:
117
117
 
118
118
  def __init__(
119
119
  self,
120
- path_list: list[UPathStr],
121
- layers_keys: str | list[str] | None = None,
122
- obs_keys: str | list[str] | None = None,
123
- obsm_keys: str | list[str] | None = None,
120
+ path_list: List[UPathStr],
121
+ layers_keys: str | List[str] | None = None,
122
+ obs_keys: str | List[str] | None = None,
123
+ obsm_keys: str | List[str] | None = None,
124
124
  obs_filter: dict[str, str | tuple[str, ...]] | None = None,
125
125
  join: Literal["inner", "outer"] | None = "inner",
126
- encode_labels: bool | list[str] = True,
126
+ encode_labels: bool | List[str] = True,
127
127
  unknown_label: str | dict[str, str] | None = None,
128
128
  cache_categories: bool = True,
129
129
  parallel: bool = False,
130
130
  dtype: str | None = None,
131
131
  metacell_mode: float = 0.0,
132
132
  get_knn_cells: bool = False,
133
- meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
133
+ meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
134
134
  store_location: str | None = None,
135
135
  force_recompute_indices: bool = False,
136
136
  ):
@@ -200,7 +200,9 @@ class MappedCollection:
200
200
  self._cache_categories(self.obs_keys)
201
201
  torch.save(self._cache_cats, self.store_location)
202
202
  else:
203
- self._cache_cats = torch.load(self.store_location)
203
+ self._cache_cats = torch.load(
204
+ self.store_location, weights_only=False
205
+ )
204
206
  print(f"Loaded categories from {self.store_location}")
205
207
  self.encoders: dict = {}
206
208
  if self.encode_labels:
@@ -348,7 +350,7 @@ class MappedCollection:
348
350
  vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
349
351
  return all(vrs_sort_status)
350
352
 
351
- def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
353
+ def check_vars_non_aligned(self, vars: pd.Index | List) -> List[int]:
352
354
  """Returns indices of objects with non-aligned variables.
353
355
 
354
356
  Args:
@@ -380,7 +382,7 @@ class MappedCollection:
380
382
  return (self.n_obs, self.n_vars)
381
383
 
382
384
  @property
383
- def original_shapes(self) -> list[tuple[int, int]]:
385
+ def original_shapes(self) -> List[tuple[int, int]]:
384
386
  """Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
385
387
  if self.n_vars_list is None:
386
388
  n_vars_list = [None] * len(self.n_obs_list)
@@ -437,7 +439,20 @@ class MappedCollection:
437
439
  print(out)
438
440
  raise
439
441
 
440
- if self.metacell_mode > 0:
442
+ if self.get_knn_cells:
443
+ distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
444
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
445
+ out["knn_cells"] = np.array(
446
+ [
447
+ self._get_data_idx(
448
+ lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
449
+ )
450
+ for i in nn_idx
451
+ ],
452
+ dtype=int,
453
+ )
454
+ out["knn_cells_info"] = distances[nn_idx]
455
+ elif self.metacell_mode > 0:
441
456
  if (
442
457
  len(self.meta_assays) > 0
443
458
  and "assay_ontology_term_id" in self.obs_keys
@@ -454,19 +469,6 @@ class MappedCollection:
454
469
  out[layers_key] += self._get_data_idx(
455
470
  lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
456
471
  )
457
- elif self.get_knn_cells:
458
- distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
459
- nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
460
- out["knn_cells"] = np.array(
461
- [
462
- self._get_data_idx(
463
- lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
464
- )
465
- for i in nn_idx
466
- ],
467
- dtype=int,
468
- )
469
- out["distances"] = distances[nn_idx]
470
472
 
471
473
  return out
472
474
 
@@ -541,7 +543,7 @@ class MappedCollection:
541
543
 
542
544
  def get_label_weights(
543
545
  self,
544
- obs_keys: str | list[str],
546
+ obs_keys: str | List[str],
545
547
  scaler: float | None = None,
546
548
  return_categories: bool = False,
547
549
  ):