pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/tools/_augur.py CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
15
15
  from anndata import AnnData
16
16
  from joblib import Parallel, delayed
17
17
  from lamin_utils import logger
18
- from rich import print
19
18
  from rich.progress import track
20
19
  from scipy import sparse, stats
21
20
  from sklearn.base import is_classifier, is_regressor
@@ -26,17 +25,19 @@ from sklearn.metrics import (
26
25
  explained_variance_score,
27
26
  f1_score,
28
27
  make_scorer,
29
- mean_squared_error,
30
28
  precision_score,
31
29
  r2_score,
32
30
  recall_score,
33
31
  roc_auc_score,
32
+ root_mean_squared_error,
34
33
  )
35
34
  from sklearn.model_selection import StratifiedKFold, cross_validate
36
35
  from sklearn.preprocessing import LabelEncoder
37
36
  from skmisc.loess import loess
38
37
  from statsmodels.stats.multitest import fdrcorrection
39
38
 
39
+ from pertpy._doc import _doc_params, doc_common_plot_args
40
+
40
41
  if TYPE_CHECKING:
41
42
  from matplotlib.axes import Axes
42
43
  from matplotlib.figure import Figure
@@ -439,7 +440,7 @@ class Augur:
439
440
  "augur_score": make_scorer(self.ccc_score),
440
441
  "r2": make_scorer(r2_score),
441
442
  "ccc": make_scorer(self.ccc_score),
442
- "neg_mean_squared_error": make_scorer(mean_squared_error),
443
+ "neg_mean_squared_error": make_scorer(root_mean_squared_error),
443
444
  "explained_variance": make_scorer(explained_variance_score),
444
445
  }
445
446
  )
@@ -684,7 +685,7 @@ class Augur:
684
685
  span: float = 0.75,
685
686
  filter_negative_residuals: bool = False,
686
687
  n_threads: int = 4,
687
- augur_mode: Literal["permute"] | Literal["default"] | Literal["velocity"] = "default",
688
+ augur_mode: Literal["default", "permute", "velocity"] = "default",
688
689
  select_variance_features: bool = True,
689
690
  key_added: str = "augurpy_results",
690
691
  random_state: int | None = None,
@@ -907,41 +908,39 @@ class Augur:
907
908
  .mean()
908
909
  )
909
910
 
910
- sampled_permuted_cv_augur1 = []
911
- sampled_permuted_cv_augur2 = []
911
+ rng = np.random.default_rng()
912
+ sampled_data = []
912
913
 
913
914
  # draw mean aucs for permute1 and permute2
914
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
915
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
916
917
  df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
917
- for permutation_idx in range(n_permutations):
918
- # subsample
919
- sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
920
- sampled_permuted_cv_augur1.append(
921
- pd.DataFrame(
922
- {
923
- "cell_type": [celltype],
924
- "permutation_idx": [permutation_idx],
925
- "mean": [sample1["augur_score"].mean(axis=0)],
926
- "std": [sample1["augur_score"].std(axis=0)],
927
- }
928
- )
929
- )
930
918
 
931
- sample2 = df2.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
932
- sampled_permuted_cv_augur2.append(
933
- pd.DataFrame(
934
- {
935
- "cell_type": [celltype],
936
- "permutation_idx": [permutation_idx],
937
- "mean": [sample2["augur_score"].mean(axis=0)],
938
- "std": [sample2["augur_score"].std(axis=0)],
939
- }
940
- )
919
+ indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
920
+ indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
921
+
922
+ scores1 = df1["augur_score"].values[indices1]
923
+ scores2 = df2["augur_score"].values[indices2]
924
+
925
+ means1 = scores1.mean(axis=1)
926
+ means2 = scores2.mean(axis=1)
927
+ stds1 = scores1.std(axis=1)
928
+ stds2 = scores2.std(axis=1)
929
+
930
+ sampled_data.append(
931
+ pd.DataFrame(
932
+ {
933
+ "cell_type": np.repeat(celltype, n_permutations),
934
+ "permutation_idx": np.arange(n_permutations),
935
+ "mean1": means1,
936
+ "mean2": means2,
937
+ "std1": stds1,
938
+ "std2": stds2,
939
+ }
941
940
  )
941
+ )
942
942
 
943
- permuted_samples1 = pd.concat(sampled_permuted_cv_augur1)
944
- permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
943
+ sampled_df = pd.concat(sampled_data)
945
944
 
946
945
  # delta between augur scores
947
946
  delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
@@ -949,9 +948,7 @@ class Augur:
949
948
  )
950
949
 
951
950
  # delta between permutation scores
952
- delta_rnd = permuted_samples1.merge(
953
- permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
954
- ).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
951
+ delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
955
952
 
956
953
  # number of values where permutations are larger than test statistic
957
954
  delta["b"] = (
@@ -966,7 +963,7 @@ class Augur:
966
963
  delta["z"] = (
967
964
  delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
968
965
  ) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
969
- # calculate pvalues
966
+
970
967
  delta["pval"] = np.minimum(
971
968
  2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
972
969
  )
@@ -974,24 +971,25 @@ class Augur:
974
971
 
975
972
  return delta
976
973
 
974
+ @_doc_params(common_plot_args=doc_common_plot_args)
977
975
  def plot_dp_scatter(
978
976
  self,
979
977
  results: pd.DataFrame,
978
+ *,
980
979
  top_n: int = None,
981
- return_fig: bool | None = None,
982
980
  ax: Axes = None,
983
- show: bool | None = None,
984
- save: str | bool | None = None,
985
- ) -> Axes | Figure | None:
981
+ return_fig: bool = False,
982
+ ) -> Figure | None:
986
983
  """Plot scatterplot of differential prioritization.
987
984
 
988
985
  Args:
989
986
  results: Results after running differential prioritization.
990
987
  top_n: optionally, the number of top prioritized cell types to label in the plot
991
988
  ax: optionally, axes used to draw plot
989
+ {common_plot_args}
992
990
 
993
991
  Returns:
994
- Axes of the plot.
992
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
995
993
 
996
994
  Examples:
997
995
  >>> import pertpy as pt
@@ -1038,37 +1036,32 @@ class Augur:
1038
1036
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1039
1037
  ax.add_artist(legend1)
1040
1038
 
1041
- if save:
1042
- plt.savefig(save, bbox_inches="tight")
1043
- if show:
1044
- plt.show()
1045
1039
  if return_fig:
1046
1040
  return plt.gcf()
1047
- if not (show or save):
1048
- return ax
1041
+ plt.show()
1049
1042
  return None
1050
1043
 
1044
+ @_doc_params(common_plot_args=doc_common_plot_args)
1051
1045
  def plot_important_features(
1052
1046
  self,
1053
1047
  data: dict[str, Any],
1048
+ *,
1054
1049
  key: str = "augurpy_results",
1055
1050
  top_n: int = 10,
1056
- return_fig: bool | None = None,
1057
1051
  ax: Axes = None,
1058
- show: bool | None = None,
1059
- save: str | bool | None = None,
1060
- ) -> Axes | None:
1052
+ return_fig: bool = False,
1053
+ ) -> Figure | None:
1061
1054
  """Plot a lollipop plot of the n features with largest feature importances.
1062
1055
 
1063
1056
  Args:
1064
- results: results after running `predict()` as dictionary or the AnnData object.
1057
+ data: results after running `predict()` as dictionary or the AnnData object.
1065
1058
  key: Key in the AnnData object of the results
1066
1059
  top_n: n number feature importance values to plot. Default is 10.
1067
1060
  ax: optionally, axes used to draw plot
1068
- return_figure: if `True` returns figure of the plot, default is `False`
1061
+ {common_plot_args}
1069
1062
 
1070
1063
  Returns:
1071
- Axes of the plot.
1064
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1072
1065
 
1073
1066
  Examples:
1074
1067
  >>> import pertpy as pt
@@ -1109,35 +1102,30 @@ class Augur:
1109
1102
  plt.ylabel("Gene")
1110
1103
  plt.yticks(y_axes_range, n_features["genes"])
1111
1104
 
1112
- if save:
1113
- plt.savefig(save, bbox_inches="tight")
1114
- if show:
1115
- plt.show()
1116
1105
  if return_fig:
1117
1106
  return plt.gcf()
1118
- if not (show or save):
1119
- return ax
1107
+ plt.show()
1120
1108
  return None
1121
1109
 
1110
+ @_doc_params(common_plot_args=doc_common_plot_args)
1122
1111
  def plot_lollipop(
1123
1112
  self,
1124
- data: dict[str, Any],
1113
+ data: dict[str, Any] | AnnData,
1114
+ *,
1125
1115
  key: str = "augurpy_results",
1126
- return_fig: bool | None = None,
1127
1116
  ax: Axes = None,
1128
- show: bool | None = None,
1129
- save: str | bool | None = None,
1130
- ) -> Axes | Figure | None:
1117
+ return_fig: bool = False,
1118
+ ) -> Figure | None:
1131
1119
  """Plot a lollipop plot of the mean augur values.
1132
1120
 
1133
1121
  Args:
1134
- results: results after running `predict()` as dictionary or the AnnData object.
1135
- key: Key in the AnnData object of the results
1136
- ax: optionally, axes used to draw plot
1137
- return_figure: if `True` returns figure of the plot
1122
+ data: results after running `predict()` as dictionary or the AnnData object.
1123
+ key: .uns key in the results AnnData object.
1124
+ ax: optionally, axes used to draw plot.
1125
+ {common_plot_args}
1138
1126
 
1139
1127
  Returns:
1140
- Axes of the plot.
1128
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1141
1129
 
1142
1130
  Examples:
1143
1131
  >>> import pertpy as pt
@@ -1175,32 +1163,27 @@ class Augur:
1175
1163
  plt.ylabel("Cell Type")
1176
1164
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1177
1165
 
1178
- if save:
1179
- plt.savefig(save, bbox_inches="tight")
1180
- if show:
1181
- plt.show()
1182
1166
  if return_fig:
1183
1167
  return plt.gcf()
1184
- if not (show or save):
1185
- return ax
1168
+ plt.show()
1186
1169
  return None
1187
1170
 
1171
+ @_doc_params(common_plot_args=doc_common_plot_args)
1188
1172
  def plot_scatterplot(
1189
1173
  self,
1190
1174
  results1: dict[str, Any],
1191
1175
  results2: dict[str, Any],
1176
+ *,
1192
1177
  top_n: int = None,
1193
- return_fig: bool | None = None,
1194
- show: bool | None = None,
1195
- save: str | bool | None = None,
1196
- ) -> Axes | Figure | None:
1178
+ return_fig: bool = False,
1179
+ ) -> Figure | None:
1197
1180
  """Create scatterplot with two augur results.
1198
1181
 
1199
1182
  Args:
1200
1183
  results1: results after running `predict()`
1201
1184
  results2: results after running `predict()`
1202
1185
  top_n: optionally, the number of top prioritized cell types to label in the plot
1203
- return_figure: if `True` returns figure of the plot
1186
+ {common_plot_args}
1204
1187
 
1205
1188
  Returns:
1206
1189
  Axes of the plot.
@@ -1249,12 +1232,7 @@ class Augur:
1249
1232
  plt.xlabel("Augur scores 1")
1250
1233
  plt.ylabel("Augur scores 2")
1251
1234
 
1252
- if save:
1253
- plt.savefig(save, bbox_inches="tight")
1254
- if show:
1255
- plt.show()
1256
1235
  if return_fig:
1257
1236
  return plt.gcf()
1258
- if not (show or save):
1259
- return ax
1237
+ plt.show()
1260
1238
  return None
pertpy/tools/_cinemaot.py CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
18
18
  from sklearn.linear_model import LinearRegression
19
19
  from sklearn.neighbors import NearestNeighbors
20
20
 
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
22
+
21
23
  if TYPE_CHECKING:
22
24
  from anndata import AnnData
23
25
  from matplotlib.axes import Axes
26
+ from matplotlib.pyplot import Figure
24
27
  from statsmodels.tools.typing import ArrayLike
25
28
 
26
29
 
@@ -88,7 +91,7 @@ class Cinemaot:
88
91
  dim = self.get_dim(adata, use_rep=use_rep)
89
92
 
90
93
  transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
91
- X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
94
+ X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
92
95
  groupvec = (adata.obs[pert_key] == control * 1).values # control
93
96
  xi = np.zeros(dim)
94
97
  j = 0
@@ -97,9 +100,9 @@ class Cinemaot:
97
100
  xi[j] = xi_obj.correlation
98
101
  j = j + 1
99
102
 
100
- cf = X_transformed[:, xi < thres]
101
- cf1 = cf[adata.obs[pert_key] == control, :]
102
- cf2 = cf[adata.obs[pert_key] != control, :]
103
+ cf = np.array(X_transformed[:, xi < thres], np.float64)
104
+ cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
105
+ cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
103
106
  if sum(xi < thres) == 1:
104
107
  sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
105
108
  elif sum(xi < thres) == 0:
@@ -167,7 +170,7 @@ class Cinemaot:
167
170
  else:
168
171
  _solver = sinkhorn.Sinkhorn(threshold=eps)
169
172
  ot_sink = _solver(ot_prob)
170
- ot_matrix = ot_sink.matrix.T
173
+ ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
171
174
  embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
172
175
  ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
173
176
  )
@@ -639,6 +642,7 @@ class Cinemaot:
639
642
  s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
640
643
  return c_effect, s_effect
641
644
 
645
+ @_doc_params(common_plot_args=doc_common_plot_args)
642
646
  def plot_vis_matching(
643
647
  self,
644
648
  adata: AnnData,
@@ -647,16 +651,16 @@ class Cinemaot:
647
651
  control: str,
648
652
  de_label: str,
649
653
  source_label: str,
654
+ *,
650
655
  matching_rep: str = "ot",
651
656
  resolution: float = 0.5,
652
657
  normalize: str = "col",
653
658
  title: str = "CINEMA-OT matching matrix",
654
659
  min_val: float = 0.01,
655
- show: bool = True,
656
- save: str | None = None,
657
660
  ax: Axes | None = None,
661
+ return_fig: bool = False,
658
662
  **kwargs,
659
- ) -> None:
663
+ ) -> Figure | None:
660
664
  """Visualize the CINEMA-OT matching matrix.
661
665
 
662
666
  Args:
@@ -670,11 +674,12 @@ class Cinemaot:
670
674
  normalize: normalize the coarse-grained matching matrix by row / column.
671
675
  title: the title for the figure.
672
676
  min_val: The min value to truncate the matching matrix.
673
- show: Show the plot, do not return axis.
674
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
675
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
677
+ {common_plot_args}
676
678
  **kwargs: Other parameters to input for seaborn.heatmap.
677
679
 
680
+ Returns:
681
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
682
+
678
683
  Examples:
679
684
  >>> import pertpy as pt
680
685
  >>> adata = pt.dt.cinemaot_example()
@@ -710,12 +715,11 @@ class Cinemaot:
710
715
 
711
716
  g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
712
717
  plt.title(title)
713
- _utils.savefig_or_show("matching_heatmap", show=show, save=save)
714
- if not show:
715
- if ax is not None:
716
- return ax
717
- else:
718
- return g
718
+
719
+ if return_fig:
720
+ return g
721
+ plt.show()
722
+ return None
719
723
 
720
724
 
721
725
  class Xi: