scdataloader 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,8 +8,9 @@ import pandas as pd
8
8
  import scanpy as sc
9
9
  from anndata import AnnData
10
10
  from scipy.sparse import csr_matrix
11
-
11
+ from anndata import read_h5ad
12
12
  from scdataloader import utils as data_utils
13
+ from upath import UPath
13
14
 
14
15
  FULL_LENGTH_ASSAYS = [
15
16
  "EFO: 0700016",
@@ -31,7 +32,7 @@ class Preprocessor:
31
32
  filter_gene_by_counts: Union[int, bool] = False,
32
33
  filter_cell_by_counts: Union[int, bool] = False,
33
34
  normalize_sum: float = 1e4,
34
- subset_hvg: int = 0,
35
+ n_hvg_for_postp: int = 0,
35
36
  use_layer: Optional[str] = None,
36
37
  is_symbol: bool = False,
37
38
  hvg_flavor: str = "seurat_v3",
@@ -45,7 +46,13 @@ class Preprocessor:
45
46
  maxdropamount: int = 50,
46
47
  madoutlier: int = 5,
47
48
  pct_mt_outlier: int = 8,
48
- batch_key: Optional[str] = None,
49
+ batch_keys: list[str] = [
50
+ "assay_ontology_term_id",
51
+ "self_reported_ethnicity_ontology_term_id",
52
+ "sex_ontology_term_id",
53
+ "donor_id",
54
+ "suspension_type",
55
+ ],
49
56
  skip_validate: bool = False,
50
57
  additional_preprocess: Optional[Callable[[AnnData], AnnData]] = None,
51
58
  additional_postprocess: Optional[Callable[[AnnData], AnnData]] = None,
@@ -65,7 +72,7 @@ class Preprocessor:
65
72
  Defaults to 1e4.
66
73
  log1p (bool, optional): Determines whether to apply log1p transform to the normalized data.
67
74
  Defaults to True.
68
- subset_hvg (int or bool, optional): Determines whether to subset highly variable genes.
75
+ n_hvg_for_postp (int or bool, optional): Determines whether to subset to highly variable genes for the PCA.
69
76
  Defaults to False.
70
77
  hvg_flavor (str, optional): Specifies the flavor of highly variable genes selection.
71
78
  See :func:`scanpy.pp.highly_variable_genes` for more details. Defaults to "seurat_v3".
@@ -96,7 +103,6 @@ class Preprocessor:
96
103
  self.filter_gene_by_counts = filter_gene_by_counts
97
104
  self.filter_cell_by_counts = filter_cell_by_counts
98
105
  self.normalize_sum = normalize_sum
99
- self.subset_hvg = subset_hvg
100
106
  self.hvg_flavor = hvg_flavor
101
107
  self.binning = binning
102
108
  self.organisms = organisms
@@ -109,8 +115,9 @@ class Preprocessor:
109
115
  self.min_nnz_genes = min_nnz_genes
110
116
  self.maxdropamount = maxdropamount
111
117
  self.madoutlier = madoutlier
118
+ self.n_hvg_for_postp = n_hvg_for_postp
112
119
  self.pct_mt_outlier = pct_mt_outlier
113
- self.batch_key = batch_key
120
+ self.batch_keys = batch_keys
114
121
  self.length_normalize = length_normalize
115
122
  self.skip_validate = skip_validate
116
123
  self.use_layer = use_layer
@@ -118,7 +125,7 @@ class Preprocessor:
118
125
  self.do_postp = do_postp
119
126
  self.use_raw = use_raw
120
127
 
121
- def __call__(self, adata) -> AnnData:
128
+ def __call__(self, adata, dataset_id=None) -> AnnData:
122
129
  if adata[0].obs.organism_ontology_term_id.iloc[0] not in self.organisms:
123
130
  raise ValueError(
124
131
  "we cannot work with this organism",
@@ -144,10 +151,6 @@ class Preprocessor:
144
151
  del adata.obsm
145
152
  if len(adata.obsp.keys()) > 0 and self.do_postp:
146
153
  del adata.obsp
147
- if len(adata.uns.keys()) > 0:
148
- del adata.uns
149
- if len(adata.varp.keys()) > 0:
150
- del adata.varp
151
154
  # check that it is a count
152
155
  print("checking raw counts")
153
156
  if np.abs(
@@ -209,9 +212,9 @@ class Preprocessor:
209
212
  )
210
213
  )
211
214
 
212
- if self.is_symbol or not adata.var.index.str.contains("ENSG").any():
213
- if not adata.var.index.str.contains("ENSG").any():
214
- print("No ENSG genes found, assuming gene symbols...")
215
+ if self.is_symbol or not adata.var.index.str.contains("ENS").any():
216
+ if not adata.var.index.str.contains("ENS").any():
217
+ print("No ENS genes found, assuming gene symbols...")
215
218
  genesdf["ensembl_gene_id"] = genesdf.index
216
219
  var = (
217
220
  adata.var.merge(
@@ -243,9 +246,13 @@ class Preprocessor:
243
246
  adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
244
247
  # do a validation function
245
248
  adata.uns["unseen_genes"] = list(unseen)
249
+ if dataset_id is not None:
250
+ adata.uns["dataset_id"] = dataset_id
246
251
  if not self.skip_validate:
247
252
  print("validating")
248
- data_utils.validate(adata, organism=adata.obs.organism_ontology_term_id[0])
253
+ data_utils.validate(
254
+ adata, organism=adata.obs.organism_ontology_term_id[0], need_all=False
255
+ )
249
256
  # length normalization
250
257
  if (
251
258
  adata.obs["assay_ontology_term_id"].isin(FULL_LENGTH_ASSAYS).any()
@@ -310,38 +317,42 @@ class Preprocessor:
310
317
  )["X"]
311
318
  )
312
319
  # step 5: subset hvg
313
- if self.subset_hvg:
314
- sc.pp.highly_variable_genes(
315
- adata,
316
- n_top_genes=self.subset_hvg,
317
- batch_key=self.batch_key,
318
- flavor=self.hvg_flavor,
319
- subset=False,
320
- )
321
- sc.pp.log1p(adata, layer="norm")
322
- sc.pp.pca(
323
- adata,
324
- layer="norm",
325
- n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
326
- )
327
- sc.pp.neighbors(adata, use_rep="X_pca")
328
- sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
329
- sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
330
- sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
331
- batches = [
332
- "assay_ontology_term_id",
333
- "self_reported_ethnicity_ontology_term_id",
334
- "sex_ontology_term_id",
335
- "development_stage_ontology_term_id",
336
- ]
337
- if "donor_id" in adata.obs.columns:
338
- batches.append("donor_id")
339
- if "suspension_type" in adata.obs.columns:
340
- batches.append("suspension_type")
320
+ batches = []
321
+ for i in self.batch_keys:
322
+ if i in adata.obs.columns:
323
+ batches.append(i)
341
324
  adata.obs["batches"] = adata.obs[batches].apply(
342
325
  lambda x: ",".join(x.dropna().astype(str)), axis=1
343
326
  )
344
- sc.tl.umap(adata)
327
+ if self.n_hvg_for_postp:
328
+ try:
329
+ sc.pp.highly_variable_genes(
330
+ adata,
331
+ n_top_genes=self.n_hvg_for_postp,
332
+ batch_key="batches",
333
+ flavor=self.hvg_flavor,
334
+ subset=False,
335
+ layer="norm",
336
+ )
337
+ except (ValueError, ZeroDivisionError) as e:
338
+ print("retrying with span")
339
+ sc.pp.highly_variable_genes(
340
+ adata,
341
+ n_top_genes=self.n_hvg_for_postp,
342
+ # batch_key="batches",
343
+ flavor=self.hvg_flavor,
344
+ span=0.5,
345
+ subset=False,
346
+ layer="norm",
347
+ )
348
+
349
+ adata.obsm["X_pca"] = sc.pp.pca(
350
+ adata.layers["norm"][:, adata.var.highly_variable]
351
+ if "highly_variable" in adata.var.columns
352
+ else adata.layers["norm"],
353
+ n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
354
+ )
355
+
345
356
  # additional
346
357
  if self.additional_postprocess is not None:
347
358
  adata = self.additional_postprocess(adata)
@@ -393,6 +404,7 @@ class Preprocessor:
393
404
  adata.layers[self.result_binned_key] = np.stack(binned_rows)
394
405
  adata.obsm["bin_edges"] = np.stack(bin_edges)
395
406
  print("done")
407
+ print(adata)
396
408
  return adata
397
409
 
398
410
 
@@ -401,22 +413,22 @@ class LaminPreprocessor(Preprocessor):
401
413
  self,
402
414
  *args,
403
415
  cache: bool = True,
404
- stream: bool = False,
405
416
  keep_files: bool = True,
417
+ force_preloaded: bool = False,
406
418
  **kwargs,
407
419
  ):
408
420
  super().__init__(*args, **kwargs)
409
421
  self.cache = cache
410
- self.stream = stream
411
422
  self.keep_files = keep_files
423
+ self.force_preloaded = force_preloaded
412
424
 
413
425
  def __call__(
414
426
  self,
415
427
  data: Union[ln.Collection, AnnData] = None,
416
- name="preprocessed dataset",
417
- description="preprocessed dataset using scprint",
418
- start_at=0,
419
- version=2,
428
+ name: str = "preprocessed dataset",
429
+ description: str = "preprocessed dataset using scprint",
430
+ start_at: int = 0,
431
+ version: str = "2",
420
432
  ):
421
433
  """
422
434
  format controls the different input value wrapping, including categorical
@@ -437,12 +449,15 @@ class LaminPreprocessor(Preprocessor):
437
449
  elif isinstance(data, ln.Collection):
438
450
  for i, file in enumerate(data.artifacts.all()[start_at:]):
439
451
  # use the counts matrix
440
- print(i + start_at)
452
+ i = i + start_at
453
+ print(i)
441
454
  if file.stem_uid in all_ready_processed_keys:
442
455
  print(f"{file.stem_uid} is already processed... not preprocessing")
443
456
  continue
444
457
  print(file)
445
- backed = file.open()
458
+
459
+ path = cache_path(file) if self.force_preloaded else file.cache()
460
+ backed = read_h5ad(path, backed="r")
446
461
  if backed.obs.is_primary_data.sum() == 0:
447
462
  print(f"{file.key} only contains non primary cells.. dropping")
448
463
  # Save the stem_uid to a file to avoid loading it again
@@ -455,12 +470,11 @@ class LaminPreprocessor(Preprocessor):
455
470
  )
456
471
  continue
457
472
  if file.size <= MAXFILESIZE:
458
- adata = file.load(stream=self.stream)
473
+ adata = backed.to_memory()
459
474
  print(adata)
460
475
  else:
461
476
  badata = backed
462
477
  print(badata)
463
-
464
478
  try:
465
479
  if file.size > MAXFILESIZE:
466
480
  print(
@@ -472,16 +486,26 @@ class LaminPreprocessor(Preprocessor):
472
486
  )
473
487
  print("num blocks ", num_blocks)
474
488
  for j in range(num_blocks):
489
+ if j == 0 and i == 390:
490
+ continue
475
491
  start_index = j * block_size
476
492
  end_index = min((j + 1) * block_size, badata.shape[0])
477
493
  block = badata[start_index:end_index].to_memory()
478
494
  print(block)
479
- block = super().__call__(block)
480
- myfile = ln.from_anndata(
495
+ block = super().__call__(
496
+ block, dataset_id=file.stem_uid + "_p" + str(j)
497
+ )
498
+ myfile = ln.Artifact.from_anndata(
481
499
  block,
482
- revises=file,
483
- description=description,
484
- version=str(version) + "_s" + str(j),
500
+ description=description
501
+ + " n"
502
+ + str(i)
503
+ + " p"
504
+ + str(j)
505
+ + " ( revises file "
506
+ + str(file.key)
507
+ + " )",
508
+ version=version,
485
509
  )
486
510
  myfile.save()
487
511
  if self.keep_files:
@@ -491,16 +515,12 @@ class LaminPreprocessor(Preprocessor):
491
515
  del block
492
516
 
493
517
  else:
494
- adata = super().__call__(adata)
495
- try:
496
- sc.pl.umap(adata, color=["cell_type"])
497
- except Exception:
498
- sc.pl.umap(adata, color=["cell_type_ontology_term_id"])
499
- myfile = ln.from_anndata(
518
+ adata = super().__call__(adata, dataset_id=file.stem_uid)
519
+ myfile = ln.Artifact.from_anndata(
500
520
  adata,
501
521
  revises=file,
502
- description=description,
503
- version=str(version),
522
+ description=description + " p" + str(i),
523
+ version=version,
504
524
  )
505
525
  myfile.save()
506
526
  if self.keep_files:
@@ -672,35 +692,158 @@ def additional_preprocess(adata):
672
692
 
673
693
 
674
694
  def additional_postprocess(adata):
675
- import palantir
695
+ # import palantir
676
696
 
677
697
  # define the "up to" 10 neighbors for each cells and add to obs
678
698
  # compute neighbors
679
699
  # need to be connectivities and same labels [cell type, assay, dataset, disease]
680
700
  # define the "neighbor" up to 10(N) cells and add to obs
681
701
  # define the "next time point" up to 5(M) cells and add to obs # step 1: filter genes
682
- del adata.obsp["connectivities"]
683
- del adata.obsp["distances"]
684
- sc.external.pp.harmony_integrate(adata, key="batches")
685
- sc.pp.neighbors(adata, use_rep="X_pca_harmony")
702
+ # if len(adata.obs["batches"].unique()) > 1:
703
+ # sc.external.pp.harmony_integrate(adata, key="batches")
704
+ # sc.pp.neighbors(adata, use_rep="X_pca_harmony")
705
+ # else:
706
+ sc.pp.neighbors(adata, use_rep="X_pca")
707
+ sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
708
+ sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
709
+ sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
686
710
  sc.tl.umap(adata)
711
+ mid = adata.uns["dataset_id"] if "dataset_id" in adata.uns else "unknown_id"
687
712
  sc.pl.umap(
688
713
  adata,
714
+ ncols=1,
689
715
  color=["cell_type", "batches"],
716
+ save="_" + mid + ".png",
690
717
  )
691
- palantir.utils.run_diffusion_maps(adata, n_components=20)
692
- palantir.utils.determine_multiscale_space(adata)
693
- terminal_states = palantir.utils.find_terminal_states(
694
- adata,
695
- celltypes=adata.obs.cell_type_ontology_term_id.unique(),
696
- celltype_column="cell_type_ontology_term_id",
718
+ COL = "cell_type_ontology_term_id"
719
+ NEWOBS = "clust_cell_type"
720
+ MINCELLS = 10
721
+ MAXSIM = 0.94
722
+ from collections import Counter
723
+
724
+ from .config import MAIN_HUMAN_MOUSE_DEV_STAGE_MAP
725
+
726
+ adata.obs[NEWOBS] = (
727
+ adata.obs[COL].astype(str) + "_" + adata.obs["leiden_1"].astype(str)
697
728
  )
698
- sc.tl.diffmap(adata)
699
- adata.obs["heat_diff"] = 1
700
- for terminal_state in terminal_states.index.tolist():
701
- adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
702
- sc.tl.dpt(adata)
703
- adata.obs["heat_diff"] = np.minimum(
704
- adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
705
- )
729
+ coun = Counter(adata.obs[NEWOBS])
730
+ relab = {}
731
+ for i in adata.obs[COL].unique():
732
+ num = 0
733
+ for n, c in sorted(coun.items(), key=lambda x: x[1], reverse=True):
734
+ if i in n:
735
+ if c < MINCELLS or num == 0:
736
+ relab[n] = i
737
+ else:
738
+ relab[n] = i + "_" + str(num)
739
+ num += 1
740
+
741
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(relab)
742
+
743
+ cluster_means = pd.DataFrame(
744
+ np.array(
745
+ [
746
+ adata.X[adata.obs[NEWOBS] == i].mean(axis=0)
747
+ for i in adata.obs[NEWOBS].unique()
748
+ ]
749
+ )[:, 0, :],
750
+ index=adata.obs[NEWOBS].unique(),
751
+ )
752
+
753
+ # Calculate correlation matrix between clusters
754
+ cluster_similarity = cluster_means.T.corr()
755
+ cluster_similarity.values[np.tril_indices(len(cluster_similarity), -1)] = 0
756
+
757
+ # Get pairs with similarity > 0.95
758
+ high_sim_pairs = []
759
+ for i in range(len(cluster_similarity)):
760
+ for j in range(i + 1, len(cluster_similarity)):
761
+ if (
762
+ cluster_similarity.iloc[i, j] > MAXSIM
763
+ and cluster_similarity.columns[i].split("_")[0]
764
+ == cluster_similarity.columns[j].split("_")[0]
765
+ ):
766
+ high_sim_pairs.append(
767
+ (
768
+ cluster_similarity.index[i],
769
+ cluster_similarity.columns[j],
770
+ )
771
+ )
772
+ # Create mapping for merging similar clusters
773
+ merge_mapping = {}
774
+ for pair in high_sim_pairs:
775
+ if pair[0] not in merge_mapping:
776
+ merge_mapping[pair[1]] = pair[0]
777
+ else:
778
+ merge_mapping[pair[1]] = merge_mapping[pair[0]]
779
+
780
+ # Apply merging
781
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
782
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].astype(str)
783
+ coun = Counter(adata.obs[NEWOBS]).most_common()
784
+ merge_mapping = {}
785
+ for i in adata.obs[COL].unique():
786
+ num = 0
787
+ for j, c in coun:
788
+ if i in j:
789
+ merge_mapping[j] = i + "_" + str(num) if num > 0 else i
790
+ num += 1
791
+ adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
792
+
793
+ import bionty as bt
794
+
795
+ stages = adata.obs["development_stage_ontology_term_id"].unique()
796
+ if adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:9606"]:
797
+ relabel = {i: i for i in stages}
798
+ for stage in stages:
799
+ stage_obj = bt.DevelopmentalStage.filter(ontology_id=stage).first()
800
+ parents = set([i.ontology_id for i in stage_obj.parents.filter()])
801
+ parents = parents - set(
802
+ [
803
+ "HsapDv:0010000",
804
+ "HsapDv:0000204",
805
+ "HsapDv:0000227",
806
+ ]
807
+ )
808
+ if len(parents) > 0:
809
+ for p in parents:
810
+ if p in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP:
811
+ relabel[stage] = p
812
+ adata.obs["simplified_dev_stage"] = adata.obs[
813
+ "development_stage_ontology_term_id"
814
+ ].map(relabel)
815
+ elif adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:10090"]:
816
+ rename_mapping = {
817
+ k: v for v, j in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for k in j
818
+ }
819
+ relabel = {i: "unknown" for i in stages}
820
+ for stage in stages:
821
+ if stage in rename_mapping:
822
+ relabel[stage] = rename_mapping[stage]
823
+ adata.obs["simplified_dev_stage"] = adata.obs[
824
+ "development_stage_ontology_term_id"
825
+ ].map(relabel)
826
+ else:
827
+ raise ValueError("organism not supported")
828
+ # palantir.utils.run_diffusion_maps(adata, n_components=20)
829
+ # palantir.utils.determine_multiscale_space(adata)
830
+ # terminal_states = palantir.utils.find_terminal_states(
831
+ # adata,
832
+ # celltypes=adata.obs.cell_type_ontology_term_id.unique(),
833
+ # celltype_column="cell_type_ontology_term_id",
834
+ # )
835
+ # sc.tl.diffmap(adata)
836
+ # adata.obs["heat_diff"] = 1
837
+ # for terminal_state in terminal_states.index.tolist():
838
+ # adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
839
+ # sc.tl.dpt(adata)
840
+ # adata.obs["heat_diff"] = np.minimum(
841
+ # adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
842
+ # )
706
843
  return adata
844
+
845
+
846
+ def cache_path(artifact):
847
+ cloud_path = UPath(artifact.storage.root) / artifact.key
848
+ cache_path = ln.setup.settings.paths.cloud_to_local_no_update(cloud_path)
849
+ return cache_path