scdataloader 1.9.0__tar.gz → 1.9.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -135,3 +135,5 @@ figures/*/*.png
135
135
  figures/*.png
136
136
  figures/add_postp_clust.py
137
137
  figures/age_relabel.py
138
+ notebooks/figures/umap_*.png
139
+ notebooks/data/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdataloader
3
- Version: 1.9.0
3
+ Version: 1.9.2
4
4
  Summary: a dataloader for single cell data in lamindb
5
5
  Project-URL: repository, https://github.com/jkobject/scDataLoader
6
6
  Author-email: jkobject <jkobject@gmail.com>
@@ -15,7 +15,7 @@ Requires-Dist: django>=4.0.0
15
15
  Requires-Dist: harmonypy>=0.0.10
16
16
  Requires-Dist: ipykernel>=6.20.0
17
17
  Requires-Dist: jupytext>=1.16.0
18
- Requires-Dist: lamindb[bionty,cellregistry,jupyter,ourprojects,zarr]<2,>=1.0.4
18
+ Requires-Dist: lamindb[bionty,cellregistry,jupyter,zarr]<2,>=1.0.4
19
19
  Requires-Dist: leidenalg>=0.8.0
20
20
  Requires-Dist: matplotlib>=3.5.0
21
21
  Requires-Dist: numpy==1.26.0
@@ -71,7 +71,16 @@ It allows you to:
71
71
  3. create a more complex single cell dataset
72
72
  4. extend it to your need
73
73
 
74
- built on top of `lamindb` and the `.mapped()` function by Sergei: https://github.com/Koncopd
74
+ built on top of `lamindb` and the `.mapped()` function by Sergei: https://github.com/Koncopd
75
+
76
+ ```
77
+ Portions of the mapped.py file are derived from Lamin Labs
78
+ Copyright 2024 Lamin Labs
79
+ Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
80
+ The rest of the package is licensed under MIT License, see LICENSE for details
81
+ Please see https://github.com/laminlabs/lamindb/blob/main/lamindb/core/_mapped_collection.py
82
+ for the original implementation
83
+ ```
75
84
 
76
85
  The package has been designed together with the [scPRINT paper](https://doi.org/10.1101/2024.07.29.605556) and [model](https://github.com/cantinilab/scPRINT).
77
86
 
@@ -28,7 +28,16 @@ It allows you to:
28
28
  3. create a more complex single cell dataset
29
29
  4. extend it to your need
30
30
 
31
- built on top of `lamindb` and the `.mapped()` function by Sergei: https://github.com/Koncopd
31
+ built on top of `lamindb` and the `.mapped()` function by Sergei: https://github.com/Koncopd
32
+
33
+ ```
34
+ Portions of the mapped.py file are derived from Lamin Labs
35
+ Copyright 2024 Lamin Labs
36
+ Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
37
+ The rest of the package is licensed under MIT License, see LICENSE for details
38
+ Please see https://github.com/laminlabs/lamindb/blob/main/lamindb/core/_mapped_collection.py
39
+ for the original implementation
40
+ ```
32
41
 
33
42
  The package has been designed together with the [scPRINT paper](https://doi.org/10.1101/2024.07.29.605556) and [model](https://github.com/cantinilab/scPRINT).
34
43
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdataloader"
3
- version = "1.9.0"
3
+ version = "1.9.2"
4
4
  description = "a dataloader for single cell data in lamindb"
5
5
  authors = [
6
6
  {name = "jkobject", email = "jkobject@gmail.com"}
@@ -11,7 +11,7 @@ requires-python = ">=3.10,<3.14"
11
11
  keywords = ["scRNAseq", "dataloader", "pytorch", "lamindb", "scPRINT"]
12
12
  dependencies = [
13
13
  "numpy==1.26.0",
14
- "lamindb[bionty,ourprojects,jupyter,cellregistry,zarr]>=1.0.4,<2",
14
+ "lamindb[bionty,jupyter,cellregistry,zarr]>=1.0.4,<2",
15
15
  "cellxgene-census>=0.1.0",
16
16
  "torch==2.2.0",
17
17
  "pytorch-lightning>=2.3.0",
@@ -0,0 +1 @@
1
+ 1.9.2
@@ -1,7 +1,8 @@
1
+ from importlib.metadata import version
2
+
1
3
  from .collator import Collator
2
4
  from .data import Dataset, SimpleAnnDataset
3
5
  from .datamodule import DataModule
4
6
  from .preprocess import Preprocessor
5
- from importlib.metadata import version
6
7
 
7
8
  __version__ = version("scdataloader")
@@ -148,7 +148,6 @@ class Collator:
148
148
  :, self.accepted_genes[organism_id]
149
149
  ]
150
150
  if self.how == "most expr":
151
- nnz_loc = np.where(expr > 0)[0]
152
151
  if "knn_cells" in elem:
153
152
  nnz_loc = np.where(expr + elem["knn_cells"].sum(0) > 0)[0]
154
153
  ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
@@ -161,14 +160,18 @@ class Collator:
161
160
  # loc = np.argsort(expr)[-(self.max_len) :][::-1]
162
161
  elif self.how == "random expr":
163
162
  nnz_loc = np.where(expr > 0)[0]
164
- loc = nnz_loc[
165
- np.random.choice(
166
- len(nnz_loc),
167
- self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc),
168
- replace=False,
169
- # p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
170
- )
171
- ]
163
+ loc = (
164
+ nnz_loc[
165
+ np.random.choice(
166
+ len(nnz_loc),
167
+ self.max_len,
168
+ replace=False,
169
+ # p=(expr.max() + (expr[nnz_loc])*19) / expr.max(), # 20 at most times more likely to be selected
170
+ )
171
+ ]
172
+ if self.max_len < len(nnz_loc)
173
+ else nnz_loc
174
+ )
172
175
  elif self.how in ["all", "some"]:
173
176
  loc = np.arange(len(expr))
174
177
  else:
@@ -179,23 +182,19 @@ class Collator:
179
182
  "all",
180
183
  "some",
181
184
  ]:
185
+ ma = self.add_zero_genes + (
186
+ 0 if self.max_len < len(nnz_loc) else self.max_len - len(nnz_loc)
187
+ )
182
188
  if "knn_cells" in elem:
183
189
  # we complete with genes expressed in the knn
184
- nnz_loc = np.where(elem["knn_cells"].sum(0) > 0)[0]
185
- ma = self.max_len if self.max_len < len(nnz_loc) else len(nnz_loc)
186
190
  # which is not a zero_loc in this context
187
- zero_loc = np.argsort(elem["knn_cells"].sum(0))[-(ma):][::-1]
191
+ zero_loc = np.argsort(elem["knn_cells"].sum(0))[-ma:][::-1]
188
192
  else:
189
193
  zero_loc = np.where(expr == 0)[0]
190
194
  zero_loc = zero_loc[
191
195
  np.random.choice(
192
196
  len(zero_loc),
193
- self.add_zero_genes
194
- + (
195
- 0
196
- if self.max_len < len(nnz_loc)
197
- else self.max_len - len(nnz_loc)
198
- ),
197
+ ma,
199
198
  replace=False,
200
199
  )
201
200
  ]
@@ -16,7 +16,7 @@ from torch.utils.data.sampler import (
16
16
 
17
17
  from .collator import Collator
18
18
  from .data import Dataset
19
- from .utils import getBiomartTable, slurm_restart_count
19
+ from .utils import getBiomartTable
20
20
 
21
21
  FILE_DIR = os.path.dirname(os.path.abspath(__file__))
22
22
 
@@ -28,7 +28,7 @@ class DataModule(L.LightningDataModule):
28
28
  clss_to_weight: list = ["organism_ontology_term_id"],
29
29
  organisms: list = ["NCBITaxon:9606"],
30
30
  weight_scaler: int = 10,
31
- train_oversampling_per_epoch: float = 0.1,
31
+ n_samples_per_epoch: int = 2_000_000,
32
32
  validation_split: float = 0.2,
33
33
  test_split: float = 0,
34
34
  gene_embeddings: str = "",
@@ -53,7 +53,6 @@ class DataModule(L.LightningDataModule):
53
53
  ],
54
54
  metacell_mode: float = 0.0,
55
55
  get_knn_cells: bool = False,
56
- modify_seed_on_requeue: bool = True,
57
56
  **kwargs,
58
57
  ):
59
58
  """
@@ -67,7 +66,7 @@ class DataModule(L.LightningDataModule):
67
66
  collection_name (str): The lamindb collection to be used.
68
67
  organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
69
68
  weight_scaler (int, optional): how much more you will see the most present vs less present category.
70
- train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
69
+ n_samples_per_epoch (int, optional): The number of samples to include in the training set for each epoch. Defaults to 2_000_000.
71
70
  validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
72
71
  test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
73
72
  it will use a full dataset and will round to the nearest dataset's cell count.
@@ -88,7 +87,6 @@ class DataModule(L.LightningDataModule):
88
87
  hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
89
88
  metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
90
89
  clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
91
- modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
92
90
  get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
93
91
  **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
94
92
  see @file data.py and @file collator.py for more details about some of the parameters
@@ -171,13 +169,11 @@ class DataModule(L.LightningDataModule):
171
169
  self.assays_to_drop = assays_to_drop
172
170
  self.n_samples = len(mdataset)
173
171
  self.weight_scaler = weight_scaler
174
- self.train_oversampling_per_epoch = train_oversampling_per_epoch
172
+ self.n_samples_per_epoch = n_samples_per_epoch
175
173
  self.clss_to_weight = clss_to_weight
176
174
  self.train_weights = None
177
175
  self.train_labels = None
178
- self.modify_seed_on_requeue = modify_seed_on_requeue
179
176
  self.nnz = None
180
- self.restart_num = 0
181
177
  self.test_datasets = []
182
178
  self.test_idx = []
183
179
  super().__init__()
@@ -190,7 +186,7 @@ class DataModule(L.LightningDataModule):
190
186
  f"\ttest_split={self.test_split},\n"
191
187
  f"\tn_samples={self.n_samples},\n"
192
188
  f"\tweight_scaler={self.weight_scaler},\n"
193
- f"\ttrain_oversampling_per_epoch={self.train_oversampling_per_epoch},\n"
189
+ f"\tn_samples_per_epoch={self.n_samples_per_epoch},\n"
194
190
  f"\tassays_to_drop={self.assays_to_drop},\n"
195
191
  f"\tnum_datasets={len(self.dataset.mapped_dataset.storages)},\n"
196
192
  f"\ttest datasets={str(self.test_datasets)},\n"
@@ -336,18 +332,16 @@ class DataModule(L.LightningDataModule):
336
332
  def train_dataloader(self, **kwargs):
337
333
  # train_sampler = WeightedRandomSampler(
338
334
  # self.train_weights[self.train_labels],
339
- # int(self.n_samples*self.train_oversampling_per_epoch),
335
+ # int(self.n_samples*self.n_samples_per_epoch),
340
336
  # replacement=True,
341
337
  # )
342
338
  try:
343
339
  train_sampler = LabelWeightedSampler(
344
340
  label_weights=self.train_weights,
345
341
  labels=self.train_labels,
346
- num_samples=int(self.n_samples * self.train_oversampling_per_epoch),
342
+ num_samples=int(self.n_samples_per_epoch),
347
343
  element_weights=self.nnz,
348
344
  replacement=self.replacement,
349
- restart_num=self.restart_num,
350
- modify_seed_on_requeue=self.modify_seed_on_requeue,
351
345
  )
352
346
  except ValueError as e:
353
347
  raise ValueError(e + "have you run `datamodule.setup()`?")
@@ -397,8 +391,6 @@ class LabelWeightedSampler(Sampler[int]):
397
391
  num_samples: int
398
392
  nnz: Optional[Sequence[int]]
399
393
  replacement: bool
400
- restart_num: int
401
- modify_seed_on_requeue: bool
402
394
  # when we use, just set weights for each classes(here is: np.ones(num_classes)), and labels of a dataset.
403
395
  # this will result a class-balanced sampling, no matter how imbalance the labels are.
404
396
 
@@ -409,15 +401,12 @@ class LabelWeightedSampler(Sampler[int]):
409
401
  num_samples: int,
410
402
  replacement: bool = True,
411
403
  element_weights: Sequence[float] = None,
412
- restart_num: int = 0,
413
- modify_seed_on_requeue: bool = True,
414
404
  ) -> None:
415
405
  """
416
406
 
417
407
  :param label_weights: list(len=num_classes)[float], weights for each class.
418
408
  :param labels: list(len=dataset_len)[int], labels of a dataset.
419
409
  :param num_samples: number of samples.
420
- :param restart_num: if we are continuing a previous run, we need to restart the sampler from the same point.
421
410
  """
422
411
 
423
412
  super(LabelWeightedSampler, self).__init__(None)
@@ -433,8 +422,6 @@ class LabelWeightedSampler(Sampler[int]):
433
422
  )
434
423
  self.replacement = replacement
435
424
  self.num_samples = num_samples
436
- self.restart_num = slurm_restart_count(use_mine=True) + restart_num
437
- self.modify_seed_on_requeue = modify_seed_on_requeue
438
425
  # list of tensor.
439
426
  self.klass_indices = [
440
427
  (self.labels == i_klass).nonzero().squeeze(1)
@@ -447,9 +434,6 @@ class LabelWeightedSampler(Sampler[int]):
447
434
  self.label_weights,
448
435
  num_samples=self.num_samples,
449
436
  replacement=True,
450
- generator=None
451
- if self.restart_num == 0 and not self.modify_seed_on_requeue
452
- else torch.Generator().manual_seed(self.restart_num),
453
437
  )
454
438
  sample_indices = torch.empty_like(sample_labels)
455
439
  for i_klass, klass_index in enumerate(self.klass_indices):
@@ -465,17 +449,11 @@ class LabelWeightedSampler(Sampler[int]):
465
449
  if not self.replacement and len(klass_index) < len(left_inds)
466
450
  else len(left_inds),
467
451
  replacement=self.replacement,
468
- generator=None
469
- if self.restart_num == 0 and not self.modify_seed_on_requeue
470
- else torch.Generator().manual_seed(self.restart_num),
471
452
  )
472
453
  elif self.replacement:
473
454
  right_inds = torch.randint(
474
455
  len(klass_index),
475
456
  size=(len(left_inds),),
476
- generator=None
477
- if self.restart_num == 0 and not self.modify_seed_on_requeue
478
- else torch.Generator().manual_seed(self.restart_num),
479
457
  )
480
458
  else:
481
459
  maxelem = (
@@ -485,6 +463,7 @@ class LabelWeightedSampler(Sampler[int]):
485
463
  )
486
464
  right_inds = torch.randperm(len(klass_index))[:maxelem]
487
465
  sample_indices[left_inds[: len(right_inds)]] = klass_index[right_inds]
466
+ # if there are more left_inds than right_inds, we need to drop the extra ones
488
467
  if len(right_inds) < len(left_inds):
489
468
  sample_indices[left_inds[len(right_inds) :]] = -1
490
469
  # drop all -1
@@ -492,7 +471,6 @@ class LabelWeightedSampler(Sampler[int]):
492
471
  # torch shuffle
493
472
  sample_indices = sample_indices[torch.randperm(len(sample_indices))]
494
473
  self.num_samples = len(sample_indices)
495
- # raise Exception("stop")
496
474
  yield from iter(sample_indices.tolist())
497
475
 
498
476
  def __len__(self):
@@ -1,3 +1,10 @@
1
+ # Portions of this file are derived from Lamin Labs
2
+ # Copyright 2024 Lamin Labs
3
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # The rest of this file is licensed under MIT
5
+ # Please see https://github.com/laminlabs/lamindb/blob/main/lamindb/core/_mapped_collection.py
6
+ # for the original implementation
7
+
1
8
  from __future__ import annotations
2
9
 
3
10
  from collections import Counter
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from typing import Callable, Optional, Union
2
3
  from uuid import uuid4
3
4
 
@@ -9,7 +10,7 @@ import scanpy as sc
9
10
  from anndata import AnnData, read_h5ad
10
11
  from scipy.sparse import csr_matrix
11
12
  from upath import UPath
12
- import gc
13
+
13
14
  from scdataloader import utils as data_utils
14
15
 
15
16
  FULL_LENGTH_ASSAYS = [
@@ -270,10 +271,12 @@ class Preprocessor:
270
271
  var = ens_var
271
272
 
272
273
  adata = adata[:, var.index]
273
- var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
274
+ # var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
274
275
  # Update adata with combined genes
275
- adata.var = var
276
- genesdf = genesdf.set_index("ensembl_gene_id")
276
+ if "ensembl_gene_id" in var.columns:
277
+ adata.var = var.set_index("ensembl_gene_id")
278
+ else:
279
+ adata.var = var
277
280
  # Drop duplicate genes, keeping first occurrence
278
281
  adata = adata[:, ~adata.var.index.duplicated(keep="first")]
279
282
 
@@ -503,7 +506,7 @@ class LaminPreprocessor(Preprocessor):
503
506
  continue
504
507
  print(file)
505
508
 
506
- path = cache_path(file) if self.force_preloaded else file.cache()
509
+ _ = cache_path(file) if self.force_preloaded else file.cache()
507
510
  backed = file.open()
508
511
  # backed = read_h5ad(path, backed="r")
509
512
  if "is_primary_data" in backed.obs.columns:
@@ -1 +0,0 @@
1
- 1.9.0
File without changes