pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_augur.py
CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
|
|
15
15
|
from anndata import AnnData
|
16
16
|
from joblib import Parallel, delayed
|
17
17
|
from lamin_utils import logger
|
18
|
-
from rich import print
|
19
18
|
from rich.progress import track
|
20
19
|
from scipy import sparse, stats
|
21
20
|
from sklearn.base import is_classifier, is_regressor
|
@@ -26,17 +25,19 @@ from sklearn.metrics import (
|
|
26
25
|
explained_variance_score,
|
27
26
|
f1_score,
|
28
27
|
make_scorer,
|
29
|
-
mean_squared_error,
|
30
28
|
precision_score,
|
31
29
|
r2_score,
|
32
30
|
recall_score,
|
33
31
|
roc_auc_score,
|
32
|
+
root_mean_squared_error,
|
34
33
|
)
|
35
34
|
from sklearn.model_selection import StratifiedKFold, cross_validate
|
36
35
|
from sklearn.preprocessing import LabelEncoder
|
37
36
|
from skmisc.loess import loess
|
38
37
|
from statsmodels.stats.multitest import fdrcorrection
|
39
38
|
|
39
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
40
|
+
|
40
41
|
if TYPE_CHECKING:
|
41
42
|
from matplotlib.axes import Axes
|
42
43
|
from matplotlib.figure import Figure
|
@@ -439,7 +440,7 @@ class Augur:
|
|
439
440
|
"augur_score": make_scorer(self.ccc_score),
|
440
441
|
"r2": make_scorer(r2_score),
|
441
442
|
"ccc": make_scorer(self.ccc_score),
|
442
|
-
"neg_mean_squared_error": make_scorer(
|
443
|
+
"neg_mean_squared_error": make_scorer(root_mean_squared_error),
|
443
444
|
"explained_variance": make_scorer(explained_variance_score),
|
444
445
|
}
|
445
446
|
)
|
@@ -684,7 +685,7 @@ class Augur:
|
|
684
685
|
span: float = 0.75,
|
685
686
|
filter_negative_residuals: bool = False,
|
686
687
|
n_threads: int = 4,
|
687
|
-
augur_mode: Literal["
|
688
|
+
augur_mode: Literal["default", "permute", "velocity"] = "default",
|
688
689
|
select_variance_features: bool = True,
|
689
690
|
key_added: str = "augurpy_results",
|
690
691
|
random_state: int | None = None,
|
@@ -907,41 +908,39 @@ class Augur:
|
|
907
908
|
.mean()
|
908
909
|
)
|
909
910
|
|
910
|
-
|
911
|
-
|
911
|
+
rng = np.random.default_rng()
|
912
|
+
sampled_data = []
|
912
913
|
|
913
914
|
# draw mean aucs for permute1 and permute2
|
914
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
915
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
916
917
|
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
917
|
-
for permutation_idx in range(n_permutations):
|
918
|
-
# subsample
|
919
|
-
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
920
|
-
sampled_permuted_cv_augur1.append(
|
921
|
-
pd.DataFrame(
|
922
|
-
{
|
923
|
-
"cell_type": [celltype],
|
924
|
-
"permutation_idx": [permutation_idx],
|
925
|
-
"mean": [sample1["augur_score"].mean(axis=0)],
|
926
|
-
"std": [sample1["augur_score"].std(axis=0)],
|
927
|
-
}
|
928
|
-
)
|
929
|
-
)
|
930
918
|
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
919
|
+
indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
|
920
|
+
indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
|
921
|
+
|
922
|
+
scores1 = df1["augur_score"].values[indices1]
|
923
|
+
scores2 = df2["augur_score"].values[indices2]
|
924
|
+
|
925
|
+
means1 = scores1.mean(axis=1)
|
926
|
+
means2 = scores2.mean(axis=1)
|
927
|
+
stds1 = scores1.std(axis=1)
|
928
|
+
stds2 = scores2.std(axis=1)
|
929
|
+
|
930
|
+
sampled_data.append(
|
931
|
+
pd.DataFrame(
|
932
|
+
{
|
933
|
+
"cell_type": np.repeat(celltype, n_permutations),
|
934
|
+
"permutation_idx": np.arange(n_permutations),
|
935
|
+
"mean1": means1,
|
936
|
+
"mean2": means2,
|
937
|
+
"std1": stds1,
|
938
|
+
"std2": stds2,
|
939
|
+
}
|
941
940
|
)
|
941
|
+
)
|
942
942
|
|
943
|
-
|
944
|
-
permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
|
943
|
+
sampled_df = pd.concat(sampled_data)
|
945
944
|
|
946
945
|
# delta between augur scores
|
947
946
|
delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
|
@@ -949,9 +948,7 @@ class Augur:
|
|
949
948
|
)
|
950
949
|
|
951
950
|
# delta between permutation scores
|
952
|
-
delta_rnd =
|
953
|
-
permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
|
954
|
-
).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
951
|
+
delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
955
952
|
|
956
953
|
# number of values where permutations are larger than test statistic
|
957
954
|
delta["b"] = (
|
@@ -966,7 +963,7 @@ class Augur:
|
|
966
963
|
delta["z"] = (
|
967
964
|
delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
|
968
965
|
) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
|
969
|
-
|
966
|
+
|
970
967
|
delta["pval"] = np.minimum(
|
971
968
|
2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
|
972
969
|
)
|
@@ -974,24 +971,25 @@ class Augur:
|
|
974
971
|
|
975
972
|
return delta
|
976
973
|
|
974
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
977
975
|
def plot_dp_scatter(
|
978
976
|
self,
|
979
977
|
results: pd.DataFrame,
|
978
|
+
*,
|
980
979
|
top_n: int = None,
|
981
|
-
return_fig: bool | None = None,
|
982
980
|
ax: Axes = None,
|
983
|
-
|
984
|
-
|
985
|
-
) -> Axes | Figure | None:
|
981
|
+
return_fig: bool = False,
|
982
|
+
) -> Figure | None:
|
986
983
|
"""Plot scatterplot of differential prioritization.
|
987
984
|
|
988
985
|
Args:
|
989
986
|
results: Results after running differential prioritization.
|
990
987
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
988
|
ax: optionally, axes used to draw plot
|
989
|
+
{common_plot_args}
|
992
990
|
|
993
991
|
Returns:
|
994
|
-
|
992
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
995
993
|
|
996
994
|
Examples:
|
997
995
|
>>> import pertpy as pt
|
@@ -1038,37 +1036,32 @@ class Augur:
|
|
1038
1036
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
1037
|
ax.add_artist(legend1)
|
1040
1038
|
|
1041
|
-
if save:
|
1042
|
-
plt.savefig(save, bbox_inches="tight")
|
1043
|
-
if show:
|
1044
|
-
plt.show()
|
1045
1039
|
if return_fig:
|
1046
1040
|
return plt.gcf()
|
1047
|
-
|
1048
|
-
return ax
|
1041
|
+
plt.show()
|
1049
1042
|
return None
|
1050
1043
|
|
1044
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1051
1045
|
def plot_important_features(
|
1052
1046
|
self,
|
1053
1047
|
data: dict[str, Any],
|
1048
|
+
*,
|
1054
1049
|
key: str = "augurpy_results",
|
1055
1050
|
top_n: int = 10,
|
1056
|
-
return_fig: bool | None = None,
|
1057
1051
|
ax: Axes = None,
|
1058
|
-
|
1059
|
-
|
1060
|
-
) -> Axes | None:
|
1052
|
+
return_fig: bool = False,
|
1053
|
+
) -> Figure | None:
|
1061
1054
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
1055
|
|
1063
1056
|
Args:
|
1064
|
-
|
1057
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1065
1058
|
key: Key in the AnnData object of the results
|
1066
1059
|
top_n: n number feature importance values to plot. Default is 10.
|
1067
1060
|
ax: optionally, axes used to draw plot
|
1068
|
-
|
1061
|
+
{common_plot_args}
|
1069
1062
|
|
1070
1063
|
Returns:
|
1071
|
-
|
1064
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1072
1065
|
|
1073
1066
|
Examples:
|
1074
1067
|
>>> import pertpy as pt
|
@@ -1109,35 +1102,30 @@ class Augur:
|
|
1109
1102
|
plt.ylabel("Gene")
|
1110
1103
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1104
|
|
1112
|
-
if save:
|
1113
|
-
plt.savefig(save, bbox_inches="tight")
|
1114
|
-
if show:
|
1115
|
-
plt.show()
|
1116
1105
|
if return_fig:
|
1117
1106
|
return plt.gcf()
|
1118
|
-
|
1119
|
-
return ax
|
1107
|
+
plt.show()
|
1120
1108
|
return None
|
1121
1109
|
|
1110
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1122
1111
|
def plot_lollipop(
|
1123
1112
|
self,
|
1124
|
-
data: dict[str, Any],
|
1113
|
+
data: dict[str, Any] | AnnData,
|
1114
|
+
*,
|
1125
1115
|
key: str = "augurpy_results",
|
1126
|
-
return_fig: bool | None = None,
|
1127
1116
|
ax: Axes = None,
|
1128
|
-
|
1129
|
-
|
1130
|
-
) -> Axes | Figure | None:
|
1117
|
+
return_fig: bool = False,
|
1118
|
+
) -> Figure | None:
|
1131
1119
|
"""Plot a lollipop plot of the mean augur values.
|
1132
1120
|
|
1133
1121
|
Args:
|
1134
|
-
|
1135
|
-
key:
|
1136
|
-
ax: optionally, axes used to draw plot
|
1137
|
-
|
1122
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1123
|
+
key: .uns key in the results AnnData object.
|
1124
|
+
ax: optionally, axes used to draw plot.
|
1125
|
+
{common_plot_args}
|
1138
1126
|
|
1139
1127
|
Returns:
|
1140
|
-
|
1128
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1141
1129
|
|
1142
1130
|
Examples:
|
1143
1131
|
>>> import pertpy as pt
|
@@ -1175,32 +1163,27 @@ class Augur:
|
|
1175
1163
|
plt.ylabel("Cell Type")
|
1176
1164
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
1165
|
|
1178
|
-
if save:
|
1179
|
-
plt.savefig(save, bbox_inches="tight")
|
1180
|
-
if show:
|
1181
|
-
plt.show()
|
1182
1166
|
if return_fig:
|
1183
1167
|
return plt.gcf()
|
1184
|
-
|
1185
|
-
return ax
|
1168
|
+
plt.show()
|
1186
1169
|
return None
|
1187
1170
|
|
1171
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1188
1172
|
def plot_scatterplot(
|
1189
1173
|
self,
|
1190
1174
|
results1: dict[str, Any],
|
1191
1175
|
results2: dict[str, Any],
|
1176
|
+
*,
|
1192
1177
|
top_n: int = None,
|
1193
|
-
return_fig: bool
|
1194
|
-
|
1195
|
-
save: str | bool | None = None,
|
1196
|
-
) -> Axes | Figure | None:
|
1178
|
+
return_fig: bool = False,
|
1179
|
+
) -> Figure | None:
|
1197
1180
|
"""Create scatterplot with two augur results.
|
1198
1181
|
|
1199
1182
|
Args:
|
1200
1183
|
results1: results after running `predict()`
|
1201
1184
|
results2: results after running `predict()`
|
1202
1185
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
-
|
1186
|
+
{common_plot_args}
|
1204
1187
|
|
1205
1188
|
Returns:
|
1206
1189
|
Axes of the plot.
|
@@ -1249,12 +1232,7 @@ class Augur:
|
|
1249
1232
|
plt.xlabel("Augur scores 1")
|
1250
1233
|
plt.ylabel("Augur scores 2")
|
1251
1234
|
|
1252
|
-
if save:
|
1253
|
-
plt.savefig(save, bbox_inches="tight")
|
1254
|
-
if show:
|
1255
|
-
plt.show()
|
1256
1235
|
if return_fig:
|
1257
1236
|
return plt.gcf()
|
1258
|
-
|
1259
|
-
return ax
|
1237
|
+
plt.show()
|
1260
1238
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
|
|
18
18
|
from sklearn.linear_model import LinearRegression
|
19
19
|
from sklearn.neighbors import NearestNeighbors
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from anndata import AnnData
|
23
25
|
from matplotlib.axes import Axes
|
26
|
+
from matplotlib.pyplot import Figure
|
24
27
|
from statsmodels.tools.typing import ArrayLike
|
25
28
|
|
26
29
|
|
@@ -88,7 +91,7 @@ class Cinemaot:
|
|
88
91
|
dim = self.get_dim(adata, use_rep=use_rep)
|
89
92
|
|
90
93
|
transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
|
91
|
-
X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
|
94
|
+
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
|
92
95
|
groupvec = (adata.obs[pert_key] == control * 1).values # control
|
93
96
|
xi = np.zeros(dim)
|
94
97
|
j = 0
|
@@ -97,9 +100,9 @@ class Cinemaot:
|
|
97
100
|
xi[j] = xi_obj.correlation
|
98
101
|
j = j + 1
|
99
102
|
|
100
|
-
cf = X_transformed[:, xi < thres]
|
101
|
-
cf1 = cf[adata.obs[pert_key] == control, :]
|
102
|
-
cf2 = cf[adata.obs[pert_key] != control, :]
|
103
|
+
cf = np.array(X_transformed[:, xi < thres], np.float64)
|
104
|
+
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
|
105
|
+
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
|
103
106
|
if sum(xi < thres) == 1:
|
104
107
|
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
|
105
108
|
elif sum(xi < thres) == 0:
|
@@ -167,7 +170,7 @@ class Cinemaot:
|
|
167
170
|
else:
|
168
171
|
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
172
|
ot_sink = _solver(ot_prob)
|
170
|
-
ot_matrix = ot_sink.matrix.T
|
173
|
+
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
171
174
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
172
175
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
173
176
|
)
|
@@ -639,6 +642,7 @@ class Cinemaot:
|
|
639
642
|
s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
|
640
643
|
return c_effect, s_effect
|
641
644
|
|
645
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
642
646
|
def plot_vis_matching(
|
643
647
|
self,
|
644
648
|
adata: AnnData,
|
@@ -647,16 +651,16 @@ class Cinemaot:
|
|
647
651
|
control: str,
|
648
652
|
de_label: str,
|
649
653
|
source_label: str,
|
654
|
+
*,
|
650
655
|
matching_rep: str = "ot",
|
651
656
|
resolution: float = 0.5,
|
652
657
|
normalize: str = "col",
|
653
658
|
title: str = "CINEMA-OT matching matrix",
|
654
659
|
min_val: float = 0.01,
|
655
|
-
show: bool = True,
|
656
|
-
save: str | None = None,
|
657
660
|
ax: Axes | None = None,
|
661
|
+
return_fig: bool = False,
|
658
662
|
**kwargs,
|
659
|
-
) -> None:
|
663
|
+
) -> Figure | None:
|
660
664
|
"""Visualize the CINEMA-OT matching matrix.
|
661
665
|
|
662
666
|
Args:
|
@@ -670,11 +674,12 @@ class Cinemaot:
|
|
670
674
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
671
675
|
title: the title for the figure.
|
672
676
|
min_val: The min value to truncate the matching matrix.
|
673
|
-
|
674
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
675
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
677
|
+
{common_plot_args}
|
676
678
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
677
679
|
|
680
|
+
Returns:
|
681
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
682
|
+
|
678
683
|
Examples:
|
679
684
|
>>> import pertpy as pt
|
680
685
|
>>> adata = pt.dt.cinemaot_example()
|
@@ -710,12 +715,11 @@ class Cinemaot:
|
|
710
715
|
|
711
716
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
712
717
|
plt.title(title)
|
713
|
-
|
714
|
-
if
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
return g
|
718
|
+
|
719
|
+
if return_fig:
|
720
|
+
return g
|
721
|
+
plt.show()
|
722
|
+
return None
|
719
723
|
|
720
724
|
|
721
725
|
class Xi:
|