scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,8 +6,9 @@ import lamindb as ln
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  import scanpy as sc
9
- from anndata import AnnData
9
+ from anndata import AnnData, read_h5ad
10
10
  from scipy.sparse import csr_matrix
11
+ from upath import UPath
11
12
 
12
13
  from scdataloader import utils as data_utils
13
14
 
@@ -31,7 +32,7 @@ class Preprocessor:
31
32
  filter_gene_by_counts: Union[int, bool] = False,
32
33
  filter_cell_by_counts: Union[int, bool] = False,
33
34
  normalize_sum: float = 1e4,
34
- subset_hvg: int = 0,
35
+ n_hvg_for_postp: int = 0,
35
36
  use_layer: Optional[str] = None,
36
37
  is_symbol: bool = False,
37
38
  hvg_flavor: str = "seurat_v3",
@@ -45,13 +46,20 @@ class Preprocessor:
45
46
  maxdropamount: int = 50,
46
47
  madoutlier: int = 5,
47
48
  pct_mt_outlier: int = 8,
48
- batch_key: Optional[str] = None,
49
+ batch_keys: list[str] = [
50
+ "assay_ontology_term_id",
51
+ "self_reported_ethnicity_ontology_term_id",
52
+ "sex_ontology_term_id",
53
+ "donor_id",
54
+ "suspension_type",
55
+ ],
49
56
  skip_validate: bool = False,
50
57
  additional_preprocess: Optional[Callable[[AnnData], AnnData]] = None,
51
58
  additional_postprocess: Optional[Callable[[AnnData], AnnData]] = None,
52
59
  do_postp: bool = True,
53
60
  organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
54
61
  use_raw: bool = True,
62
+ keepdata: bool = False,
55
63
  ) -> None:
56
64
  """
57
65
  Initializes the preprocessor and configures the workflow steps.
@@ -65,7 +73,7 @@ class Preprocessor:
65
73
  Defaults to 1e4.
66
74
  log1p (bool, optional): Determines whether to apply log1p transform to the normalized data.
67
75
  Defaults to True.
68
- subset_hvg (int or bool, optional): Determines whether to subset highly variable genes.
76
+ n_hvg_for_postp (int or bool, optional): Determines whether to subset to highly variable genes for the PCA.
69
77
  Defaults to False.
70
78
  hvg_flavor (str, optional): Specifies the flavor of highly variable genes selection.
71
79
  See :func:`scanpy.pp.highly_variable_genes` for more details. Defaults to "seurat_v3".
@@ -92,11 +100,12 @@ class Preprocessor:
92
100
  This arg is used in the highly variable gene selection step.
93
101
  skip_validate (bool, optional): Determines whether to skip the validation step.
94
102
  Defaults to False.
103
+ keepdata (bool, optional): Determines whether to keep the data in the AnnData object.
104
+ Defaults to False.
95
105
  """
96
106
  self.filter_gene_by_counts = filter_gene_by_counts
97
107
  self.filter_cell_by_counts = filter_cell_by_counts
98
108
  self.normalize_sum = normalize_sum
99
- self.subset_hvg = subset_hvg
100
109
  self.hvg_flavor = hvg_flavor
101
110
  self.binning = binning
102
111
  self.organisms = organisms
@@ -109,16 +118,18 @@ class Preprocessor:
109
118
  self.min_nnz_genes = min_nnz_genes
110
119
  self.maxdropamount = maxdropamount
111
120
  self.madoutlier = madoutlier
121
+ self.n_hvg_for_postp = n_hvg_for_postp
112
122
  self.pct_mt_outlier = pct_mt_outlier
113
- self.batch_key = batch_key
123
+ self.batch_keys = batch_keys
114
124
  self.length_normalize = length_normalize
115
125
  self.skip_validate = skip_validate
116
126
  self.use_layer = use_layer
117
127
  self.is_symbol = is_symbol
118
128
  self.do_postp = do_postp
119
129
  self.use_raw = use_raw
130
+ self.keepdata = keepdata
120
131
 
121
- def __call__(self, adata) -> AnnData:
132
+ def __call__(self, adata, dataset_id=None) -> AnnData:
122
133
  if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
123
134
  raise ValueError(
124
135
  "we cannot work with this organism",
@@ -137,17 +148,14 @@ class Preprocessor:
137
148
  print("X was not raw counts, using 'counts' layer")
138
149
  adata.X = adata.layers["counts"].copy()
139
150
  print("Dropping layers: ", adata.layers.keys())
140
- del adata.layers
141
- if len(adata.varm.keys()) > 0:
151
+ if not self.keepdata:
152
+ del adata.layers
153
+ if len(adata.varm.keys()) > 0 and not self.keepdata:
142
154
  del adata.varm
143
- if len(adata.obsm.keys()) > 0 and self.do_postp:
155
+ if len(adata.obsm.keys()) > 0 and self.do_postp and not self.keepdata:
144
156
  del adata.obsm
145
- if len(adata.obsp.keys()) > 0 and self.do_postp:
157
+ if len(adata.obsp.keys()) > 0 and self.do_postp and not self.keepdata:
146
158
  del adata.obsp
147
- if len(adata.uns.keys()) > 0:
148
- del adata.uns
149
- if len(adata.varp.keys()) > 0:
150
- del adata.varp
151
159
  # check that it is a count
152
160
  print("checking raw counts")
153
161
  if np.abs(
@@ -209,9 +217,9 @@ class Preprocessor:
209
217
  )
210
218
  )
211
219
 
212
- if self.is_symbol or not adata.var.index.str.contains("ENSG").any():
213
- if not adata.var.index.str.contains("ENSG").any():
214
- print("No ENSG genes found, assuming gene symbols...")
220
+ if self.is_symbol or not adata.var.index.str.contains("ENS").any():
221
+ if not adata.var.index.str.contains("ENS").any():
222
+ print("No ENS genes found, assuming gene symbols...")
215
223
  genesdf["ensembl_gene_id"] = genesdf.index
216
224
  var = (
217
225
  adata.var.merge(
@@ -243,9 +251,13 @@ class Preprocessor:
243
251
  adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
244
252
  # do a validation function
245
253
  adata.uns["unseen_genes"] = list(unseen)
254
+ if dataset_id is not None:
255
+ adata.uns["dataset_id"] = dataset_id
246
256
  if not self.skip_validate:
247
257
  print("validating")
248
- data_utils.validate(adata, organism=adata.obs.organism_ontology_term_id[0])
258
+ data_utils.validate(
259
+ adata, organism=adata.obs.organism_ontology_term_id[0], need_all=False
260
+ )
249
261
  # length normalization
250
262
  if (
251
263
  adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS).any()
@@ -310,38 +322,42 @@ class Preprocessor:
310
322
  )["X"]
311
323
  )
312
324
  # step 5: subset hvg
313
- if self.subset_hvg:
314
- sc.pp.highly_variable_genes(
315
- adata,
316
- n_top_genes=self.subset_hvg,
317
- batch_key=self.batch_key,
318
- flavor=self.hvg_flavor,
319
- subset=False,
320
- )
321
- sc.pp.log1p(adata, layer="norm")
322
- sc.pp.pca(
323
- adata,
324
- layer="norm",
325
- n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
326
- )
327
- sc.pp.neighbors(adata, use_rep="X_pca")
328
- sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
329
- sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
330
- sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
331
- batches = [
332
- "assay_ontology_term_id",
333
- "self_reported_ethnicity_ontology_term_id",
334
- "sex_ontology_term_id",
335
- "development_stage_ontology_term_id",
336
- ]
337
- if "donor_id" in adata.obs.columns:
338
- batches.append("donor_id")
339
- if "suspension_type" in adata.obs.columns:
340
- batches.append("suspension_type")
325
+ batches = []
326
+ for i in self.batch_keys:
327
+ if i in adata.obs.columns:
328
+ batches.append(i)
341
329
  adata.obs["batches"] = adata.obs[batches].apply(
342
330
  lambda x: ",".join(x.dropna().astype(str)), axis=1
343
331
  )
344
- sc.tl.umap(adata)
332
+ if self.n_hvg_for_postp:
333
+ try:
334
+ sc.pp.highly_variable_genes(
335
+ adata,
336
+ n_top_genes=self.n_hvg_for_postp,
337
+ batch_key="batches",
338
+ flavor=self.hvg_flavor,
339
+ subset=False,
340
+ layer="norm",
341
+ )
342
+ except (ValueError, ZeroDivisionError):
343
+ print("retrying with span")
344
+ sc.pp.highly_variable_genes(
345
+ adata,
346
+ n_top_genes=self.n_hvg_for_postp,
347
+ # batch_key="batches",
348
+ flavor=self.hvg_flavor,
349
+ span=0.5,
350
+ subset=False,
351
+ layer="norm",
352
+ )
353
+
354
+ adata.obsm["X_pca"] = sc.pp.pca(
355
+ adata.layers["norm"][:, adata.var.highly_variable]
356
+ if "highly_variable" in adata.var.columns
357
+ else adata.layers["norm"],
358
+ n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
359
+ )
360
+
345
361
  # additional
346
362
  if self.additional_postprocess is not None:
347
363
  adata = self.additional_postprocess(adata)
@@ -393,6 +409,7 @@ class Preprocessor:
393
409
  adata.layers[self.result_binned_key] = np.stack(binned_rows)
394
410
  adata.obsm["bin_edges"] = np.stack(bin_edges)
395
411
  print("done")
412
+ print(adata)
396
413
  return adata
397
414
 
398
415
 
@@ -401,22 +418,22 @@ class LaminPreprocessor(Preprocessor):
401
418
  self,
402
419
  *args,
403
420
  cache: bool = True,
404
- stream: bool = False,
405
421
  keep_files: bool = True,
422
+ force_preloaded: bool = False,
406
423
  **kwargs,
407
424
  ):
408
425
  super().__init__(*args, **kwargs)
409
426
  self.cache = cache
410
- self.stream = stream
411
427
  self.keep_files = keep_files
428
+ self.force_preloaded = force_preloaded
412
429
 
413
430
  def __call__(
414
431
  self,
415
432
  data: Union[ln.Collection, AnnData] = None,
416
- name="preprocessed dataset",
417
- description="preprocessed dataset using scprint",
418
- start_at=0,
419
- version=2,
433
+ name: str = "preprocessed dataset",
434
+ description: str = "preprocessed dataset using scprint",
435
+ start_at: int = 0,
436
+ version: str = "2",
420
437
  ):
421
438
  """
422
439
  format controls the different input value wrapping, including categorical
@@ -437,12 +454,15 @@ class LaminPreprocessor(Preprocessor):
437
454
  elif isinstance(data, ln.Collection):
438
455
  for i, file in enumerate(data.artifacts.all()[start_at:]):
439
456
  # use the counts matrix
440
- print(i + start_at)
457
+ i = i + start_at
458
+ print(i)
441
459
  if file.stem_uid in all_ready_processed_keys:
442
460
  print(f"{file.stem_uid} is already processed... not preprocessing")
443
461
  continue
444
462
  print(file)
445
- backed = file.open()
463
+
464
+ path = cache_path(file) if self.force_preloaded else file.cache()
465
+ backed = read_h5ad(path, backed="r")
446
466
  if backed.obs.is_primary_data.sum() == 0:
447
467
  print(f"{file.key} only contains non primary cells.. dropping")
448
468
  # Save the stem_uid to a file to avoid loading it again
@@ -455,16 +475,15 @@ class LaminPreprocessor(Preprocessor):
455
475
  )
456
476
  continue
457
477
  if file.size <= MAXFILESIZE:
458
- adata = file.load(stream=self.stream)
478
+ adata = backed.to_memory()
459
479
  print(adata)
460
480
  else:
461
481
  badata = backed
462
482
  print(badata)
463
-
464
483
  try:
465
484
  if file.size > MAXFILESIZE:
466
485
  print(
467
- f"dividing the dataset as it is too large: {file.size//1_000_000_000}Gb"
486
+ f"dividing the dataset as it is too large: {file.size // 1_000_000_000}Gb"
468
487
  )
469
488
  num_blocks = int(np.ceil(file.size / (MAXFILESIZE / 2)))
470
489
  block_size = int(
@@ -472,16 +491,26 @@ class LaminPreprocessor(Preprocessor):
472
491
  )
473
492
  print("num blocks ", num_blocks)
474
493
  for j in range(num_blocks):
494
+ if j == 0 and i == 390:
495
+ continue
475
496
  start_index = j * block_size
476
497
  end_index = min((j + 1) * block_size, badata.shape[0])
477
498
  block = badata[start_index:end_index].to_memory()
478
499
  print(block)
479
- block = super().__call__(block)
480
- myfile = ln.from_anndata(
500
+ block = super().__call__(
501
+ block, dataset_id=file.stem_uid + "_p" + str(j)
502
+ )
503
+ myfile = ln.Artifact.from_anndata(
481
504
  block,
482
- revises=file,
483
- description=description,
484
- version=str(version) + "_s" + str(j),
505
+ description=description
506
+ + " n"
507
+ + str(i)
508
+ + " p"
509
+ + str(j)
510
+ + " ( revises file "
511
+ + str(file.key)
512
+ + " )",
513
+ version=version,
485
514
  )
486
515
  myfile.save()
487
516
  if self.keep_files:
@@ -491,16 +520,12 @@ class LaminPreprocessor(Preprocessor):
491
520
  del block
492
521
 
493
522
  else:
494
- adata = super().__call__(adata)
495
- try:
496
- sc.pl.umap(adata, color=["cell_type"])
497
- except Exception:
498
- sc.pl.umap(adata, color=["cell_type_ontology_term_id"])
499
- myfile = ln.from_anndata(
523
+ adata = super().__call__(adata, dataset_id=file.stem_uid)
524
+ myfile = ln.Artifact.from_anndata(
500
525
  adata,
501
526
  revises=file,
502
- description=description,
503
- version=str(version),
527
+ description=description + " p" + str(i),
528
+ version=version,
504
529
  )
505
530
  myfile.save()
506
531
  if self.keep_files:
@@ -672,35 +697,158 @@ def additional_preprocess(adata):
672
697
 
673
698
 
674
699
  def additional_postprocess(adata):
675
- import palantir
700
+ # import palantir
676
701
 
677
702
  # define the "up to" 10 neighbors for each cells and add to obs
678
703
  # compute neighbors
679
704
  # need to be connectivities and same labels [cell type, assay, dataset, disease]
680
705
  # define the "neighbor" up to 10(N) cells and add to obs
681
706
  # define the "next time point" up to 5(M) cells and add to obs # step 1: filter genes
682
- del adata.obsp["connectivities"]
683
- del adata.obsp["distances"]
684
- sc.external.pp.harmony_integrate(adata, key="batches")
685
- sc.pp.neighbors(adata, use_rep="X_pca_harmony")
707
+ # if len(adata.obs["batches"].unique()) > 1:
708
+ # sc.external.pp.harmony_integrate(adata, key="batches")
709
+ # sc.pp.neighbors(adata, use_rep="X_pca_harmony")
710
+ # else:
711
+ sc.pp.neighbors(adata, use_rep="X_pca")
712
+ sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
713
+ sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
714
+ sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
686
715
  sc.tl.umap(adata)
716
+ mid = adata.uns["dataset_id"] if "dataset_id" in adata.uns else "unknown_id"
687
717
  sc.pl.umap(
688
718
  adata,
719
+ ncols=1,
689
720
  color=["cell_type", "batches"],
721
+ save="_" + mid + ".png",
690
722
  )
691
- palantir.utils.run_diffusion_maps(adata, n_components=20)
692
- palantir.utils.determine_multiscale_space(adata)
693
- terminal_states = palantir.utils.find_terminal_states(
694
- adata,
695
- celltypes=adata.obs.cell_type_ontology_term_id.unique(),
696
- celltype_column="cell_type_ontology_term_id",
723
+ COL = "cell_type_ontology_term_id"
724
+ NEWOBS = "clust_cell_type"
725
+ MINCELLS = 10
726
+ MAXSIM = 0.94
727
+ from collections import Counter
728
+
729
+ from .config import MAIN_HUMAN_MOUSE_DEV_STAGE_MAP
730
+
731
+ adata.obs[NEWOBS] = (
732
+ adata.obs[COL].astype(str) + "_" + adata.obs["leiden_1"].astype(str)
697
733
  )
698
- sc.tl.diffmap(adata)
699
- adata.obs["heat_diff"] = 1
700
- for terminal_state in terminal_states.index.tolist():
701
- adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
702
- sc.tl.dpt(adata)
703
- adata.obs["heat_diff"] = np.minimum(
704
- adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
705
- )
734
+ coun = Counter(adata.obs[NEWOBS])
735
+ relab = {}
736
+ for i in adata.obs[COL].unique():
737
+ num = 0
738
+ for n, c in sorted(coun.items(), key=lambda x: x[1], reverse=True):
739
+ if i in n:
740
+ if c < MINCELLS or num == 0:
741
+ relab[n] = i
742
+ else:
743
+ relab[n] = i + "_" + str(num)
744
+ num += 1
745
+
746
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(relab)
747
+
748
+ cluster_means = pd.DataFrame(
749
+ np.array(
750
+ [
751
+ adata.X[adata.obs[NEWOBS] == i].mean(axis=0)
752
+ for i in adata.obs[NEWOBS].unique()
753
+ ]
754
+ )[:, 0, :],
755
+ index=adata.obs[NEWOBS].unique(),
756
+ )
757
+
758
+ # Calculate correlation matrix between clusters
759
+ cluster_similarity = cluster_means.T.corr()
760
+ cluster_similarity.values[np.tril_indices(len(cluster_similarity), -1)] = 0
761
+
762
+ # Get pairs with similarity > 0.95
763
+ high_sim_pairs = []
764
+ for i in range(len(cluster_similarity)):
765
+ for j in range(i + 1, len(cluster_similarity)):
766
+ if (
767
+ cluster_similarity.iloc[i, j] > MAXSIM
768
+ and cluster_similarity.columns[i].split("_")[0]
769
+ == cluster_similarity.columns[j].split("_")[0]
770
+ ):
771
+ high_sim_pairs.append(
772
+ (
773
+ cluster_similarity.index[i],
774
+ cluster_similarity.columns[j],
775
+ )
776
+ )
777
+ # Create mapping for merging similar clusters
778
+ merge_mapping = {}
779
+ for pair in high_sim_pairs:
780
+ if pair[0] not in merge_mapping:
781
+ merge_mapping[pair[1]] = pair[0]
782
+ else:
783
+ merge_mapping[pair[1]] = merge_mapping[pair[0]]
784
+
785
+ # Apply merging
786
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
787
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].astype(str)
788
+ coun = Counter(adata.obs[NEWOBS]).most_common()
789
+ merge_mapping = {}
790
+ for i in adata.obs[COL].unique():
791
+ num = 0
792
+ for j, c in coun:
793
+ if i in j:
794
+ merge_mapping[j] = i + "_" + str(num) if num > 0 else i
795
+ num += 1
796
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
797
+
798
+ import bionty as bt
799
+
800
+ stages = adata.obs["development_stage_ontology_term_id"].unique()
801
+ if adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:9606"]:
802
+ relabel = {i: i for i in stages}
803
+ for stage in stages:
804
+ stage_obj = bt.DevelopmentalStage.filter(ontology_id=stage).first()
805
+ parents = set([i.ontology_id for i in stage_obj.parents.filter()])
806
+ parents = parents - set(
807
+ [
808
+ "HsapDv:0010000",
809
+ "HsapDv:0000204",
810
+ "HsapDv:0000227",
811
+ ]
812
+ )
813
+ if len(parents) > 0:
814
+ for p in parents:
815
+ if p in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP:
816
+ relabel[stage] = p
817
+ adata.obs["simplified_dev_stage"] = adata.obs[
818
+ "development_stage_ontology_term_id"
819
+ ].map(relabel)
820
+ elif adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:10090"]:
821
+ rename_mapping = {
822
+ k: v for v, j in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for k in j
823
+ }
824
+ relabel = {i: "unknown" for i in stages}
825
+ for stage in stages:
826
+ if stage in rename_mapping:
827
+ relabel[stage] = rename_mapping[stage]
828
+ adata.obs["simplified_dev_stage"] = adata.obs[
829
+ "development_stage_ontology_term_id"
830
+ ].map(relabel)
831
+ else:
832
+ raise ValueError("organism not supported")
833
+ # palantir.utils.run_diffusion_maps(adata, n_components=20)
834
+ # palantir.utils.determine_multiscale_space(adata)
835
+ # terminal_states = palantir.utils.find_terminal_states(
836
+ # adata,
837
+ # celltypes=adata.obs.cell_type_ontology_term_id.unique(),
838
+ # celltype_column="cell_type_ontology_term_id",
839
+ # )
840
+ # sc.tl.diffmap(adata)
841
+ # adata.obs["heat_diff"] = 1
842
+ # for terminal_state in terminal_states.index.tolist():
843
+ # adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
844
+ # sc.tl.dpt(adata)
845
+ # adata.obs["heat_diff"] = np.minimum(
846
+ # adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
847
+ # )
706
848
  return adata
849
+
850
+
851
+ def cache_path(artifact):
852
+ cloud_path = UPath(artifact.storage.root) / artifact.key
853
+ cache_path = ln.setup.settings.paths.cloud_to_local_no_update(cloud_path)
854
+ return cache_path