pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/__init__.py
CHANGED
pertpy/_doc.py
CHANGED
pertpy/metadata/_cell_line.py
CHANGED
@@ -703,7 +703,6 @@ class CellLine(MetaData):
|
|
703
703
|
metadata_key: str = "bulk_rna_broad",
|
704
704
|
category: str = "cell line",
|
705
705
|
subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
|
706
|
-
show: bool = True,
|
707
706
|
return_fig: bool = False,
|
708
707
|
) -> Figure | None:
|
709
708
|
"""Visualise the correlation of cell lines with annotated metadata.
|
@@ -747,7 +746,7 @@ class CellLine(MetaData):
|
|
747
746
|
if all(isinstance(id, str) for id in subset_identifier_list):
|
748
747
|
if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
|
749
748
|
subset_identifier_list = np.where(
|
750
|
-
np.
|
749
|
+
np.isin(adata.obs[identifier].values, subset_identifier_list)
|
751
750
|
)[0]
|
752
751
|
else:
|
753
752
|
raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
|
@@ -798,10 +797,9 @@ class CellLine(MetaData):
|
|
798
797
|
},
|
799
798
|
)
|
800
799
|
|
801
|
-
if show:
|
802
|
-
plt.show()
|
803
800
|
if return_fig:
|
804
801
|
return plt.gcf()
|
802
|
+
plt.show()
|
805
803
|
return None
|
806
804
|
else:
|
807
|
-
raise NotImplementedError
|
805
|
+
raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
|
@@ -1,15 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import uuid
|
4
|
-
from typing import TYPE_CHECKING
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
from warnings import warn
|
5
6
|
|
6
7
|
import matplotlib.pyplot as plt
|
7
8
|
import numpy as np
|
8
9
|
import pandas as pd
|
9
10
|
import scanpy as sc
|
10
11
|
import scipy
|
12
|
+
from rich.progress import track
|
13
|
+
from scipy.sparse import issparse
|
11
14
|
|
12
15
|
from pertpy._doc import _doc_params, doc_common_plot_args
|
16
|
+
from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
|
13
17
|
|
14
18
|
if TYPE_CHECKING:
|
15
19
|
from anndata import AnnData
|
@@ -17,7 +21,7 @@ if TYPE_CHECKING:
|
|
17
21
|
|
18
22
|
|
19
23
|
class GuideAssignment:
|
20
|
-
"""
|
24
|
+
"""Assign cells to guide RNAs."""
|
21
25
|
|
22
26
|
def assign_by_threshold(
|
23
27
|
self,
|
@@ -33,12 +37,12 @@ class GuideAssignment:
|
|
33
37
|
This function expects unnormalized data as input.
|
34
38
|
|
35
39
|
Args:
|
36
|
-
adata:
|
40
|
+
adata: AnnData object containing gRNA values.
|
37
41
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
38
42
|
layer: Key to the layer containing raw count values of the gRNAs.
|
39
43
|
adata.X is used if layer is None. Expects count data.
|
40
44
|
output_layer: Assigned guide will be saved on adata.layers[output_key].
|
41
|
-
only_return_results:
|
45
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
|
42
46
|
|
43
47
|
Examples:
|
44
48
|
Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
|
@@ -67,7 +71,7 @@ class GuideAssignment:
|
|
67
71
|
assignment_threshold: float,
|
68
72
|
layer: str | None = None,
|
69
73
|
output_key: str = "assigned_guide",
|
70
|
-
no_grna_assigned_key: str = "
|
74
|
+
no_grna_assigned_key: str = "Negative",
|
71
75
|
only_return_results: bool = False,
|
72
76
|
) -> np.ndarray | None:
|
73
77
|
"""Simple threshold based max gRNA assignment function.
|
@@ -76,13 +80,13 @@ class GuideAssignment:
|
|
76
80
|
This function expects unnormalized data as input.
|
77
81
|
|
78
82
|
Args:
|
79
|
-
adata:
|
83
|
+
adata: AnnData object containing gRNA values.
|
80
84
|
assignment_threshold: The count threshold that is required for an assignment to be viable.
|
81
85
|
layer: Key to the layer containing raw count values of the gRNAs.
|
82
86
|
adata.X is used if layer is None. Expects count data.
|
83
87
|
output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
|
84
88
|
no_grna_assigned_key: The key to return if no gRNA is expressed enough.
|
85
|
-
only_return_results:
|
89
|
+
only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
|
86
90
|
|
87
91
|
Examples:
|
88
92
|
Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
|
@@ -109,6 +113,92 @@ class GuideAssignment:
|
|
109
113
|
|
110
114
|
return None
|
111
115
|
|
116
|
+
def assign_mixture_model(
|
117
|
+
self,
|
118
|
+
adata: AnnData,
|
119
|
+
model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
|
120
|
+
assigned_guides_key: str = "assigned_guide",
|
121
|
+
no_grna_assigned_key: str = "negative",
|
122
|
+
max_assignments_per_cell: int = 5,
|
123
|
+
multiple_grna_assigned_key: str = "multiple",
|
124
|
+
multiple_grna_assignment_string: str = "+",
|
125
|
+
only_return_results: bool = False,
|
126
|
+
uns_key: str = "guide_assignment_params",
|
127
|
+
show_progress: bool = False,
|
128
|
+
**mixture_model_kwargs,
|
129
|
+
) -> np.ndarray | None:
|
130
|
+
"""Assigns gRNAs to cells using a mixture model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
adata: AnnData object containing gRNA values.
|
134
|
+
model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
|
135
|
+
output_key: Assigned guide will be saved on adata.obs[output_key].
|
136
|
+
no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
|
137
|
+
max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
|
138
|
+
multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
|
139
|
+
multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
|
140
|
+
only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
|
141
|
+
show_progress: Whether to shows progress bar.
|
142
|
+
mixture_model_kwargs: Are passed to the mixture model.
|
143
|
+
|
144
|
+
Examples:
|
145
|
+
>>> import pertpy as pt
|
146
|
+
>>> mdata = pt.dt.papalexi_2021()
|
147
|
+
>>> gdo = mdata.mod["gdo"]
|
148
|
+
>>> ga = pt.pp.GuideAssignment()
|
149
|
+
>>> ga.assign_mixture_model(gdo)
|
150
|
+
"""
|
151
|
+
if model == "poisson_gauss_mixture":
|
152
|
+
mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
|
153
|
+
else:
|
154
|
+
raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
|
155
|
+
|
156
|
+
if uns_key not in adata.uns:
|
157
|
+
adata.uns[uns_key] = {}
|
158
|
+
elif type(adata.uns[uns_key]) is not dict:
|
159
|
+
raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
|
160
|
+
|
161
|
+
res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
|
162
|
+
fct = track if show_progress else lambda iterable: iterable
|
163
|
+
for gene in fct(adata.var_names):
|
164
|
+
is_nonzero = (
|
165
|
+
np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
|
166
|
+
)
|
167
|
+
if sum(is_nonzero) < 2:
|
168
|
+
warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
|
169
|
+
continue
|
170
|
+
# We are only fitting the model to the non-zero values, the rest is
|
171
|
+
# automatically assigned to the negative class
|
172
|
+
data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
|
173
|
+
data = np.ravel(data)
|
174
|
+
|
175
|
+
if np.any(data < 0):
|
176
|
+
raise ValueError(
|
177
|
+
"Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
|
178
|
+
)
|
179
|
+
|
180
|
+
# Log2 transform the data so positive population is approximately normal
|
181
|
+
data = np.log2(data)
|
182
|
+
assignments = mixture_model.run_model(data)
|
183
|
+
res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
|
184
|
+
adata.uns[uns_key][gene] = mixture_model.params
|
185
|
+
|
186
|
+
# Assign guides to cells
|
187
|
+
# Some cells might have multiple guides assigned
|
188
|
+
series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
|
189
|
+
num_guides_assigned = res.sum(1)
|
190
|
+
series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
|
191
|
+
lambda row: row.index[row == 1].tolist(), axis=1
|
192
|
+
).str.join(multiple_grna_assignment_string)
|
193
|
+
series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
|
194
|
+
|
195
|
+
if only_return_results:
|
196
|
+
return series.values
|
197
|
+
|
198
|
+
adata.obs[assigned_guides_key] = series.values
|
199
|
+
|
200
|
+
return None
|
201
|
+
|
112
202
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
113
203
|
def plot_heatmap(
|
114
204
|
self,
|
@@ -117,7 +207,6 @@ class GuideAssignment:
|
|
117
207
|
layer: str | None = None,
|
118
208
|
order_by: np.ndarray | str | None = None,
|
119
209
|
key_to_save_order: str = None,
|
120
|
-
show: bool = True,
|
121
210
|
return_fig: bool = False,
|
122
211
|
**kwargs,
|
123
212
|
) -> Figure | None:
|
@@ -194,8 +283,7 @@ class GuideAssignment:
|
|
194
283
|
finally:
|
195
284
|
del adata.obs[temp_col_name]
|
196
285
|
|
197
|
-
if show:
|
198
|
-
plt.show()
|
199
286
|
if return_fig:
|
200
287
|
return fig
|
288
|
+
plt.show()
|
201
289
|
return None
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
import numpyro
|
10
|
+
import numpyro.distributions as dist
|
11
|
+
from jax import random
|
12
|
+
from numpyro.infer import MCMC, NUTS
|
13
|
+
|
14
|
+
ParamsDict = Mapping[str, jnp.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
class MixtureModel(ABC):
|
18
|
+
"""Abstract base class for 2-component mixture models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
num_warmup: Number of warmup steps for MCMC sampling.
|
22
|
+
num_samples: Number of samples to draw after warmup.
|
23
|
+
fraction_positive_expected: Prior belief about fraction of positive components.
|
24
|
+
poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
|
25
|
+
gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
|
26
|
+
gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
num_warmup: int = 50,
|
32
|
+
num_samples: int = 100,
|
33
|
+
fraction_positive_expected: float = 0.15,
|
34
|
+
poisson_rate_prior: float = 0.2,
|
35
|
+
gaussian_mean_prior: tuple[float, float] = (3, 2),
|
36
|
+
gaussian_std_prior: float = 1,
|
37
|
+
) -> None:
|
38
|
+
self.num_warmup = num_warmup
|
39
|
+
self.num_samples = num_samples
|
40
|
+
self.fraction_positive_expected = fraction_positive_expected
|
41
|
+
self.poisson_rate_prior = poisson_rate_prior
|
42
|
+
self.gaussian_mean_prior = gaussian_mean_prior
|
43
|
+
self.gaussian_std_prior = gaussian_std_prior
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def initialize_params(self) -> ParamsDict:
|
47
|
+
"""Initialize model parameters via sampling from priors.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Dictionary of sampled parameter values.
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
|
56
|
+
"""Calculate log likelihood of data under current parameters.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
data: Input data array.
|
60
|
+
params: Current parameter values.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Log likelihood values for each datapoint.
|
64
|
+
"""
|
65
|
+
pass
|
66
|
+
|
67
|
+
def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
|
68
|
+
"""Fit the mixture model using MCMC.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
data: Input data to fit.
|
72
|
+
seed: Random seed for reproducibility.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Fitted MCMC object containing samples.
|
76
|
+
"""
|
77
|
+
nuts_kernel = NUTS(self.mixture_model)
|
78
|
+
mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
|
79
|
+
mcmc.run(random.PRNGKey(seed), data=data)
|
80
|
+
return mcmc
|
81
|
+
|
82
|
+
def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
|
83
|
+
"""Run model fitting and assign components.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
data: Input data array.
|
87
|
+
seed: Random seed.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Array of "Positive"/"Negative" assignments for each datapoint.
|
91
|
+
"""
|
92
|
+
self.mcmc = self.fit_model(data, seed)
|
93
|
+
self.samples = self.mcmc.get_samples()
|
94
|
+
self.assignments = self.assignment(self.samples, data)
|
95
|
+
return self.assignments
|
96
|
+
|
97
|
+
def mixture_model(self, data: jnp.ndarray) -> None:
|
98
|
+
"""Define mixture model structure for NumPyro.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
data: Input data array.
|
102
|
+
"""
|
103
|
+
params = self.initialize_params()
|
104
|
+
|
105
|
+
with numpyro.plate("data", data.shape[0]):
|
106
|
+
log_likelihoods = self.log_likelihood(data, params)
|
107
|
+
log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
|
108
|
+
numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
|
109
|
+
|
110
|
+
def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
|
111
|
+
"""Assign data points to mixture components.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
samples: MCMC samples of parameters.
|
115
|
+
data: Input data array.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Array of component assignments.
|
119
|
+
"""
|
120
|
+
params = {key: samples[key].mean(axis=0) for key in samples.keys()}
|
121
|
+
self.params = params
|
122
|
+
|
123
|
+
log_likelihoods = self.log_likelihood(data, params)
|
124
|
+
guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
|
125
|
+
|
126
|
+
assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
|
127
|
+
return np.array(assignments)
|
128
|
+
|
129
|
+
|
130
|
+
class PoissonGaussMixture(MixtureModel):
|
131
|
+
"""Mixture model combining Poisson and Gaussian distributions."""
|
132
|
+
|
133
|
+
def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
|
134
|
+
"""Calculate component-wise log likelihoods.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
data: Input data array.
|
138
|
+
params: Current parameter values.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Log likelihood values for each component.
|
142
|
+
"""
|
143
|
+
poisson_rate = params["poisson_rate"]
|
144
|
+
gaussian_mean = params["gaussian_mean"]
|
145
|
+
gaussian_std = params["gaussian_std"]
|
146
|
+
mix_probs = params["mix_probs"]
|
147
|
+
|
148
|
+
# We penalize the model for positioning the Poisson component to the right of the Gaussian component
|
149
|
+
# by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
|
150
|
+
# Heuristic regularization term to prevent flipping of the components
|
151
|
+
numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
|
152
|
+
|
153
|
+
log_likelihoods = jnp.stack(
|
154
|
+
[
|
155
|
+
# Poisson component
|
156
|
+
jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
|
157
|
+
# Gaussian component
|
158
|
+
jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
|
159
|
+
],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
return log_likelihoods
|
164
|
+
|
165
|
+
def initialize_params(self) -> ParamsDict:
|
166
|
+
"""Initialize model parameters via prior sampling.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Dictionary of sampled parameter values.
|
170
|
+
"""
|
171
|
+
params = {}
|
172
|
+
params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
|
173
|
+
params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
|
174
|
+
params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
|
175
|
+
params["mix_probs"] = numpyro.sample(
|
176
|
+
"mix_probs",
|
177
|
+
dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
|
178
|
+
)
|
179
|
+
return params
|
pertpy/tools/_augur.py
CHANGED
@@ -685,7 +685,7 @@ class Augur:
|
|
685
685
|
span: float = 0.75,
|
686
686
|
filter_negative_residuals: bool = False,
|
687
687
|
n_threads: int = 4,
|
688
|
-
augur_mode: Literal["
|
688
|
+
augur_mode: Literal["default", "permute", "velocity"] = "default",
|
689
689
|
select_variance_features: bool = True,
|
690
690
|
key_added: str = "augurpy_results",
|
691
691
|
random_state: int | None = None,
|
@@ -908,41 +908,39 @@ class Augur:
|
|
908
908
|
.mean()
|
909
909
|
)
|
910
910
|
|
911
|
-
|
912
|
-
|
911
|
+
rng = np.random.default_rng()
|
912
|
+
sampled_data = []
|
913
913
|
|
914
914
|
# draw mean aucs for permute1 and permute2
|
915
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
916
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
917
917
|
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
918
|
-
for permutation_idx in range(n_permutations):
|
919
|
-
# subsample
|
920
|
-
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
921
|
-
sampled_permuted_cv_augur1.append(
|
922
|
-
pd.DataFrame(
|
923
|
-
{
|
924
|
-
"cell_type": [celltype],
|
925
|
-
"permutation_idx": [permutation_idx],
|
926
|
-
"mean": [sample1["augur_score"].mean(axis=0)],
|
927
|
-
"std": [sample1["augur_score"].std(axis=0)],
|
928
|
-
}
|
929
|
-
)
|
930
|
-
)
|
931
918
|
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
919
|
+
indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
|
920
|
+
indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
|
921
|
+
|
922
|
+
scores1 = df1["augur_score"].values[indices1]
|
923
|
+
scores2 = df2["augur_score"].values[indices2]
|
924
|
+
|
925
|
+
means1 = scores1.mean(axis=1)
|
926
|
+
means2 = scores2.mean(axis=1)
|
927
|
+
stds1 = scores1.std(axis=1)
|
928
|
+
stds2 = scores2.std(axis=1)
|
929
|
+
|
930
|
+
sampled_data.append(
|
931
|
+
pd.DataFrame(
|
932
|
+
{
|
933
|
+
"cell_type": np.repeat(celltype, n_permutations),
|
934
|
+
"permutation_idx": np.arange(n_permutations),
|
935
|
+
"mean1": means1,
|
936
|
+
"mean2": means2,
|
937
|
+
"std1": stds1,
|
938
|
+
"std2": stds2,
|
939
|
+
}
|
942
940
|
)
|
941
|
+
)
|
943
942
|
|
944
|
-
|
945
|
-
permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
|
943
|
+
sampled_df = pd.concat(sampled_data)
|
946
944
|
|
947
945
|
# delta between augur scores
|
948
946
|
delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
|
@@ -950,9 +948,7 @@ class Augur:
|
|
950
948
|
)
|
951
949
|
|
952
950
|
# delta between permutation scores
|
953
|
-
delta_rnd =
|
954
|
-
permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
|
955
|
-
).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
951
|
+
delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
956
952
|
|
957
953
|
# number of values where permutations are larger than test statistic
|
958
954
|
delta["b"] = (
|
@@ -967,7 +963,7 @@ class Augur:
|
|
967
963
|
delta["z"] = (
|
968
964
|
delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
|
969
965
|
) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
|
970
|
-
|
966
|
+
|
971
967
|
delta["pval"] = np.minimum(
|
972
968
|
2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
|
973
969
|
)
|
@@ -982,7 +978,6 @@ class Augur:
|
|
982
978
|
*,
|
983
979
|
top_n: int = None,
|
984
980
|
ax: Axes = None,
|
985
|
-
show: bool = True,
|
986
981
|
return_fig: bool = False,
|
987
982
|
) -> Figure | None:
|
988
983
|
"""Plot scatterplot of differential prioritization.
|
@@ -1041,10 +1036,9 @@ class Augur:
|
|
1041
1036
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1042
1037
|
ax.add_artist(legend1)
|
1043
1038
|
|
1044
|
-
if show:
|
1045
|
-
plt.show()
|
1046
1039
|
if return_fig:
|
1047
1040
|
return plt.gcf()
|
1041
|
+
plt.show()
|
1048
1042
|
return None
|
1049
1043
|
|
1050
1044
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1055,7 +1049,6 @@ class Augur:
|
|
1055
1049
|
key: str = "augurpy_results",
|
1056
1050
|
top_n: int = 10,
|
1057
1051
|
ax: Axes = None,
|
1058
|
-
show: bool = True,
|
1059
1052
|
return_fig: bool = False,
|
1060
1053
|
) -> Figure | None:
|
1061
1054
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
@@ -1109,10 +1102,9 @@ class Augur:
|
|
1109
1102
|
plt.ylabel("Gene")
|
1110
1103
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1104
|
|
1112
|
-
if show:
|
1113
|
-
plt.show()
|
1114
1105
|
if return_fig:
|
1115
1106
|
return plt.gcf()
|
1107
|
+
plt.show()
|
1116
1108
|
return None
|
1117
1109
|
|
1118
1110
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1122,7 +1114,6 @@ class Augur:
|
|
1122
1114
|
*,
|
1123
1115
|
key: str = "augurpy_results",
|
1124
1116
|
ax: Axes = None,
|
1125
|
-
show: bool = True,
|
1126
1117
|
return_fig: bool = False,
|
1127
1118
|
) -> Figure | None:
|
1128
1119
|
"""Plot a lollipop plot of the mean augur values.
|
@@ -1172,10 +1163,9 @@ class Augur:
|
|
1172
1163
|
plt.ylabel("Cell Type")
|
1173
1164
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1174
1165
|
|
1175
|
-
if show:
|
1176
|
-
plt.show()
|
1177
1166
|
if return_fig:
|
1178
1167
|
return plt.gcf()
|
1168
|
+
plt.show()
|
1179
1169
|
return None
|
1180
1170
|
|
1181
1171
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1185,7 +1175,6 @@ class Augur:
|
|
1185
1175
|
results2: dict[str, Any],
|
1186
1176
|
*,
|
1187
1177
|
top_n: int = None,
|
1188
|
-
show: bool = True,
|
1189
1178
|
return_fig: bool = False,
|
1190
1179
|
) -> Figure | None:
|
1191
1180
|
"""Create scatterplot with two augur results.
|
@@ -1243,8 +1232,7 @@ class Augur:
|
|
1243
1232
|
plt.xlabel("Augur scores 1")
|
1244
1233
|
plt.ylabel("Augur scores 2")
|
1245
1234
|
|
1246
|
-
if show:
|
1247
|
-
plt.show()
|
1248
1235
|
if return_fig:
|
1249
1236
|
return plt.gcf()
|
1237
|
+
plt.show()
|
1250
1238
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -658,7 +658,6 @@ class Cinemaot:
|
|
658
658
|
title: str = "CINEMA-OT matching matrix",
|
659
659
|
min_val: float = 0.01,
|
660
660
|
ax: Axes | None = None,
|
661
|
-
show: bool = True,
|
662
661
|
return_fig: bool = False,
|
663
662
|
**kwargs,
|
664
663
|
) -> Figure | None:
|
@@ -717,10 +716,9 @@ class Cinemaot:
|
|
717
716
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
718
717
|
plt.title(title)
|
719
718
|
|
720
|
-
if show:
|
721
|
-
plt.show()
|
722
719
|
if return_fig:
|
723
720
|
return g
|
721
|
+
plt.show()
|
724
722
|
return None
|
725
723
|
|
726
724
|
|