pertpy 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -120,8 +120,10 @@ class Tasccoda(CompositionalModel2):
120
120
  covariate_df=covariate_df,
121
121
  )
122
122
  mdata = MuData({modality_key_1: adata, modality_key_2: adata_coda})
123
- else:
123
+ elif type == "sample_level":
124
124
  mdata = MuData({modality_key_1: AnnData(), modality_key_2: adata})
125
+ else:
126
+ raise ValueError(f'{type} is not a supported type, expected "cell_level" or "sample_level".')
125
127
  import_tree(
126
128
  data=mdata,
127
129
  modality_1=modality_key_1,
@@ -464,7 +466,7 @@ class Tasccoda(CompositionalModel2):
464
466
  self,
465
467
  data: AnnData | MuData,
466
468
  modality_key: str = "coda",
467
- rng_key=None,
469
+ rng_key: int | None = None,
468
470
  num_prior_samples: int = 500,
469
471
  use_posterior_predictive: bool = True,
470
472
  ) -> az.InferenceData:
@@ -547,6 +549,8 @@ class Tasccoda(CompositionalModel2):
547
549
  if rng_key is None:
548
550
  rng = np.random.default_rng()
549
551
  rng_key = random.key(rng.integers(0, 10000))
552
+ else:
553
+ rng_key = random.key(rng_key)
550
554
 
551
555
  if use_posterior_predictive:
552
556
  posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
@@ -557,6 +561,15 @@ class Tasccoda(CompositionalModel2):
557
561
  ref_index=ref_index,
558
562
  sample_adata=sample_adata,
559
563
  )
564
+ # Remove problematic posterior predictive arrays with wrong dimensions
565
+ if posterior_predictive and "counts" in posterior_predictive:
566
+ counts_shape = posterior_predictive["counts"].shape
567
+ expected_dims = 2 # ['sample', 'cell_type']
568
+ if len(counts_shape) != expected_dims:
569
+ posterior_predictive = {k: v for k, v in posterior_predictive.items() if k != "counts"}
570
+ logger.warning(
571
+ f"Removed 'counts' from posterior_predictive due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
572
+ )
560
573
  else:
561
574
  posterior_predictive = None
562
575
 
@@ -569,6 +582,15 @@ class Tasccoda(CompositionalModel2):
569
582
  ref_index=ref_index,
570
583
  sample_adata=sample_adata,
571
584
  )
585
+ # Remove problematic prior arrays with wrong dimensions
586
+ if prior and "counts" in prior:
587
+ counts_shape = prior["counts"].shape
588
+ expected_dims = 2 # ['sample', 'cell_type']
589
+ if len(counts_shape) != expected_dims:
590
+ prior = {k: v for k, v in prior.items() if k != "counts"}
591
+ logger.warning(
592
+ f"Removed 'counts' from prior due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
593
+ )
572
594
  else:
573
595
  prior = None
574
596
 
@@ -592,80 +614,88 @@ class Tasccoda(CompositionalModel2):
592
614
  *args,
593
615
  **kwargs,
594
616
  ):
595
- """Examples:
596
- >>> import pertpy as pt
597
- >>> adata = pt.dt.tasccoda_example()
598
- >>> tasccoda = pt.tl.Tasccoda()
599
- >>> mdata = tasccoda.load(
600
- >>> adata, type="sample_level",
601
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
602
- >>> key_added="lineage", add_level_name=True
603
- >>> )
604
- >>> mdata = tasccoda.prepare(
605
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
606
- >>> )
607
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
608
- """ # noqa: D205
617
+ """
618
+
619
+ Examples:
620
+ >>> import pertpy as pt
621
+ >>> adata = pt.dt.tasccoda_example()
622
+ >>> tasccoda = pt.tl.Tasccoda()
623
+ >>> mdata = tasccoda.load(
624
+ >>> adata, type="sample_level",
625
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
626
+ >>> key_added="lineage", add_level_name=True
627
+ >>> )
628
+ >>> mdata = tasccoda.prepare(
629
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
630
+ >>> )
631
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
632
+ """ # noqa: D205, D212
609
633
  return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
610
634
 
611
635
  run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
612
636
 
613
637
  def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
614
- """Examples:
615
- >>> import pertpy as pt
616
- >>> adata = pt.dt.tasccoda_example()
617
- >>> tasccoda = pt.tl.Tasccoda()
618
- >>> mdata = tasccoda.load(
619
- >>> adata, type="sample_level",
620
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
621
- >>> key_added="lineage", add_level_name=True
622
- >>> )
623
- >>> mdata = tasccoda.prepare(
624
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
625
- >>> )
626
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
627
- >>> tasccoda.summary(mdata).
628
- """ # noqa: D205
638
+ """
639
+
640
+ Examples:
641
+ >>> import pertpy as pt
642
+ >>> adata = pt.dt.tasccoda_example()
643
+ >>> tasccoda = pt.tl.Tasccoda()
644
+ >>> mdata = tasccoda.load(
645
+ >>> adata, type="sample_level",
646
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
647
+ >>> key_added="lineage", add_level_name=True
648
+ >>> )
649
+ >>> mdata = tasccoda.prepare(
650
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
651
+ >>> )
652
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
653
+ >>> tasccoda.summary(mdata).
654
+ """ # noqa: D205, D212
629
655
  return super().summary(data, extended, modality_key, *args, **kwargs)
630
656
 
631
657
  summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
632
658
 
633
659
  def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
634
- """Examples:
635
- >>> import pertpy as pt
636
- >>> adata = pt.dt.tasccoda_example()
637
- >>> tasccoda = pt.tl.Tasccoda()
638
- >>> mdata = tasccoda.load(
639
- >>> adata, type="sample_level",
640
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
641
- >>> key_added="lineage", add_level_name=True
642
- >>> )
643
- >>> mdata = tasccoda.prepare(
644
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
645
- >>> )
646
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
647
- >>> tasccoda.credible_effects(mdata).
648
- """ # noqa: D205
660
+ """
661
+
662
+ Examples:
663
+ >>> import pertpy as pt
664
+ >>> adata = pt.dt.tasccoda_example()
665
+ >>> tasccoda = pt.tl.Tasccoda()
666
+ >>> mdata = tasccoda.load(
667
+ >>> adata, type="sample_level",
668
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
669
+ >>> key_added="lineage", add_level_name=True
670
+ >>> )
671
+ >>> mdata = tasccoda.prepare(
672
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
673
+ >>> )
674
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
675
+ >>> tasccoda.credible_effects(mdata).
676
+ """ # noqa: D205, D212
649
677
  return super().credible_effects(data, modality_key, est_fdr)
650
678
 
651
679
  credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
652
680
 
653
681
  def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
654
- """Examples:
655
- >>> import pertpy as pt
656
- >>> adata = pt.dt.tasccoda_example()
657
- >>> tasccoda = pt.tl.Tasccoda()
658
- >>> mdata = tasccoda.load(
659
- >>> adata, type="sample_level",
660
- >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
661
- >>> key_added="lineage", add_level_name=True
662
- >>> )
663
- >>> mdata = tasccoda.prepare(
664
- >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
665
- >>> )
666
- >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
667
- >>> tasccoda.set_fdr(mdata, est_fdr=0.4).
668
- """ # noqa: D205
682
+ """
683
+
684
+ Examples:
685
+ >>> import pertpy as pt
686
+ >>> adata = pt.dt.tasccoda_example()
687
+ >>> tasccoda = pt.tl.Tasccoda()
688
+ >>> mdata = tasccoda.load(
689
+ >>> adata, type="sample_level",
690
+ >>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
691
+ >>> key_added="lineage", add_level_name=True
692
+ >>> )
693
+ >>> mdata = tasccoda.prepare(
694
+ >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
695
+ >>> )
696
+ >>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
697
+ >>> tasccoda.set_fdr(mdata, est_fdr=0.4).
698
+ """ # noqa: D205, D212
669
699
  return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
670
700
 
671
701
  set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__
pertpy/tools/_dialogue.py CHANGED
@@ -882,9 +882,9 @@ class Dialogue:
882
882
  if len(conditions_compare) != 2:
883
883
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
884
884
 
885
- pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
886
- tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
887
- pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
885
+ pvals = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
886
+ tstats = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
887
+ pvals_adj = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
888
888
 
889
889
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
890
890
  for celltype in adata.obs[celltype_label].unique():
@@ -1,9 +1,52 @@
1
+ import contextlib
2
+ from importlib import import_module
3
+ from importlib.util import find_spec
4
+
1
5
  from ._base import LinearModelBase, MethodBase
2
6
  from ._dge_comparison import DGEEVAL
3
7
  from ._edger import EdgeR
4
- from ._pydeseq2 import PyDESeq2
5
8
  from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
6
- from ._statsmodels import Statsmodels
9
+
10
+
11
+ def __getattr__(name: str):
12
+ deps = {
13
+ "PyDESeq2": ["pydeseq2", "formulaic_contrasts", "formulaic"],
14
+ "EdgeR": ["rpy2", "formulaic_contrasts", "formulaic"],
15
+ "Statsmodels": ["formulaic_contrasts", "formulaic"],
16
+ }
17
+
18
+ if name in deps:
19
+ for dep in deps[name]:
20
+ if find_spec(dep) is None:
21
+ raise ImportError(f"{dep} is required but not installed")
22
+
23
+ module_map = {
24
+ "PyDESeq2": "pertpy.tools._differential_gene_expression._pydeseq2",
25
+ "EdgeR": "pertpy.tools._differential_gene_expression._edger",
26
+ "Statsmodels": "pertpy.tools._differential_gene_expression._statsmodels",
27
+ }
28
+
29
+ module = import_module(module_map[name])
30
+ return getattr(module, name)
31
+
32
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
33
+
34
+
35
+ def _get_available_methods():
36
+ methods = [WilcoxonTest, TTest]
37
+ from importlib.util import find_spec
38
+
39
+ for name in ["Statsmodels", "PyDESeq2", "EdgeR"]:
40
+ with contextlib.suppress(ImportError):
41
+ methods.append(__getattr__(name))
42
+
43
+ return methods
44
+
45
+
46
+ AVAILABLE_METHODS = _get_available_methods()
47
+
48
+
49
+ AVAILABLE_METHODS = _get_available_methods()
7
50
 
8
51
  __all__ = [
9
52
  "MethodBase",
@@ -15,5 +58,3 @@ __all__ = [
15
58
  "WilcoxonTest",
16
59
  "TTest",
17
60
  ]
18
-
19
- AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
@@ -12,7 +12,6 @@ import matplotlib.pyplot as plt
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
  import seaborn as sns
15
- from formulaic_contrasts import FormulaicContrasts
16
15
  from lamin_utils import logger
17
16
  from matplotlib.pyplot import Figure
18
17
  from matplotlib.ticker import MaxNLocator
@@ -881,6 +880,8 @@ class LinearModelBase(MethodBase):
881
880
  super().__init__(adata, mask=mask, layer=layer)
882
881
  self._check_counts()
883
882
 
883
+ from formulaic_contrasts import FormulaicContrasts
884
+
884
885
  self.formulaic_contrasts = None
885
886
  if isinstance(design, str):
886
887
  self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
@@ -23,9 +23,6 @@ class EdgeR(LinearModelBase):
23
23
  Args:
24
24
  **kwargs: Keyword arguments specific to glmQLFit()
25
25
  """
26
- # For running in notebook
27
- # pandas2ri.activate()
28
- # rpy2.robjects.numpy2ri.activate()
29
26
  try:
30
27
  from rpy2 import robjects as ro
31
28
  from rpy2.robjects import numpy2ri, pandas2ri
@@ -47,17 +44,17 @@ class EdgeR(LinearModelBase):
47
44
  expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
48
45
  expr = expr.T.toarray() if issparse(expr) else expr.T
49
46
 
50
- with localconverter(get_conversion() + pandas2ri.converter):
51
- expr_r = ro.conversion.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
52
- samples_r = ro.conversion.py2rpy(self.adata.obs)
47
+ with localconverter(get_conversion() + pandas2ri.converter) as cv:
48
+ expr_r = cv.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
49
+ samples_r = cv.py2rpy(self.adata.obs)
53
50
 
54
51
  dge = edger.DGEList(counts=expr_r, samples=samples_r)
55
52
 
56
53
  logger.info("Calculating NormFactors")
57
54
  dge = edger.calcNormFactors(dge)
58
55
 
59
- with localconverter(get_conversion() + numpy2ri.converter):
60
- design_r = ro.conversion.py2rpy(self.design.values)
56
+ with localconverter(get_conversion() + numpy2ri.converter) as cv:
57
+ design_r = cv.py2rpy(self.design.values)
61
58
 
62
59
  logger.info("Estimating Dispersions")
63
60
  dge = edger.estimateDisp(dge, design=design_r)
@@ -100,8 +97,8 @@ class EdgeR(LinearModelBase):
100
97
  ) from None
101
98
 
102
99
  # Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
103
- with localconverter(get_conversion() + numpy2ri.converter):
104
- contrast_vec_r = ro.conversion.py2rpy(np.asarray(contrast))
100
+ with localconverter(get_conversion() + numpy2ri.converter) as cv:
101
+ contrast_vec_r = cv.py2rpy(np.asarray(contrast))
105
102
  ro.globalenv["contrast_vec"] = contrast_vec_r
106
103
 
107
104
  # Test contrast with R
@@ -121,8 +118,8 @@ class EdgeR(LinearModelBase):
121
118
  return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
122
119
 
123
120
  # Convert to Pandas DataFrame if still an R object
124
- with localconverter(get_conversion() + pandas2ri.converter):
125
- de_res = ro.conversion.rpy2py(de_res)
121
+ with localconverter(get_conversion() + pandas2ri.converter) as cv:
122
+ de_res = cv.rpy2py(de_res)
126
123
 
127
124
  de_res.index.name = "variable"
128
125
  de_res = de_res.reset_index()
@@ -1,6 +1,4 @@
1
1
  import os
2
- import re
3
- import warnings
4
2
 
5
3
  import numpy as np
6
4
  import pandas as pd
@@ -8,7 +8,7 @@ from rich.progress import track
8
8
  from sklearn.metrics import pairwise_distances
9
9
  from statsmodels.stats.multitest import multipletests
10
10
 
11
- from ._distances import Distance
11
+ from ._distances import Distance, Metric
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from anndata import AnnData
@@ -43,7 +43,7 @@ class DistanceTest:
43
43
 
44
44
  def __init__(
45
45
  self,
46
- metric: str,
46
+ metric: Metric,
47
47
  n_perms: int = 1000,
48
48
  layer_key: str = None,
49
49
  obsm_key: str = None,
@@ -34,6 +34,31 @@ class MeanVar(NamedTuple):
34
34
  variance: float
35
35
 
36
36
 
37
+ Metric = Literal[
38
+ "edistance",
39
+ "euclidean",
40
+ "root_mean_squared_error",
41
+ "mse",
42
+ "mean_absolute_error",
43
+ "pearson_distance",
44
+ "spearman_distance",
45
+ "kendalltau_distance",
46
+ "cosine_distance",
47
+ "r2_distance",
48
+ "mean_pairwise",
49
+ "mmd",
50
+ "wasserstein",
51
+ "sym_kldiv",
52
+ "t_test",
53
+ "ks_test",
54
+ "nb_ll",
55
+ "classifier_proba",
56
+ "classifier_cp",
57
+ "mean_var_distribution",
58
+ "mahalanobis",
59
+ ]
60
+
61
+
37
62
  class Distance:
38
63
  """Distance class, used to compute distances between groups of cells.
39
64
 
@@ -112,7 +137,7 @@ class Distance:
112
137
 
113
138
  def __init__(
114
139
  self,
115
- metric: str = "edistance",
140
+ metric: Metric = "edistance",
116
141
  agg_fct: Callable = np.mean,
117
142
  layer_key: str = None,
118
143
  obsm_key: str = None,
@@ -660,19 +685,19 @@ class MMD(AbstractDistance):
660
685
  super().__init__()
661
686
  self.accepts_precomputed = False
662
687
 
663
- def __call__(self, X: np.ndarray, Y: np.ndarray, kernel="linear", **kwargs) -> float:
688
+ def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float:
664
689
  if kernel == "linear":
665
690
  XX = np.dot(X, X.T)
666
691
  YY = np.dot(Y, Y.T)
667
692
  XY = np.dot(X, Y.T)
668
693
  elif kernel == "rbf":
669
- XX = rbf_kernel(X, X, gamma=1.0)
670
- YY = rbf_kernel(Y, Y, gamma=1.0)
671
- XY = rbf_kernel(X, Y, gamma=1.0)
694
+ XX = rbf_kernel(X, X, gamma=gamma)
695
+ YY = rbf_kernel(Y, Y, gamma=gamma)
696
+ XY = rbf_kernel(X, Y, gamma=gamma)
672
697
  elif kernel == "poly":
673
- XX = polynomial_kernel(X, X, degree=2, gamma=1.0, coef0=0)
674
- YY = polynomial_kernel(Y, Y, degree=2, gamma=1.0, coef0=0)
675
- XY = polynomial_kernel(X, Y, degree=2, gamma=1.0, coef0=0)
698
+ XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0)
699
+ YY = polynomial_kernel(Y, Y, degree=degree, gamma=gamma, coef0=0)
700
+ XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0)
676
701
  else:
677
702
  raise ValueError(f"Kernel {kernel} not recognized.")
678
703
 
pertpy/tools/_milo.py CHANGED
@@ -411,6 +411,8 @@ class Milo:
411
411
  res = base.as_data_frame(
412
412
  edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf)
413
413
  )
414
+ if res is None:
415
+ raise ValueError("Unable to generate results with edgeR. Is your installation correct?")
414
416
  if not isinstance(res, pd.DataFrame):
415
417
  res = pd.DataFrame(res)
416
418
  # The columns of res looks like e.g. table.A, table.B, so remove the prefix
@@ -530,7 +532,7 @@ class Milo:
530
532
 
531
533
  anno_frac_dataframe = pd.DataFrame(anno_frac, columns=anno_dummies.columns, index=sample_adata.var_names)
532
534
  sample_adata.varm["frac_annotation"] = anno_frac_dataframe.values
533
- sample_adata.uns["annotation_labels"] = anno_frac_dataframe.columns
535
+ sample_adata.uns["annotation_labels"] = anno_frac_dataframe.columns.to_list()
534
536
  sample_adata.uns["annotation_obs"] = anno_col
535
537
  sample_adata.var["nhood_annotation"] = anno_frac_dataframe.idxmax(1)
536
538
  sample_adata.var["nhood_annotation_frac"] = anno_frac_dataframe.max(1)
@@ -1,13 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import warnings
4
-
5
3
  import anndata
6
4
  import numpy as np
7
- import pandas as pd
8
5
  import scipy
9
6
  import torch
10
7
  from anndata import AnnData
8
+ from fast_array_utils.conv import to_dense
11
9
  from pytorch_lightning import LightningModule, Trainer
12
10
  from pytorch_lightning.callbacks import EarlyStopping
13
11
  from sklearn.linear_model import LogisticRegression
@@ -112,18 +110,6 @@ class LRClassifierSpace(PerturbationSpace):
112
110
  return pert_adata
113
111
 
114
112
 
115
- # Ensure backward compatibility with DiscriminatorClassifierSpace
116
- def DiscriminatorClassifierSpace():
117
- warnings.warn(
118
- "The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
119
- "Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
120
- DeprecationWarning,
121
- stacklevel=2,
122
- )
123
-
124
- return MLPClassifierSpace()
125
-
126
-
127
113
  class MLPClassifierSpace(PerturbationSpace):
128
114
  """Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
129
115
 
@@ -202,7 +188,7 @@ class MLPClassifierSpace(PerturbationSpace):
202
188
  labels = adata.obs[target_col].values.reshape(-1, 1)
203
189
  encoder = OneHotEncoder()
204
190
  encoded_labels = encoder.fit_transform(labels).toarray()
205
- adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]
191
+ adata.obsm["encoded_perturbations"] = encoded_labels.astype(np.float32)
206
192
 
207
193
  # Split the data in train, test and validation
208
194
  X = list(range(adata.n_obs))
@@ -226,7 +212,7 @@ class MLPClassifierSpace(PerturbationSpace):
226
212
  # Fix class unbalance (likely to happen in perturbation datasets)
227
213
  # Usually control cells are overrepresented such that predicting control all time would give good results
228
214
  # Cells with rare perturbations are sampled more
229
- train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels.to_list()), dim=1))
215
+ train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
230
216
  train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
231
217
 
232
218
  self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
@@ -278,11 +264,10 @@ class MLPClassifierSpace(PerturbationSpace):
278
264
  pert_adata.obs = pert_adata.obs.reset_index(drop=True)
279
265
  if "perturbations" in self.adata_obs.columns:
280
266
  self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
281
- pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)
282
-
283
- # Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
284
- # which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
285
- pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
267
+ obs_subset = self.adata_obs.iloc[: len(pert_adata.obs)].copy()
268
+ for col in obs_subset.columns:
269
+ if col not in ["perturbations", "encoded_perturbations"]:
270
+ pert_adata.obs[col] = obs_subset[col].values
286
271
 
287
272
  return pert_adata
288
273
 
@@ -397,7 +382,13 @@ class PLDataset(Dataset):
397
382
  else:
398
383
  self.data = adata.X
399
384
 
400
- self.labels = adata.obs[target_col]
385
+ if target_col in adata.obs.columns:
386
+ self.labels = adata.obs[target_col]
387
+ elif target_col in adata.obsm:
388
+ self.labels = adata.obsm[target_col]
389
+ else:
390
+ raise ValueError(f"Target column {target_col} not found in obs or obsm")
391
+
401
392
  self.pert_labels = adata.obs[label_col]
402
393
 
403
394
  def __len__(self):
@@ -405,8 +396,8 @@ class PLDataset(Dataset):
405
396
 
406
397
  def __getitem__(self, idx):
407
398
  """Returns a sample and corresponding perturbations applied (labels)."""
408
- sample = self.data[idx].toarray().squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
409
- num_label = self.labels.iloc[idx]
399
+ sample = to_dense(self.data[idx]).squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
400
+ num_label = self.labels.iloc[idx] if hasattr(self.labels, "iloc") else self.labels[idx]
410
401
  str_label = self.pert_labels.iloc[idx]
411
402
 
412
403
  return sample, num_label, str_label
@@ -161,12 +161,20 @@ class PseudobulkSpace(PerturbationSpace):
161
161
  adata = adata_emb
162
162
 
163
163
  adata.obs[target_col] = adata.obs[target_col].astype("category")
164
+ grouping_cols = [target_col] if groups_col is None else [target_col, groups_col]
165
+ original_obs = adata.obs.copy()
164
166
  ps_adata = sc.get.aggregate(
165
167
  adata, by=[target_col] if groups_col is None else [target_col, groups_col], func=mode, layer=layer_key
166
168
  )
169
+
167
170
  if mode in ps_adata.layers:
168
171
  ps_adata.X = ps_adata.layers[mode]
169
172
 
173
+ for col in original_obs.columns:
174
+ if col not in ps_adata.obs.columns:
175
+ grouped_values = original_obs.groupby(grouping_cols)[col].first()
176
+ ps_adata.obs[col] = grouped_values.reindex(ps_adata.obs.index).values
177
+
170
178
  ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
171
179
 
172
180
  return ps_adata