pertpy 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/data/_dataloader.py +61 -58
- pertpy/metadata/_cell_line.py +9 -3
- pertpy/tools/__init__.py +18 -27
- pertpy/tools/_coda/_base_coda.py +10 -4
- pertpy/tools/_coda/_sccoda.py +84 -56
- pertpy/tools/_coda/_tasccoda.py +91 -61
- pertpy/tools/_dialogue.py +3 -3
- pertpy/tools/_differential_gene_expression/__init__.py +45 -4
- pertpy/tools/_differential_gene_expression/_base.py +2 -1
- pertpy/tools/_differential_gene_expression/_edger.py +9 -12
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +0 -2
- pertpy/tools/_distances/_distance_tests.py +2 -2
- pertpy/tools/_distances/_distances.py +33 -8
- pertpy/tools/_milo.py +3 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +16 -25
- pertpy/tools/_perturbation_space/_simple.py +8 -0
- {pertpy-1.0.0.dist-info → pertpy-1.0.2.dist-info}/METADATA +51 -72
- {pertpy-1.0.0.dist-info → pertpy-1.0.2.dist-info}/RECORD +21 -21
- {pertpy-1.0.0.dist-info → pertpy-1.0.2.dist-info}/WHEEL +1 -1
- {pertpy-1.0.0.dist-info → pertpy-1.0.2.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_tasccoda.py
CHANGED
@@ -120,8 +120,10 @@ class Tasccoda(CompositionalModel2):
|
|
120
120
|
covariate_df=covariate_df,
|
121
121
|
)
|
122
122
|
mdata = MuData({modality_key_1: adata, modality_key_2: adata_coda})
|
123
|
-
|
123
|
+
elif type == "sample_level":
|
124
124
|
mdata = MuData({modality_key_1: AnnData(), modality_key_2: adata})
|
125
|
+
else:
|
126
|
+
raise ValueError(f'{type} is not a supported type, expected "cell_level" or "sample_level".')
|
125
127
|
import_tree(
|
126
128
|
data=mdata,
|
127
129
|
modality_1=modality_key_1,
|
@@ -464,7 +466,7 @@ class Tasccoda(CompositionalModel2):
|
|
464
466
|
self,
|
465
467
|
data: AnnData | MuData,
|
466
468
|
modality_key: str = "coda",
|
467
|
-
rng_key=None,
|
469
|
+
rng_key: int | None = None,
|
468
470
|
num_prior_samples: int = 500,
|
469
471
|
use_posterior_predictive: bool = True,
|
470
472
|
) -> az.InferenceData:
|
@@ -547,6 +549,8 @@ class Tasccoda(CompositionalModel2):
|
|
547
549
|
if rng_key is None:
|
548
550
|
rng = np.random.default_rng()
|
549
551
|
rng_key = random.key(rng.integers(0, 10000))
|
552
|
+
else:
|
553
|
+
rng_key = random.key(rng_key)
|
550
554
|
|
551
555
|
if use_posterior_predictive:
|
552
556
|
posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
|
@@ -557,6 +561,15 @@ class Tasccoda(CompositionalModel2):
|
|
557
561
|
ref_index=ref_index,
|
558
562
|
sample_adata=sample_adata,
|
559
563
|
)
|
564
|
+
# Remove problematic posterior predictive arrays with wrong dimensions
|
565
|
+
if posterior_predictive and "counts" in posterior_predictive:
|
566
|
+
counts_shape = posterior_predictive["counts"].shape
|
567
|
+
expected_dims = 2 # ['sample', 'cell_type']
|
568
|
+
if len(counts_shape) != expected_dims:
|
569
|
+
posterior_predictive = {k: v for k, v in posterior_predictive.items() if k != "counts"}
|
570
|
+
logger.warning(
|
571
|
+
f"Removed 'counts' from posterior_predictive due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
|
572
|
+
)
|
560
573
|
else:
|
561
574
|
posterior_predictive = None
|
562
575
|
|
@@ -569,6 +582,15 @@ class Tasccoda(CompositionalModel2):
|
|
569
582
|
ref_index=ref_index,
|
570
583
|
sample_adata=sample_adata,
|
571
584
|
)
|
585
|
+
# Remove problematic prior arrays with wrong dimensions
|
586
|
+
if prior and "counts" in prior:
|
587
|
+
counts_shape = prior["counts"].shape
|
588
|
+
expected_dims = 2 # ['sample', 'cell_type']
|
589
|
+
if len(counts_shape) != expected_dims:
|
590
|
+
prior = {k: v for k, v in prior.items() if k != "counts"}
|
591
|
+
logger.warning(
|
592
|
+
f"Removed 'counts' from prior due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
|
593
|
+
)
|
572
594
|
else:
|
573
595
|
prior = None
|
574
596
|
|
@@ -592,80 +614,88 @@ class Tasccoda(CompositionalModel2):
|
|
592
614
|
*args,
|
593
615
|
**kwargs,
|
594
616
|
):
|
595
|
-
"""
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
617
|
+
"""
|
618
|
+
|
619
|
+
Examples:
|
620
|
+
>>> import pertpy as pt
|
621
|
+
>>> adata = pt.dt.tasccoda_example()
|
622
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
623
|
+
>>> mdata = tasccoda.load(
|
624
|
+
>>> adata, type="sample_level",
|
625
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
626
|
+
>>> key_added="lineage", add_level_name=True
|
627
|
+
>>> )
|
628
|
+
>>> mdata = tasccoda.prepare(
|
629
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
630
|
+
>>> )
|
631
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
|
632
|
+
""" # noqa: D205, D212
|
609
633
|
return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
|
610
634
|
|
611
635
|
run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
|
612
636
|
|
613
637
|
def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
|
614
|
-
"""
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
638
|
+
"""
|
639
|
+
|
640
|
+
Examples:
|
641
|
+
>>> import pertpy as pt
|
642
|
+
>>> adata = pt.dt.tasccoda_example()
|
643
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
644
|
+
>>> mdata = tasccoda.load(
|
645
|
+
>>> adata, type="sample_level",
|
646
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
647
|
+
>>> key_added="lineage", add_level_name=True
|
648
|
+
>>> )
|
649
|
+
>>> mdata = tasccoda.prepare(
|
650
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
651
|
+
>>> )
|
652
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
653
|
+
>>> tasccoda.summary(mdata).
|
654
|
+
""" # noqa: D205, D212
|
629
655
|
return super().summary(data, extended, modality_key, *args, **kwargs)
|
630
656
|
|
631
657
|
summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
|
632
658
|
|
633
659
|
def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
|
634
|
-
"""
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
660
|
+
"""
|
661
|
+
|
662
|
+
Examples:
|
663
|
+
>>> import pertpy as pt
|
664
|
+
>>> adata = pt.dt.tasccoda_example()
|
665
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
666
|
+
>>> mdata = tasccoda.load(
|
667
|
+
>>> adata, type="sample_level",
|
668
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
669
|
+
>>> key_added="lineage", add_level_name=True
|
670
|
+
>>> )
|
671
|
+
>>> mdata = tasccoda.prepare(
|
672
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
673
|
+
>>> )
|
674
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
675
|
+
>>> tasccoda.credible_effects(mdata).
|
676
|
+
""" # noqa: D205, D212
|
649
677
|
return super().credible_effects(data, modality_key, est_fdr)
|
650
678
|
|
651
679
|
credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
|
652
680
|
|
653
681
|
def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
|
654
|
-
"""
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
682
|
+
"""
|
683
|
+
|
684
|
+
Examples:
|
685
|
+
>>> import pertpy as pt
|
686
|
+
>>> adata = pt.dt.tasccoda_example()
|
687
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
688
|
+
>>> mdata = tasccoda.load(
|
689
|
+
>>> adata, type="sample_level",
|
690
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
691
|
+
>>> key_added="lineage", add_level_name=True
|
692
|
+
>>> )
|
693
|
+
>>> mdata = tasccoda.prepare(
|
694
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
695
|
+
>>> )
|
696
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
697
|
+
>>> tasccoda.set_fdr(mdata, est_fdr=0.4).
|
698
|
+
""" # noqa: D205, D212
|
669
699
|
return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
|
670
700
|
|
671
701
|
set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__
|
pertpy/tools/_dialogue.py
CHANGED
@@ -882,9 +882,9 @@ class Dialogue:
|
|
882
882
|
if len(conditions_compare) != 2:
|
883
883
|
raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
|
884
884
|
|
885
|
-
pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
886
|
-
tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
887
|
-
pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
885
|
+
pvals = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
886
|
+
tstats = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
887
|
+
pvals_adj = pd.DataFrame(1.0, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(n_mcps)])
|
888
888
|
|
889
889
|
response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
|
890
890
|
for celltype in adata.obs[celltype_label].unique():
|
@@ -1,9 +1,52 @@
|
|
1
|
+
import contextlib
|
2
|
+
from importlib import import_module
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
1
5
|
from ._base import LinearModelBase, MethodBase
|
2
6
|
from ._dge_comparison import DGEEVAL
|
3
7
|
from ._edger import EdgeR
|
4
|
-
from ._pydeseq2 import PyDESeq2
|
5
8
|
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
|
6
|
-
|
9
|
+
|
10
|
+
|
11
|
+
def __getattr__(name: str):
|
12
|
+
deps = {
|
13
|
+
"PyDESeq2": ["pydeseq2", "formulaic_contrasts", "formulaic"],
|
14
|
+
"EdgeR": ["rpy2", "formulaic_contrasts", "formulaic"],
|
15
|
+
"Statsmodels": ["formulaic_contrasts", "formulaic"],
|
16
|
+
}
|
17
|
+
|
18
|
+
if name in deps:
|
19
|
+
for dep in deps[name]:
|
20
|
+
if find_spec(dep) is None:
|
21
|
+
raise ImportError(f"{dep} is required but not installed")
|
22
|
+
|
23
|
+
module_map = {
|
24
|
+
"PyDESeq2": "pertpy.tools._differential_gene_expression._pydeseq2",
|
25
|
+
"EdgeR": "pertpy.tools._differential_gene_expression._edger",
|
26
|
+
"Statsmodels": "pertpy.tools._differential_gene_expression._statsmodels",
|
27
|
+
}
|
28
|
+
|
29
|
+
module = import_module(module_map[name])
|
30
|
+
return getattr(module, name)
|
31
|
+
|
32
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
33
|
+
|
34
|
+
|
35
|
+
def _get_available_methods():
|
36
|
+
methods = [WilcoxonTest, TTest]
|
37
|
+
from importlib.util import find_spec
|
38
|
+
|
39
|
+
for name in ["Statsmodels", "PyDESeq2", "EdgeR"]:
|
40
|
+
with contextlib.suppress(ImportError):
|
41
|
+
methods.append(__getattr__(name))
|
42
|
+
|
43
|
+
return methods
|
44
|
+
|
45
|
+
|
46
|
+
AVAILABLE_METHODS = _get_available_methods()
|
47
|
+
|
48
|
+
|
49
|
+
AVAILABLE_METHODS = _get_available_methods()
|
7
50
|
|
8
51
|
__all__ = [
|
9
52
|
"MethodBase",
|
@@ -15,5 +58,3 @@ __all__ = [
|
|
15
58
|
"WilcoxonTest",
|
16
59
|
"TTest",
|
17
60
|
]
|
18
|
-
|
19
|
-
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
|
@@ -12,7 +12,6 @@ import matplotlib.pyplot as plt
|
|
12
12
|
import numpy as np
|
13
13
|
import pandas as pd
|
14
14
|
import seaborn as sns
|
15
|
-
from formulaic_contrasts import FormulaicContrasts
|
16
15
|
from lamin_utils import logger
|
17
16
|
from matplotlib.pyplot import Figure
|
18
17
|
from matplotlib.ticker import MaxNLocator
|
@@ -881,6 +880,8 @@ class LinearModelBase(MethodBase):
|
|
881
880
|
super().__init__(adata, mask=mask, layer=layer)
|
882
881
|
self._check_counts()
|
883
882
|
|
883
|
+
from formulaic_contrasts import FormulaicContrasts
|
884
|
+
|
884
885
|
self.formulaic_contrasts = None
|
885
886
|
if isinstance(design, str):
|
886
887
|
self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
|
@@ -23,9 +23,6 @@ class EdgeR(LinearModelBase):
|
|
23
23
|
Args:
|
24
24
|
**kwargs: Keyword arguments specific to glmQLFit()
|
25
25
|
"""
|
26
|
-
# For running in notebook
|
27
|
-
# pandas2ri.activate()
|
28
|
-
# rpy2.robjects.numpy2ri.activate()
|
29
26
|
try:
|
30
27
|
from rpy2 import robjects as ro
|
31
28
|
from rpy2.robjects import numpy2ri, pandas2ri
|
@@ -47,17 +44,17 @@ class EdgeR(LinearModelBase):
|
|
47
44
|
expr = self.adata.X if self.layer is None else self.adata.layers[self.layer]
|
48
45
|
expr = expr.T.toarray() if issparse(expr) else expr.T
|
49
46
|
|
50
|
-
with localconverter(get_conversion() + pandas2ri.converter):
|
51
|
-
expr_r =
|
52
|
-
samples_r =
|
47
|
+
with localconverter(get_conversion() + pandas2ri.converter) as cv:
|
48
|
+
expr_r = cv.py2rpy(pd.DataFrame(expr, index=self.adata.var_names, columns=self.adata.obs_names))
|
49
|
+
samples_r = cv.py2rpy(self.adata.obs)
|
53
50
|
|
54
51
|
dge = edger.DGEList(counts=expr_r, samples=samples_r)
|
55
52
|
|
56
53
|
logger.info("Calculating NormFactors")
|
57
54
|
dge = edger.calcNormFactors(dge)
|
58
55
|
|
59
|
-
with localconverter(get_conversion() + numpy2ri.converter):
|
60
|
-
design_r =
|
56
|
+
with localconverter(get_conversion() + numpy2ri.converter) as cv:
|
57
|
+
design_r = cv.py2rpy(self.design.values)
|
61
58
|
|
62
59
|
logger.info("Estimating Dispersions")
|
63
60
|
dge = edger.estimateDisp(dge, design=design_r)
|
@@ -100,8 +97,8 @@ class EdgeR(LinearModelBase):
|
|
100
97
|
) from None
|
101
98
|
|
102
99
|
# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
|
103
|
-
with localconverter(get_conversion() + numpy2ri.converter):
|
104
|
-
contrast_vec_r =
|
100
|
+
with localconverter(get_conversion() + numpy2ri.converter) as cv:
|
101
|
+
contrast_vec_r = cv.py2rpy(np.asarray(contrast))
|
105
102
|
ro.globalenv["contrast_vec"] = contrast_vec_r
|
106
103
|
|
107
104
|
# Test contrast with R
|
@@ -121,8 +118,8 @@ class EdgeR(LinearModelBase):
|
|
121
118
|
return de_res.reset_index().rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})
|
122
119
|
|
123
120
|
# Convert to Pandas DataFrame if still an R object
|
124
|
-
with localconverter(get_conversion() + pandas2ri.converter):
|
125
|
-
de_res =
|
121
|
+
with localconverter(get_conversion() + pandas2ri.converter) as cv:
|
122
|
+
de_res = cv.rpy2py(de_res)
|
126
123
|
|
127
124
|
de_res.index.name = "variable"
|
128
125
|
de_res = de_res.reset_index()
|
@@ -8,7 +8,7 @@ from rich.progress import track
|
|
8
8
|
from sklearn.metrics import pairwise_distances
|
9
9
|
from statsmodels.stats.multitest import multipletests
|
10
10
|
|
11
|
-
from ._distances import Distance
|
11
|
+
from ._distances import Distance, Metric
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from anndata import AnnData
|
@@ -43,7 +43,7 @@ class DistanceTest:
|
|
43
43
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
|
-
metric:
|
46
|
+
metric: Metric,
|
47
47
|
n_perms: int = 1000,
|
48
48
|
layer_key: str = None,
|
49
49
|
obsm_key: str = None,
|
@@ -34,6 +34,31 @@ class MeanVar(NamedTuple):
|
|
34
34
|
variance: float
|
35
35
|
|
36
36
|
|
37
|
+
Metric = Literal[
|
38
|
+
"edistance",
|
39
|
+
"euclidean",
|
40
|
+
"root_mean_squared_error",
|
41
|
+
"mse",
|
42
|
+
"mean_absolute_error",
|
43
|
+
"pearson_distance",
|
44
|
+
"spearman_distance",
|
45
|
+
"kendalltau_distance",
|
46
|
+
"cosine_distance",
|
47
|
+
"r2_distance",
|
48
|
+
"mean_pairwise",
|
49
|
+
"mmd",
|
50
|
+
"wasserstein",
|
51
|
+
"sym_kldiv",
|
52
|
+
"t_test",
|
53
|
+
"ks_test",
|
54
|
+
"nb_ll",
|
55
|
+
"classifier_proba",
|
56
|
+
"classifier_cp",
|
57
|
+
"mean_var_distribution",
|
58
|
+
"mahalanobis",
|
59
|
+
]
|
60
|
+
|
61
|
+
|
37
62
|
class Distance:
|
38
63
|
"""Distance class, used to compute distances between groups of cells.
|
39
64
|
|
@@ -112,7 +137,7 @@ class Distance:
|
|
112
137
|
|
113
138
|
def __init__(
|
114
139
|
self,
|
115
|
-
metric:
|
140
|
+
metric: Metric = "edistance",
|
116
141
|
agg_fct: Callable = np.mean,
|
117
142
|
layer_key: str = None,
|
118
143
|
obsm_key: str = None,
|
@@ -660,19 +685,19 @@ class MMD(AbstractDistance):
|
|
660
685
|
super().__init__()
|
661
686
|
self.accepts_precomputed = False
|
662
687
|
|
663
|
-
def __call__(self, X: np.ndarray, Y: np.ndarray, kernel="linear", **kwargs) -> float:
|
688
|
+
def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float:
|
664
689
|
if kernel == "linear":
|
665
690
|
XX = np.dot(X, X.T)
|
666
691
|
YY = np.dot(Y, Y.T)
|
667
692
|
XY = np.dot(X, Y.T)
|
668
693
|
elif kernel == "rbf":
|
669
|
-
XX = rbf_kernel(X, X, gamma=
|
670
|
-
YY = rbf_kernel(Y, Y, gamma=
|
671
|
-
XY = rbf_kernel(X, Y, gamma=
|
694
|
+
XX = rbf_kernel(X, X, gamma=gamma)
|
695
|
+
YY = rbf_kernel(Y, Y, gamma=gamma)
|
696
|
+
XY = rbf_kernel(X, Y, gamma=gamma)
|
672
697
|
elif kernel == "poly":
|
673
|
-
XX = polynomial_kernel(X, X, degree=
|
674
|
-
YY = polynomial_kernel(Y, Y, degree=
|
675
|
-
XY = polynomial_kernel(X, Y, degree=
|
698
|
+
XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0)
|
699
|
+
YY = polynomial_kernel(Y, Y, degree=degree, gamma=gamma, coef0=0)
|
700
|
+
XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0)
|
676
701
|
else:
|
677
702
|
raise ValueError(f"Kernel {kernel} not recognized.")
|
678
703
|
|
pertpy/tools/_milo.py
CHANGED
@@ -411,6 +411,8 @@ class Milo:
|
|
411
411
|
res = base.as_data_frame(
|
412
412
|
edgeR.topTags(edgeR.glmQLFTest(fit, coef=n_coef), sort_by="none", n=np.inf)
|
413
413
|
)
|
414
|
+
if res is None:
|
415
|
+
raise ValueError("Unable to generate results with edgeR. Is your installation correct?")
|
414
416
|
if not isinstance(res, pd.DataFrame):
|
415
417
|
res = pd.DataFrame(res)
|
416
418
|
# The columns of res looks like e.g. table.A, table.B, so remove the prefix
|
@@ -530,7 +532,7 @@ class Milo:
|
|
530
532
|
|
531
533
|
anno_frac_dataframe = pd.DataFrame(anno_frac, columns=anno_dummies.columns, index=sample_adata.var_names)
|
532
534
|
sample_adata.varm["frac_annotation"] = anno_frac_dataframe.values
|
533
|
-
sample_adata.uns["annotation_labels"] = anno_frac_dataframe.columns
|
535
|
+
sample_adata.uns["annotation_labels"] = anno_frac_dataframe.columns.to_list()
|
534
536
|
sample_adata.uns["annotation_obs"] = anno_col
|
535
537
|
sample_adata.var["nhood_annotation"] = anno_frac_dataframe.idxmax(1)
|
536
538
|
sample_adata.var["nhood_annotation_frac"] = anno_frac_dataframe.max(1)
|
@@ -1,13 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import warnings
|
4
|
-
|
5
3
|
import anndata
|
6
4
|
import numpy as np
|
7
|
-
import pandas as pd
|
8
5
|
import scipy
|
9
6
|
import torch
|
10
7
|
from anndata import AnnData
|
8
|
+
from fast_array_utils.conv import to_dense
|
11
9
|
from pytorch_lightning import LightningModule, Trainer
|
12
10
|
from pytorch_lightning.callbacks import EarlyStopping
|
13
11
|
from sklearn.linear_model import LogisticRegression
|
@@ -112,18 +110,6 @@ class LRClassifierSpace(PerturbationSpace):
|
|
112
110
|
return pert_adata
|
113
111
|
|
114
112
|
|
115
|
-
# Ensure backward compatibility with DiscriminatorClassifierSpace
|
116
|
-
def DiscriminatorClassifierSpace():
|
117
|
-
warnings.warn(
|
118
|
-
"The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
|
119
|
-
"Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
|
120
|
-
DeprecationWarning,
|
121
|
-
stacklevel=2,
|
122
|
-
)
|
123
|
-
|
124
|
-
return MLPClassifierSpace()
|
125
|
-
|
126
|
-
|
127
113
|
class MLPClassifierSpace(PerturbationSpace):
|
128
114
|
"""Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
|
129
115
|
|
@@ -202,7 +188,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
202
188
|
labels = adata.obs[target_col].values.reshape(-1, 1)
|
203
189
|
encoder = OneHotEncoder()
|
204
190
|
encoded_labels = encoder.fit_transform(labels).toarray()
|
205
|
-
adata.
|
191
|
+
adata.obsm["encoded_perturbations"] = encoded_labels.astype(np.float32)
|
206
192
|
|
207
193
|
# Split the data in train, test and validation
|
208
194
|
X = list(range(adata.n_obs))
|
@@ -226,7 +212,7 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
226
212
|
# Fix class unbalance (likely to happen in perturbation datasets)
|
227
213
|
# Usually control cells are overrepresented such that predicting control all time would give good results
|
228
214
|
# Cells with rare perturbations are sampled more
|
229
|
-
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels
|
215
|
+
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
|
230
216
|
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
231
217
|
|
232
218
|
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
|
@@ -278,11 +264,10 @@ class MLPClassifierSpace(PerturbationSpace):
|
|
278
264
|
pert_adata.obs = pert_adata.obs.reset_index(drop=True)
|
279
265
|
if "perturbations" in self.adata_obs.columns:
|
280
266
|
self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)
|
267
|
+
obs_subset = self.adata_obs.iloc[: len(pert_adata.obs)].copy()
|
268
|
+
for col in obs_subset.columns:
|
269
|
+
if col not in ["perturbations", "encoded_perturbations"]:
|
270
|
+
pert_adata.obs[col] = obs_subset[col].values
|
286
271
|
|
287
272
|
return pert_adata
|
288
273
|
|
@@ -397,7 +382,13 @@ class PLDataset(Dataset):
|
|
397
382
|
else:
|
398
383
|
self.data = adata.X
|
399
384
|
|
400
|
-
|
385
|
+
if target_col in adata.obs.columns:
|
386
|
+
self.labels = adata.obs[target_col]
|
387
|
+
elif target_col in adata.obsm:
|
388
|
+
self.labels = adata.obsm[target_col]
|
389
|
+
else:
|
390
|
+
raise ValueError(f"Target column {target_col} not found in obs or obsm")
|
391
|
+
|
401
392
|
self.pert_labels = adata.obs[label_col]
|
402
393
|
|
403
394
|
def __len__(self):
|
@@ -405,8 +396,8 @@ class PLDataset(Dataset):
|
|
405
396
|
|
406
397
|
def __getitem__(self, idx):
|
407
398
|
"""Returns a sample and corresponding perturbations applied (labels)."""
|
408
|
-
sample = self.data[idx]
|
409
|
-
num_label = self.labels.iloc[idx]
|
399
|
+
sample = to_dense(self.data[idx]).squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
|
400
|
+
num_label = self.labels.iloc[idx] if hasattr(self.labels, "iloc") else self.labels[idx]
|
410
401
|
str_label = self.pert_labels.iloc[idx]
|
411
402
|
|
412
403
|
return sample, num_label, str_label
|
@@ -161,12 +161,20 @@ class PseudobulkSpace(PerturbationSpace):
|
|
161
161
|
adata = adata_emb
|
162
162
|
|
163
163
|
adata.obs[target_col] = adata.obs[target_col].astype("category")
|
164
|
+
grouping_cols = [target_col] if groups_col is None else [target_col, groups_col]
|
165
|
+
original_obs = adata.obs.copy()
|
164
166
|
ps_adata = sc.get.aggregate(
|
165
167
|
adata, by=[target_col] if groups_col is None else [target_col, groups_col], func=mode, layer=layer_key
|
166
168
|
)
|
169
|
+
|
167
170
|
if mode in ps_adata.layers:
|
168
171
|
ps_adata.X = ps_adata.layers[mode]
|
169
172
|
|
173
|
+
for col in original_obs.columns:
|
174
|
+
if col not in ps_adata.obs.columns:
|
175
|
+
grouped_values = original_obs.groupby(grouping_cols)[col].first()
|
176
|
+
ps_adata.obs[col] = grouped_values.reindex(ps_adata.obs.index).values
|
177
|
+
|
170
178
|
ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")
|
171
179
|
|
172
180
|
return ps_adata
|