pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
CHANGED
pertpy/metadata/_cell_line.py
CHANGED
@@ -703,7 +703,6 @@ class CellLine(MetaData):
|
|
703
703
|
metadata_key: str = "bulk_rna_broad",
|
704
704
|
category: str = "cell line",
|
705
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
706
|
-
show: bool = True,
|
707
706
|
return_fig: bool = False,
|
708
707
|
) -> Figure | None:
|
709
708
|
"""Visualise the correlation of cell lines with annotated metadata.
|
@@ -747,7 +746,7 @@ class CellLine(MetaData):
|
|
747
746
|
if all(isinstance(id, str) for id in subset_identifier_list):
|
748
747
|
if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
|
749
748
|
subset_identifier_list = np.where(
|
750
|
-
np.
|
749
|
+
np.isin(adata.obs[identifier].values, subset_identifier_list)
|
751
750
|
)[0]
|
752
751
|
else:
|
753
752
|
raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
|
@@ -798,10 +797,9 @@ class CellLine(MetaData):
|
|
798
797
|
},
|
799
798
|
)
|
800
799
|
|
801
|
-
if show:
|
802
|
-
plt.show()
|
803
800
|
if return_fig:
|
804
801
|
return plt.gcf()
|
802
|
+
plt.show()
|
805
803
|
return None
|
806
804
|
else:
|
807
|
-
raise NotImplementedError
|
805
|
+
raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
|
@@ -1,15 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import TYPE_CHECKING
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
from warnings import warn
|
5
6
|
|
6
7
|
import matplotlib.pyplot as plt
|
7
8
|
import numpy as np
|
8
9
|
import pandas as pd
|
9
10
|
import scanpy as sc
|
10
11
|
import scipy
|
12
|
+
from rich.progress import track
|
13
|
+
from scipy.sparse import issparse
|
11
14
|
|
12
15
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
16
|
+
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
|
13
17
|
|
14
18
|
if TYPE_CHECKING:
|
15
19
|
from anndata import AnnData
|
@@ -17,7 +21,7 @@ if TYPE_CHECKING:
|
|
17
21
|
|
18
22
|
|
19
23
|
class GuideAssignment:
|
20
|
-
"""
|
24
|
+
"""Assign cells to guide RNAs."""
|
21
25
|
|
22
26
|
def assign_by_threshold(
|
23
27
|
self,
|
@@ -33,12 +37,12 @@ class GuideAssignment:
|
|
33
37
|
This function expects unnormalized data as input.
|
34
38
|
|
35
39
|
Args:
|
36
|
-
adata:
|
40
|
+
adata: AnnData object containing gRNA values.
|
37
41
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
38
42
|
layer: Key to the layer containing raw count values of the gRNAs.
|
39
43
|
adata.X is used if layer is None. Expects count data.
|
40
44
|
output_layer: Assigned guide will be saved on adata.layers[output_key].
|
41
|
-
only_return_results:
|
45
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
|
42
46
|
|
43
47
|
Examples:
|
44
48
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
|
@@ -67,7 +71,7 @@ class GuideAssignment:
|
|
67
71
|
assignment_threshold: float,
|
68
72
|
layer: str | None = None,
|
69
73
|
output_key: str = "assigned_guide",
|
70
|
-
no_grna_assigned_key: str = "
|
74
|
+
no_grna_assigned_key: str = "Negative",
|
71
75
|
only_return_results: bool = False,
|
72
76
|
) -> np.ndarray | None:
|
73
77
|
"""Simple threshold based max gRNA assignment function.
|
@@ -76,13 +80,13 @@ class GuideAssignment:
|
|
76
80
|
This function expects unnormalized data as input.
|
77
81
|
|
78
82
|
Args:
|
79
|
-
adata:
|
83
|
+
adata: AnnData object containing gRNA values.
|
80
84
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
81
85
|
layer: Key to the layer containing raw count values of the gRNAs.
|
82
86
|
adata.X is used if layer is None. Expects count data.
|
83
87
|
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
|
84
88
|
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
|
85
|
-
only_return_results:
|
89
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
|
86
90
|
|
87
91
|
Examples:
|
88
92
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
@@ -109,6 +113,92 @@ class GuideAssignment:
|
|
109
113
|
|
110
114
|
return None
|
111
115
|
|
116
|
+
def assign_mixture_model(
|
117
|
+
self,
|
118
|
+
adata: AnnData,
|
119
|
+
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
|
120
|
+
assigned_guides_key: str = "assigned_guide",
|
121
|
+
no_grna_assigned_key: str = "negative",
|
122
|
+
max_assignments_per_cell: int = 5,
|
123
|
+
multiple_grna_assigned_key: str = "multiple",
|
124
|
+
multiple_grna_assignment_string: str = "+",
|
125
|
+
only_return_results: bool = False,
|
126
|
+
uns_key: str = "guide_assignment_params",
|
127
|
+
show_progress: bool = False,
|
128
|
+
**mixture_model_kwargs,
|
129
|
+
) -> np.ndarray | None:
|
130
|
+
"""Assigns gRNAs to cells using a mixture model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
adata: AnnData object containing gRNA values.
|
134
|
+
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
|
135
|
+
output_key: Assigned guide will be saved on adata.obs[output_key].
|
136
|
+
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
|
137
|
+
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
|
138
|
+
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
|
139
|
+
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
|
140
|
+
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
|
141
|
+
show_progress: Whether to shows progress bar.
|
142
|
+
mixture_model_kwargs: Are passed to the mixture model.
|
143
|
+
|
144
|
+
Examples:
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_mixture_model(gdo)
|
150
|
+
"""
|
151
|
+
if model == "poisson_gauss_mixture":
|
152
|
+
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
|
153
|
+
else:
|
154
|
+
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
|
155
|
+
|
156
|
+
if uns_key not in adata.uns:
|
157
|
+
adata.uns[uns_key] = {}
|
158
|
+
elif type(adata.uns[uns_key]) is not dict:
|
159
|
+
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
|
160
|
+
|
161
|
+
res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
|
162
|
+
fct = track if show_progress else lambda iterable: iterable
|
163
|
+
for gene in fct(adata.var_names):
|
164
|
+
is_nonzero = (
|
165
|
+
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
|
166
|
+
)
|
167
|
+
if sum(is_nonzero) < 2:
|
168
|
+
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
|
169
|
+
continue
|
170
|
+
# We are only fitting the model to the non-zero values, the rest is
|
171
|
+
# automatically assigned to the negative class
|
172
|
+
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
|
173
|
+
data = np.ravel(data)
|
174
|
+
|
175
|
+
if np.any(data < 0):
|
176
|
+
raise ValueError(
|
177
|
+
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
|
178
|
+
)
|
179
|
+
|
180
|
+
# Log2 transform the data so positive population is approximately normal
|
181
|
+
data = np.log2(data)
|
182
|
+
assignments = mixture_model.run_model(data)
|
183
|
+
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
|
184
|
+
adata.uns[uns_key][gene] = mixture_model.params
|
185
|
+
|
186
|
+
# Assign guides to cells
|
187
|
+
# Some cells might have multiple guides assigned
|
188
|
+
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
|
189
|
+
num_guides_assigned = res.sum(1)
|
190
|
+
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
|
191
|
+
lambda row: row.index[row == 1].tolist(), axis=1
|
192
|
+
).str.join(multiple_grna_assignment_string)
|
193
|
+
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
|
194
|
+
|
195
|
+
if only_return_results:
|
196
|
+
return series.values
|
197
|
+
|
198
|
+
adata.obs[assigned_guides_key] = series.values
|
199
|
+
|
200
|
+
return None
|
201
|
+
|
112
202
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
113
203
|
def plot_heatmap(
|
114
204
|
self,
|
@@ -117,7 +207,6 @@ class GuideAssignment:
|
|
117
207
|
layer: str | None = None,
|
118
208
|
order_by: np.ndarray | str | None = None,
|
119
209
|
key_to_save_order: str = None,
|
120
|
-
show: bool = True,
|
121
210
|
return_fig: bool = False,
|
122
211
|
**kwargs,
|
123
212
|
) -> Figure | None:
|
@@ -194,8 +283,7 @@ class GuideAssignment:
|
|
194
283
|
finally:
|
195
284
|
del adata.obs[temp_col_name]
|
196
285
|
|
197
|
-
if show:
|
198
|
-
plt.show()
|
199
286
|
if return_fig:
|
200
287
|
return fig
|
288
|
+
plt.show()
|
201
289
|
return None
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
import numpyro
|
10
|
+
import numpyro.distributions as dist
|
11
|
+
from jax import random
|
12
|
+
from numpyro.infer import MCMC, NUTS
|
13
|
+
|
14
|
+
ParamsDict = Mapping[str, jnp.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
class MixtureModel(ABC):
|
18
|
+
"""Abstract base class for 2-component mixture models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
num_warmup: Number of warmup steps for MCMC sampling.
|
22
|
+
num_samples: Number of samples to draw after warmup.
|
23
|
+
fraction_positive_expected: Prior belief about fraction of positive components.
|
24
|
+
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
|
25
|
+
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
|
26
|
+
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
num_warmup: int = 50,
|
32
|
+
num_samples: int = 100,
|
33
|
+
fraction_positive_expected: float = 0.15,
|
34
|
+
poisson_rate_prior: float = 0.2,
|
35
|
+
gaussian_mean_prior: tuple[float, float] = (3, 2),
|
36
|
+
gaussian_std_prior: float = 1,
|
37
|
+
) -> None:
|
38
|
+
self.num_warmup = num_warmup
|
39
|
+
self.num_samples = num_samples
|
40
|
+
self.fraction_positive_expected = fraction_positive_expected
|
41
|
+
self.poisson_rate_prior = poisson_rate_prior
|
42
|
+
self.gaussian_mean_prior = gaussian_mean_prior
|
43
|
+
self.gaussian_std_prior = gaussian_std_prior
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def initialize_params(self) -> ParamsDict:
|
47
|
+
"""Initialize model parameters via sampling from priors.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Dictionary of sampled parameter values.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
|
56
|
+
"""Calculate log likelihood of data under current parameters.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
data: Input data array.
|
60
|
+
params: Current parameter values.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Log likelihood values for each datapoint.
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
|
68
|
+
"""Fit the mixture model using MCMC.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
data: Input data to fit.
|
72
|
+
seed: Random seed for reproducibility.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Fitted MCMC object containing samples.
|
76
|
+
"""
|
77
|
+
nuts_kernel = NUTS(self.mixture_model)
|
78
|
+
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
|
79
|
+
mcmc.run(random.PRNGKey(seed), data=data)
|
80
|
+
return mcmc
|
81
|
+
|
82
|
+
def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
|
83
|
+
"""Run model fitting and assign components.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
data: Input data array.
|
87
|
+
seed: Random seed.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Array of "Positive"/"Negative" assignments for each datapoint.
|
91
|
+
"""
|
92
|
+
self.mcmc = self.fit_model(data, seed)
|
93
|
+
self.samples = self.mcmc.get_samples()
|
94
|
+
self.assignments = self.assignment(self.samples, data)
|
95
|
+
return self.assignments
|
96
|
+
|
97
|
+
def mixture_model(self, data: jnp.ndarray) -> None:
|
98
|
+
"""Define mixture model structure for NumPyro.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
data: Input data array.
|
102
|
+
"""
|
103
|
+
params = self.initialize_params()
|
104
|
+
|
105
|
+
with numpyro.plate("data", data.shape[0]):
|
106
|
+
log_likelihoods = self.log_likelihood(data, params)
|
107
|
+
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
|
108
|
+
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
|
109
|
+
|
110
|
+
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
111
|
+
"""Assign data points to mixture components.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
samples: MCMC samples of parameters.
|
115
|
+
data: Input data array.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Array of component assignments.
|
119
|
+
"""
|
120
|
+
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
|
121
|
+
self.params = params
|
122
|
+
|
123
|
+
log_likelihoods = self.log_likelihood(data, params)
|
124
|
+
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
|
125
|
+
|
126
|
+
assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
|
127
|
+
return np.array(assignments)
|
128
|
+
|
129
|
+
|
130
|
+
class PoissonGaussMixture(MixtureModel):
|
131
|
+
"""Mixture model combining Poisson and Gaussian distributions."""
|
132
|
+
|
133
|
+
def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
|
134
|
+
"""Calculate component-wise log likelihoods.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
data: Input data array.
|
138
|
+
params: Current parameter values.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Log likelihood values for each component.
|
142
|
+
"""
|
143
|
+
poisson_rate = params["poisson_rate"]
|
144
|
+
gaussian_mean = params["gaussian_mean"]
|
145
|
+
gaussian_std = params["gaussian_std"]
|
146
|
+
mix_probs = params["mix_probs"]
|
147
|
+
|
148
|
+
# We penalize the model for positioning the Poisson component to the right of the Gaussian component
|
149
|
+
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
|
150
|
+
# Heuristic regularization term to prevent flipping of the components
|
151
|
+
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
|
152
|
+
|
153
|
+
log_likelihoods = jnp.stack(
|
154
|
+
[
|
155
|
+
# Poisson component
|
156
|
+
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
|
157
|
+
# Gaussian component
|
158
|
+
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
|
159
|
+
],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
return log_likelihoods
|
164
|
+
|
165
|
+
def initialize_params(self) -> ParamsDict:
|
166
|
+
"""Initialize model parameters via prior sampling.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Dictionary of sampled parameter values.
|
170
|
+
"""
|
171
|
+
params = {}
|
172
|
+
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
|
173
|
+
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
|
174
|
+
params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
|
175
|
+
params["mix_probs"] = numpyro.sample(
|
176
|
+
"mix_probs",
|
177
|
+
dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
|
178
|
+
)
|
179
|
+
return params
|
pertpy/tools/_augur.py
CHANGED
@@ -685,7 +685,7 @@ class Augur:
|
|
685
685
|
span: float = 0.75,
|
686
686
|
filter_negative_residuals: bool = False,
|
687
687
|
n_threads: int = 4,
|
688
|
-
augur_mode: Literal["
|
688
|
+
augur_mode: Literal["default", "permute", "velocity"] = "default",
|
689
689
|
select_variance_features: bool = True,
|
690
690
|
key_added: str = "augurpy_results",
|
691
691
|
random_state: int | None = None,
|
@@ -908,41 +908,39 @@ class Augur:
|
|
908
908
|
.mean()
|
909
909
|
)
|
910
910
|
|
911
|
-
|
912
|
-
|
911
|
+
rng = np.random.default_rng()
|
912
|
+
sampled_data = []
|
913
913
|
|
914
914
|
# draw mean aucs for permute1 and permute2
|
915
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
916
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
917
917
|
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
918
|
-
for permutation_idx in range(n_permutations):
|
919
|
-
# subsample
|
920
|
-
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
921
|
-
sampled_permuted_cv_augur1.append(
|
922
|
-
pd.DataFrame(
|
923
|
-
{
|
924
|
-
"cell_type": [celltype],
|
925
|
-
"permutation_idx": [permutation_idx],
|
926
|
-
"mean": [sample1["augur_score"].mean(axis=0)],
|
927
|
-
"std": [sample1["augur_score"].std(axis=0)],
|
928
|
-
}
|
929
|
-
)
|
930
|
-
)
|
931
918
|
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
919
|
+
indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
|
920
|
+
indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
|
921
|
+
|
922
|
+
scores1 = df1["augur_score"].values[indices1]
|
923
|
+
scores2 = df2["augur_score"].values[indices2]
|
924
|
+
|
925
|
+
means1 = scores1.mean(axis=1)
|
926
|
+
means2 = scores2.mean(axis=1)
|
927
|
+
stds1 = scores1.std(axis=1)
|
928
|
+
stds2 = scores2.std(axis=1)
|
929
|
+
|
930
|
+
sampled_data.append(
|
931
|
+
pd.DataFrame(
|
932
|
+
{
|
933
|
+
"cell_type": np.repeat(celltype, n_permutations),
|
934
|
+
"permutation_idx": np.arange(n_permutations),
|
935
|
+
"mean1": means1,
|
936
|
+
"mean2": means2,
|
937
|
+
"std1": stds1,
|
938
|
+
"std2": stds2,
|
939
|
+
}
|
942
940
|
)
|
941
|
+
)
|
943
942
|
|
944
|
-
|
945
|
-
permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
|
943
|
+
sampled_df = pd.concat(sampled_data)
|
946
944
|
|
947
945
|
# delta between augur scores
|
948
946
|
delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
|
@@ -950,9 +948,7 @@ class Augur:
|
|
950
948
|
)
|
951
949
|
|
952
950
|
# delta between permutation scores
|
953
|
-
delta_rnd =
|
954
|
-
permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
|
955
|
-
).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
951
|
+
delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
956
952
|
|
957
953
|
# number of values where permutations are larger than test statistic
|
958
954
|
delta["b"] = (
|
@@ -967,7 +963,7 @@ class Augur:
|
|
967
963
|
delta["z"] = (
|
968
964
|
delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
|
969
965
|
) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
|
970
|
-
|
966
|
+
|
971
967
|
delta["pval"] = np.minimum(
|
972
968
|
2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
|
973
969
|
)
|
@@ -982,7 +978,6 @@ class Augur:
|
|
982
978
|
*,
|
983
979
|
top_n: int = None,
|
984
980
|
ax: Axes = None,
|
985
|
-
show: bool = True,
|
986
981
|
return_fig: bool = False,
|
987
982
|
) -> Figure | None:
|
988
983
|
"""Plot scatterplot of differential prioritization.
|
@@ -1041,10 +1036,9 @@ class Augur:
|
|
1041
1036
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1042
1037
|
ax.add_artist(legend1)
|
1043
1038
|
|
1044
|
-
if show:
|
1045
|
-
plt.show()
|
1046
1039
|
if return_fig:
|
1047
1040
|
return plt.gcf()
|
1041
|
+
plt.show()
|
1048
1042
|
return None
|
1049
1043
|
|
1050
1044
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1055,7 +1049,6 @@ class Augur:
|
|
1055
1049
|
key: str = "augurpy_results",
|
1056
1050
|
top_n: int = 10,
|
1057
1051
|
ax: Axes = None,
|
1058
|
-
show: bool = True,
|
1059
1052
|
return_fig: bool = False,
|
1060
1053
|
) -> Figure | None:
|
1061
1054
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
@@ -1109,10 +1102,9 @@ class Augur:
|
|
1109
1102
|
plt.ylabel("Gene")
|
1110
1103
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1104
|
|
1112
|
-
if show:
|
1113
|
-
plt.show()
|
1114
1105
|
if return_fig:
|
1115
1106
|
return plt.gcf()
|
1107
|
+
plt.show()
|
1116
1108
|
return None
|
1117
1109
|
|
1118
1110
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1122,7 +1114,6 @@ class Augur:
|
|
1122
1114
|
*,
|
1123
1115
|
key: str = "augurpy_results",
|
1124
1116
|
ax: Axes = None,
|
1125
|
-
show: bool = True,
|
1126
1117
|
return_fig: bool = False,
|
1127
1118
|
) -> Figure | None:
|
1128
1119
|
"""Plot a lollipop plot of the mean augur values.
|
@@ -1172,10 +1163,9 @@ class Augur:
|
|
1172
1163
|
plt.ylabel("Cell Type")
|
1173
1164
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1174
1165
|
|
1175
|
-
if show:
|
1176
|
-
plt.show()
|
1177
1166
|
if return_fig:
|
1178
1167
|
return plt.gcf()
|
1168
|
+
plt.show()
|
1179
1169
|
return None
|
1180
1170
|
|
1181
1171
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1185,7 +1175,6 @@ class Augur:
|
|
1185
1175
|
results2: dict[str, Any],
|
1186
1176
|
*,
|
1187
1177
|
top_n: int = None,
|
1188
|
-
show: bool = True,
|
1189
1178
|
return_fig: bool = False,
|
1190
1179
|
) -> Figure | None:
|
1191
1180
|
"""Create scatterplot with two augur results.
|
@@ -1243,8 +1232,7 @@ class Augur:
|
|
1243
1232
|
plt.xlabel("Augur scores 1")
|
1244
1233
|
plt.ylabel("Augur scores 2")
|
1245
1234
|
|
1246
|
-
if show:
|
1247
|
-
plt.show()
|
1248
1235
|
if return_fig:
|
1249
1236
|
return plt.gcf()
|
1237
|
+
plt.show()
|
1250
1238
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -658,7 +658,6 @@ class Cinemaot:
|
|
658
658
|
title: str = "CINEMA-OT matching matrix",
|
659
659
|
min_val: float = 0.01,
|
660
660
|
ax: Axes | None = None,
|
661
|
-
show: bool = True,
|
662
661
|
return_fig: bool = False,
|
663
662
|
**kwargs,
|
664
663
|
) -> Figure | None:
|
@@ -717,10 +716,9 @@ class Cinemaot:
|
|
717
716
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
718
717
|
plt.title(title)
|
719
718
|
|
720
|
-
if show:
|
721
|
-
plt.show()
|
722
719
|
if return_fig:
|
723
720
|
return g
|
721
|
+
plt.show()
|
724
722
|
return None
|
725
723
|
|
726
724
|
|