pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_augur.py
CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
|
|
15
15
|
from anndata import AnnData
|
16
16
|
from joblib import Parallel, delayed
|
17
17
|
from lamin_utils import logger
|
18
|
-
from rich import print
|
19
18
|
from rich.progress import track
|
20
19
|
from scipy import sparse, stats
|
21
20
|
from sklearn.base import is_classifier, is_regressor
|
@@ -26,17 +25,19 @@ from sklearn.metrics import (
|
|
26
25
|
explained_variance_score,
|
27
26
|
f1_score,
|
28
27
|
make_scorer,
|
29
|
-
mean_squared_error,
|
30
28
|
precision_score,
|
31
29
|
r2_score,
|
32
30
|
recall_score,
|
33
31
|
roc_auc_score,
|
32
|
+
root_mean_squared_error,
|
34
33
|
)
|
35
34
|
from sklearn.model_selection import StratifiedKFold, cross_validate
|
36
35
|
from sklearn.preprocessing import LabelEncoder
|
37
36
|
from skmisc.loess import loess
|
38
37
|
from statsmodels.stats.multitest import fdrcorrection
|
39
38
|
|
39
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
40
|
+
|
40
41
|
if TYPE_CHECKING:
|
41
42
|
from matplotlib.axes import Axes
|
42
43
|
from matplotlib.figure import Figure
|
@@ -439,7 +440,7 @@ class Augur:
|
|
439
440
|
"augur_score": make_scorer(self.ccc_score),
|
440
441
|
"r2": make_scorer(r2_score),
|
441
442
|
"ccc": make_scorer(self.ccc_score),
|
442
|
-
"neg_mean_squared_error": make_scorer(
|
443
|
+
"neg_mean_squared_error": make_scorer(root_mean_squared_error),
|
443
444
|
"explained_variance": make_scorer(explained_variance_score),
|
444
445
|
}
|
445
446
|
)
|
@@ -684,7 +685,7 @@ class Augur:
|
|
684
685
|
span: float = 0.75,
|
685
686
|
filter_negative_residuals: bool = False,
|
686
687
|
n_threads: int = 4,
|
687
|
-
augur_mode: Literal["
|
688
|
+
augur_mode: Literal["default", "permute", "velocity"] = "default",
|
688
689
|
select_variance_features: bool = True,
|
689
690
|
key_added: str = "augurpy_results",
|
690
691
|
random_state: int | None = None,
|
@@ -907,41 +908,39 @@ class Augur:
|
|
907
908
|
.mean()
|
908
909
|
)
|
909
910
|
|
910
|
-
|
911
|
-
|
911
|
+
rng = np.random.default_rng()
|
912
|
+
sampled_data = []
|
912
913
|
|
913
914
|
# draw mean aucs for permute1 and permute2
|
914
915
|
for celltype in permuted_cv_augur1["cell_type"].unique():
|
915
916
|
df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
|
916
917
|
df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
|
917
|
-
for permutation_idx in range(n_permutations):
|
918
|
-
# subsample
|
919
|
-
sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
|
920
|
-
sampled_permuted_cv_augur1.append(
|
921
|
-
pd.DataFrame(
|
922
|
-
{
|
923
|
-
"cell_type": [celltype],
|
924
|
-
"permutation_idx": [permutation_idx],
|
925
|
-
"mean": [sample1["augur_score"].mean(axis=0)],
|
926
|
-
"std": [sample1["augur_score"].std(axis=0)],
|
927
|
-
}
|
928
|
-
)
|
929
|
-
)
|
930
918
|
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
919
|
+
indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
|
920
|
+
indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
|
921
|
+
|
922
|
+
scores1 = df1["augur_score"].values[indices1]
|
923
|
+
scores2 = df2["augur_score"].values[indices2]
|
924
|
+
|
925
|
+
means1 = scores1.mean(axis=1)
|
926
|
+
means2 = scores2.mean(axis=1)
|
927
|
+
stds1 = scores1.std(axis=1)
|
928
|
+
stds2 = scores2.std(axis=1)
|
929
|
+
|
930
|
+
sampled_data.append(
|
931
|
+
pd.DataFrame(
|
932
|
+
{
|
933
|
+
"cell_type": np.repeat(celltype, n_permutations),
|
934
|
+
"permutation_idx": np.arange(n_permutations),
|
935
|
+
"mean1": means1,
|
936
|
+
"mean2": means2,
|
937
|
+
"std1": stds1,
|
938
|
+
"std2": stds2,
|
939
|
+
}
|
941
940
|
)
|
941
|
+
)
|
942
942
|
|
943
|
-
|
944
|
-
permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
|
943
|
+
sampled_df = pd.concat(sampled_data)
|
945
944
|
|
946
945
|
# delta between augur scores
|
947
946
|
delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
|
@@ -949,9 +948,7 @@ class Augur:
|
|
949
948
|
)
|
950
949
|
|
951
950
|
# delta between permutation scores
|
952
|
-
delta_rnd =
|
953
|
-
permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
|
954
|
-
).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
951
|
+
delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
|
955
952
|
|
956
953
|
# number of values where permutations are larger than test statistic
|
957
954
|
delta["b"] = (
|
@@ -966,7 +963,7 @@ class Augur:
|
|
966
963
|
delta["z"] = (
|
967
964
|
delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
|
968
965
|
) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
|
969
|
-
|
966
|
+
|
970
967
|
delta["pval"] = np.minimum(
|
971
968
|
2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
|
972
969
|
)
|
@@ -974,24 +971,25 @@ class Augur:
|
|
974
971
|
|
975
972
|
return delta
|
976
973
|
|
974
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
977
975
|
def plot_dp_scatter(
|
978
976
|
self,
|
979
977
|
results: pd.DataFrame,
|
978
|
+
*,
|
980
979
|
top_n: int = None,
|
981
|
-
return_fig: bool | None = None,
|
982
980
|
ax: Axes = None,
|
983
|
-
|
984
|
-
|
985
|
-
) -> Axes | Figure | None:
|
981
|
+
return_fig: bool = False,
|
982
|
+
) -> Figure | None:
|
986
983
|
"""Plot scatterplot of differential prioritization.
|
987
984
|
|
988
985
|
Args:
|
989
986
|
results: Results after running differential prioritization.
|
990
987
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
991
988
|
ax: optionally, axes used to draw plot
|
989
|
+
{common_plot_args}
|
992
990
|
|
993
991
|
Returns:
|
994
|
-
|
992
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
995
993
|
|
996
994
|
Examples:
|
997
995
|
>>> import pertpy as pt
|
@@ -1038,37 +1036,32 @@ class Augur:
|
|
1038
1036
|
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
|
1039
1037
|
ax.add_artist(legend1)
|
1040
1038
|
|
1041
|
-
if save:
|
1042
|
-
plt.savefig(save, bbox_inches="tight")
|
1043
|
-
if show:
|
1044
|
-
plt.show()
|
1045
1039
|
if return_fig:
|
1046
1040
|
return plt.gcf()
|
1047
|
-
|
1048
|
-
return ax
|
1041
|
+
plt.show()
|
1049
1042
|
return None
|
1050
1043
|
|
1044
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1051
1045
|
def plot_important_features(
|
1052
1046
|
self,
|
1053
1047
|
data: dict[str, Any],
|
1048
|
+
*,
|
1054
1049
|
key: str = "augurpy_results",
|
1055
1050
|
top_n: int = 10,
|
1056
|
-
return_fig: bool | None = None,
|
1057
1051
|
ax: Axes = None,
|
1058
|
-
|
1059
|
-
|
1060
|
-
) -> Axes | None:
|
1052
|
+
return_fig: bool = False,
|
1053
|
+
) -> Figure | None:
|
1061
1054
|
"""Plot a lollipop plot of the n features with largest feature importances.
|
1062
1055
|
|
1063
1056
|
Args:
|
1064
|
-
|
1057
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1065
1058
|
key: Key in the AnnData object of the results
|
1066
1059
|
top_n: n number feature importance values to plot. Default is 10.
|
1067
1060
|
ax: optionally, axes used to draw plot
|
1068
|
-
|
1061
|
+
{common_plot_args}
|
1069
1062
|
|
1070
1063
|
Returns:
|
1071
|
-
|
1064
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1072
1065
|
|
1073
1066
|
Examples:
|
1074
1067
|
>>> import pertpy as pt
|
@@ -1109,35 +1102,30 @@ class Augur:
|
|
1109
1102
|
plt.ylabel("Gene")
|
1110
1103
|
plt.yticks(y_axes_range, n_features["genes"])
|
1111
1104
|
|
1112
|
-
if save:
|
1113
|
-
plt.savefig(save, bbox_inches="tight")
|
1114
|
-
if show:
|
1115
|
-
plt.show()
|
1116
1105
|
if return_fig:
|
1117
1106
|
return plt.gcf()
|
1118
|
-
|
1119
|
-
return ax
|
1107
|
+
plt.show()
|
1120
1108
|
return None
|
1121
1109
|
|
1110
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1122
1111
|
def plot_lollipop(
|
1123
1112
|
self,
|
1124
|
-
data: dict[str, Any],
|
1113
|
+
data: dict[str, Any] | AnnData,
|
1114
|
+
*,
|
1125
1115
|
key: str = "augurpy_results",
|
1126
|
-
return_fig: bool | None = None,
|
1127
1116
|
ax: Axes = None,
|
1128
|
-
|
1129
|
-
|
1130
|
-
) -> Axes | Figure | None:
|
1117
|
+
return_fig: bool = False,
|
1118
|
+
) -> Figure | None:
|
1131
1119
|
"""Plot a lollipop plot of the mean augur values.
|
1132
1120
|
|
1133
1121
|
Args:
|
1134
|
-
|
1135
|
-
key:
|
1136
|
-
ax: optionally, axes used to draw plot
|
1137
|
-
|
1122
|
+
data: results after running `predict()` as dictionary or the AnnData object.
|
1123
|
+
key: .uns key in the results AnnData object.
|
1124
|
+
ax: optionally, axes used to draw plot.
|
1125
|
+
{common_plot_args}
|
1138
1126
|
|
1139
1127
|
Returns:
|
1140
|
-
|
1128
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
1141
1129
|
|
1142
1130
|
Examples:
|
1143
1131
|
>>> import pertpy as pt
|
@@ -1175,32 +1163,27 @@ class Augur:
|
|
1175
1163
|
plt.ylabel("Cell Type")
|
1176
1164
|
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
|
1177
1165
|
|
1178
|
-
if save:
|
1179
|
-
plt.savefig(save, bbox_inches="tight")
|
1180
|
-
if show:
|
1181
|
-
plt.show()
|
1182
1166
|
if return_fig:
|
1183
1167
|
return plt.gcf()
|
1184
|
-
|
1185
|
-
return ax
|
1168
|
+
plt.show()
|
1186
1169
|
return None
|
1187
1170
|
|
1171
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1188
1172
|
def plot_scatterplot(
|
1189
1173
|
self,
|
1190
1174
|
results1: dict[str, Any],
|
1191
1175
|
results2: dict[str, Any],
|
1176
|
+
*,
|
1192
1177
|
top_n: int = None,
|
1193
|
-
return_fig: bool
|
1194
|
-
|
1195
|
-
save: str | bool | None = None,
|
1196
|
-
) -> Axes | Figure | None:
|
1178
|
+
return_fig: bool = False,
|
1179
|
+
) -> Figure | None:
|
1197
1180
|
"""Create scatterplot with two augur results.
|
1198
1181
|
|
1199
1182
|
Args:
|
1200
1183
|
results1: results after running `predict()`
|
1201
1184
|
results2: results after running `predict()`
|
1202
1185
|
top_n: optionally, the number of top prioritized cell types to label in the plot
|
1203
|
-
|
1186
|
+
{common_plot_args}
|
1204
1187
|
|
1205
1188
|
Returns:
|
1206
1189
|
Axes of the plot.
|
@@ -1249,12 +1232,7 @@ class Augur:
|
|
1249
1232
|
plt.xlabel("Augur scores 1")
|
1250
1233
|
plt.ylabel("Augur scores 2")
|
1251
1234
|
|
1252
|
-
if save:
|
1253
|
-
plt.savefig(save, bbox_inches="tight")
|
1254
|
-
if show:
|
1255
|
-
plt.show()
|
1256
1235
|
if return_fig:
|
1257
1236
|
return plt.gcf()
|
1258
|
-
|
1259
|
-
return ax
|
1237
|
+
plt.show()
|
1260
1238
|
return None
|
pertpy/tools/_cinemaot.py
CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
|
|
18
18
|
from sklearn.linear_model import LinearRegression
|
19
19
|
from sklearn.neighbors import NearestNeighbors
|
20
20
|
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
22
|
+
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from anndata import AnnData
|
23
25
|
from matplotlib.axes import Axes
|
26
|
+
from matplotlib.pyplot import Figure
|
24
27
|
from statsmodels.tools.typing import ArrayLike
|
25
28
|
|
26
29
|
|
@@ -88,7 +91,7 @@ class Cinemaot:
|
|
88
91
|
dim = self.get_dim(adata, use_rep=use_rep)
|
89
92
|
|
90
93
|
transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
|
91
|
-
X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
|
94
|
+
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
|
92
95
|
groupvec = (adata.obs[pert_key] == control * 1).values # control
|
93
96
|
xi = np.zeros(dim)
|
94
97
|
j = 0
|
@@ -97,9 +100,9 @@ class Cinemaot:
|
|
97
100
|
xi[j] = xi_obj.correlation
|
98
101
|
j = j + 1
|
99
102
|
|
100
|
-
cf = X_transformed[:, xi < thres]
|
101
|
-
cf1 = cf[adata.obs[pert_key] == control, :]
|
102
|
-
cf2 = cf[adata.obs[pert_key] != control, :]
|
103
|
+
cf = np.array(X_transformed[:, xi < thres], np.float64)
|
104
|
+
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
|
105
|
+
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
|
103
106
|
if sum(xi < thres) == 1:
|
104
107
|
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
|
105
108
|
elif sum(xi < thres) == 0:
|
@@ -167,7 +170,7 @@ class Cinemaot:
|
|
167
170
|
else:
|
168
171
|
_solver = sinkhorn.Sinkhorn(threshold=eps)
|
169
172
|
ot_sink = _solver(ot_prob)
|
170
|
-
ot_matrix = ot_sink.matrix.T
|
173
|
+
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
|
171
174
|
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
|
172
175
|
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
|
173
176
|
)
|
@@ -639,6 +642,7 @@ class Cinemaot:
|
|
639
642
|
s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
|
640
643
|
return c_effect, s_effect
|
641
644
|
|
645
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
642
646
|
def plot_vis_matching(
|
643
647
|
self,
|
644
648
|
adata: AnnData,
|
@@ -647,16 +651,16 @@ class Cinemaot:
|
|
647
651
|
control: str,
|
648
652
|
de_label: str,
|
649
653
|
source_label: str,
|
654
|
+
*,
|
650
655
|
matching_rep: str = "ot",
|
651
656
|
resolution: float = 0.5,
|
652
657
|
normalize: str = "col",
|
653
658
|
title: str = "CINEMA-OT matching matrix",
|
654
659
|
min_val: float = 0.01,
|
655
|
-
show: bool = True,
|
656
|
-
save: str | None = None,
|
657
660
|
ax: Axes | None = None,
|
661
|
+
return_fig: bool = False,
|
658
662
|
**kwargs,
|
659
|
-
) -> None:
|
663
|
+
) -> Figure | None:
|
660
664
|
"""Visualize the CINEMA-OT matching matrix.
|
661
665
|
|
662
666
|
Args:
|
@@ -670,11 +674,12 @@ class Cinemaot:
|
|
670
674
|
normalize: normalize the coarse-grained matching matrix by row / column.
|
671
675
|
title: the title for the figure.
|
672
676
|
min_val: The min value to truncate the matching matrix.
|
673
|
-
|
674
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
675
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
677
|
+
{common_plot_args}
|
676
678
|
**kwargs: Other parameters to input for seaborn.heatmap.
|
677
679
|
|
680
|
+
Returns:
|
681
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
682
|
+
|
678
683
|
Examples:
|
679
684
|
>>> import pertpy as pt
|
680
685
|
>>> adata = pt.dt.cinemaot_example()
|
@@ -710,12 +715,11 @@ class Cinemaot:
|
|
710
715
|
|
711
716
|
g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
|
712
717
|
plt.title(title)
|
713
|
-
|
714
|
-
if
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
return g
|
718
|
+
|
719
|
+
if return_fig:
|
720
|
+
return g
|
721
|
+
plt.show()
|
722
|
+
return None
|
719
723
|
|
720
724
|
|
721
725
|
class Xi:
|