pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/tools/_augur.py CHANGED
@@ -15,7 +15,6 @@ import statsmodels.api as sm
15
15
  from anndata import AnnData
16
16
  from joblib import Parallel, delayed
17
17
  from lamin_utils import logger
18
- from rich import print
19
18
  from rich.progress import track
20
19
  from scipy import sparse, stats
21
20
  from sklearn.base import is_classifier, is_regressor
@@ -26,17 +25,19 @@ from sklearn.metrics import (
26
25
  explained_variance_score,
27
26
  f1_score,
28
27
  make_scorer,
29
- mean_squared_error,
30
28
  precision_score,
31
29
  r2_score,
32
30
  recall_score,
33
31
  roc_auc_score,
32
+ root_mean_squared_error,
34
33
  )
35
34
  from sklearn.model_selection import StratifiedKFold, cross_validate
36
35
  from sklearn.preprocessing import LabelEncoder
37
36
  from skmisc.loess import loess
38
37
  from statsmodels.stats.multitest import fdrcorrection
39
38
 
39
+ from pertpy._doc import _doc_params, doc_common_plot_args
40
+
40
41
  if TYPE_CHECKING:
41
42
  from matplotlib.axes import Axes
42
43
  from matplotlib.figure import Figure
@@ -439,7 +440,7 @@ class Augur:
439
440
  "augur_score": make_scorer(self.ccc_score),
440
441
  "r2": make_scorer(r2_score),
441
442
  "ccc": make_scorer(self.ccc_score),
442
- "neg_mean_squared_error": make_scorer(mean_squared_error),
443
+ "neg_mean_squared_error": make_scorer(root_mean_squared_error),
443
444
  "explained_variance": make_scorer(explained_variance_score),
444
445
  }
445
446
  )
@@ -684,7 +685,7 @@ class Augur:
684
685
  span: float = 0.75,
685
686
  filter_negative_residuals: bool = False,
686
687
  n_threads: int = 4,
687
- augur_mode: Literal["permute"] | Literal["default"] | Literal["velocity"] = "default",
688
+ augur_mode: Literal["default", "permute", "velocity"] = "default",
688
689
  select_variance_features: bool = True,
689
690
  key_added: str = "augurpy_results",
690
691
  random_state: int | None = None,
@@ -907,41 +908,39 @@ class Augur:
907
908
  .mean()
908
909
  )
909
910
 
910
- sampled_permuted_cv_augur1 = []
911
- sampled_permuted_cv_augur2 = []
911
+ rng = np.random.default_rng()
912
+ sampled_data = []
912
913
 
913
914
  # draw mean aucs for permute1 and permute2
914
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
915
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
916
917
  df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
917
- for permutation_idx in range(n_permutations):
918
- # subsample
919
- sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
920
- sampled_permuted_cv_augur1.append(
921
- pd.DataFrame(
922
- {
923
- "cell_type": [celltype],
924
- "permutation_idx": [permutation_idx],
925
- "mean": [sample1["augur_score"].mean(axis=0)],
926
- "std": [sample1["augur_score"].std(axis=0)],
927
- }
928
- )
929
- )
930
918
 
931
- sample2 = df2.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
932
- sampled_permuted_cv_augur2.append(
933
- pd.DataFrame(
934
- {
935
- "cell_type": [celltype],
936
- "permutation_idx": [permutation_idx],
937
- "mean": [sample2["augur_score"].mean(axis=0)],
938
- "std": [sample2["augur_score"].std(axis=0)],
939
- }
940
- )
919
+ indices1 = rng.choice(len(df1), size=(n_permutations, n_subsamples), replace=True)
920
+ indices2 = rng.choice(len(df2), size=(n_permutations, n_subsamples), replace=True)
921
+
922
+ scores1 = df1["augur_score"].values[indices1]
923
+ scores2 = df2["augur_score"].values[indices2]
924
+
925
+ means1 = scores1.mean(axis=1)
926
+ means2 = scores2.mean(axis=1)
927
+ stds1 = scores1.std(axis=1)
928
+ stds2 = scores2.std(axis=1)
929
+
930
+ sampled_data.append(
931
+ pd.DataFrame(
932
+ {
933
+ "cell_type": np.repeat(celltype, n_permutations),
934
+ "permutation_idx": np.arange(n_permutations),
935
+ "mean1": means1,
936
+ "mean2": means2,
937
+ "std1": stds1,
938
+ "std2": stds2,
939
+ }
941
940
  )
941
+ )
942
942
 
943
- permuted_samples1 = pd.concat(sampled_permuted_cv_augur1)
944
- permuted_samples2 = pd.concat(sampled_permuted_cv_augur2)
943
+ sampled_df = pd.concat(sampled_data)
945
944
 
946
945
  # delta between augur scores
947
946
  delta = augur_score1.merge(augur_score2, on=["cell_type"], suffixes=("1", "2")).assign(
@@ -949,9 +948,7 @@ class Augur:
949
948
  )
950
949
 
951
950
  # delta between permutation scores
952
- delta_rnd = permuted_samples1.merge(
953
- permuted_samples2, on=["cell_type", "permutation_idx"], suffixes=("1", "2")
954
- ).assign(delta_rnd=lambda x: x.mean2 - x.mean1)
951
+ delta_rnd = sampled_df.assign(delta_rnd=lambda x: x.mean2 - x.mean1)
955
952
 
956
953
  # number of values where permutations are larger than test statistic
957
954
  delta["b"] = (
@@ -966,7 +963,7 @@ class Augur:
966
963
  delta["z"] = (
967
964
  delta["delta_augur"] - delta_rnd.groupby("cell_type", as_index=False).mean()["delta_rnd"]
968
965
  ) / delta_rnd.groupby("cell_type", as_index=False).std()["delta_rnd"]
969
- # calculate pvalues
966
+
970
967
  delta["pval"] = np.minimum(
971
968
  2 * (delta["b"] + 1) / (delta["m"] + 1), 2 * (delta["m"] - delta["b"] + 1) / (delta["m"] + 1)
972
969
  )
@@ -974,24 +971,25 @@ class Augur:
974
971
 
975
972
  return delta
976
973
 
974
+ @_doc_params(common_plot_args=doc_common_plot_args)
977
975
  def plot_dp_scatter(
978
976
  self,
979
977
  results: pd.DataFrame,
978
+ *,
980
979
  top_n: int = None,
981
- return_fig: bool | None = None,
982
980
  ax: Axes = None,
983
- show: bool | None = None,
984
- save: str | bool | None = None,
985
- ) -> Axes | Figure | None:
981
+ return_fig: bool = False,
982
+ ) -> Figure | None:
986
983
  """Plot scatterplot of differential prioritization.
987
984
 
988
985
  Args:
989
986
  results: Results after running differential prioritization.
990
987
  top_n: optionally, the number of top prioritized cell types to label in the plot
991
988
  ax: optionally, axes used to draw plot
989
+ {common_plot_args}
992
990
 
993
991
  Returns:
994
- Axes of the plot.
992
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
995
993
 
996
994
  Examples:
997
995
  >>> import pertpy as pt
@@ -1038,37 +1036,32 @@ class Augur:
1038
1036
  legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1039
1037
  ax.add_artist(legend1)
1040
1038
 
1041
- if save:
1042
- plt.savefig(save, bbox_inches="tight")
1043
- if show:
1044
- plt.show()
1045
1039
  if return_fig:
1046
1040
  return plt.gcf()
1047
- if not (show or save):
1048
- return ax
1041
+ plt.show()
1049
1042
  return None
1050
1043
 
1044
+ @_doc_params(common_plot_args=doc_common_plot_args)
1051
1045
  def plot_important_features(
1052
1046
  self,
1053
1047
  data: dict[str, Any],
1048
+ *,
1054
1049
  key: str = "augurpy_results",
1055
1050
  top_n: int = 10,
1056
- return_fig: bool | None = None,
1057
1051
  ax: Axes = None,
1058
- show: bool | None = None,
1059
- save: str | bool | None = None,
1060
- ) -> Axes | None:
1052
+ return_fig: bool = False,
1053
+ ) -> Figure | None:
1061
1054
  """Plot a lollipop plot of the n features with largest feature importances.
1062
1055
 
1063
1056
  Args:
1064
- results: results after running `predict()` as dictionary or the AnnData object.
1057
+ data: results after running `predict()` as dictionary or the AnnData object.
1065
1058
  key: Key in the AnnData object of the results
1066
1059
  top_n: n number feature importance values to plot. Default is 10.
1067
1060
  ax: optionally, axes used to draw plot
1068
- return_figure: if `True` returns figure of the plot, default is `False`
1061
+ {common_plot_args}
1069
1062
 
1070
1063
  Returns:
1071
- Axes of the plot.
1064
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1072
1065
 
1073
1066
  Examples:
1074
1067
  >>> import pertpy as pt
@@ -1109,35 +1102,30 @@ class Augur:
1109
1102
  plt.ylabel("Gene")
1110
1103
  plt.yticks(y_axes_range, n_features["genes"])
1111
1104
 
1112
- if save:
1113
- plt.savefig(save, bbox_inches="tight")
1114
- if show:
1115
- plt.show()
1116
1105
  if return_fig:
1117
1106
  return plt.gcf()
1118
- if not (show or save):
1119
- return ax
1107
+ plt.show()
1120
1108
  return None
1121
1109
 
1110
+ @_doc_params(common_plot_args=doc_common_plot_args)
1122
1111
  def plot_lollipop(
1123
1112
  self,
1124
- data: dict[str, Any],
1113
+ data: dict[str, Any] | AnnData,
1114
+ *,
1125
1115
  key: str = "augurpy_results",
1126
- return_fig: bool | None = None,
1127
1116
  ax: Axes = None,
1128
- show: bool | None = None,
1129
- save: str | bool | None = None,
1130
- ) -> Axes | Figure | None:
1117
+ return_fig: bool = False,
1118
+ ) -> Figure | None:
1131
1119
  """Plot a lollipop plot of the mean augur values.
1132
1120
 
1133
1121
  Args:
1134
- results: results after running `predict()` as dictionary or the AnnData object.
1135
- key: Key in the AnnData object of the results
1136
- ax: optionally, axes used to draw plot
1137
- return_figure: if `True` returns figure of the plot
1122
+ data: results after running `predict()` as dictionary or the AnnData object.
1123
+ key: .uns key in the results AnnData object.
1124
+ ax: optionally, axes used to draw plot.
1125
+ {common_plot_args}
1138
1126
 
1139
1127
  Returns:
1140
- Axes of the plot.
1128
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1141
1129
 
1142
1130
  Examples:
1143
1131
  >>> import pertpy as pt
@@ -1175,32 +1163,27 @@ class Augur:
1175
1163
  plt.ylabel("Cell Type")
1176
1164
  plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1177
1165
 
1178
- if save:
1179
- plt.savefig(save, bbox_inches="tight")
1180
- if show:
1181
- plt.show()
1182
1166
  if return_fig:
1183
1167
  return plt.gcf()
1184
- if not (show or save):
1185
- return ax
1168
+ plt.show()
1186
1169
  return None
1187
1170
 
1171
+ @_doc_params(common_plot_args=doc_common_plot_args)
1188
1172
  def plot_scatterplot(
1189
1173
  self,
1190
1174
  results1: dict[str, Any],
1191
1175
  results2: dict[str, Any],
1176
+ *,
1192
1177
  top_n: int = None,
1193
- return_fig: bool | None = None,
1194
- show: bool | None = None,
1195
- save: str | bool | None = None,
1196
- ) -> Axes | Figure | None:
1178
+ return_fig: bool = False,
1179
+ ) -> Figure | None:
1197
1180
  """Create scatterplot with two augur results.
1198
1181
 
1199
1182
  Args:
1200
1183
  results1: results after running `predict()`
1201
1184
  results2: results after running `predict()`
1202
1185
  top_n: optionally, the number of top prioritized cell types to label in the plot
1203
- return_figure: if `True` returns figure of the plot
1186
+ {common_plot_args}
1204
1187
 
1205
1188
  Returns:
1206
1189
  Axes of the plot.
@@ -1249,12 +1232,7 @@ class Augur:
1249
1232
  plt.xlabel("Augur scores 1")
1250
1233
  plt.ylabel("Augur scores 2")
1251
1234
 
1252
- if save:
1253
- plt.savefig(save, bbox_inches="tight")
1254
- if show:
1255
- plt.show()
1256
1235
  if return_fig:
1257
1236
  return plt.gcf()
1258
- if not (show or save):
1259
- return ax
1237
+ plt.show()
1260
1238
  return None
pertpy/tools/_cinemaot.py CHANGED
@@ -18,9 +18,12 @@ from sklearn.decomposition import FastICA
18
18
  from sklearn.linear_model import LinearRegression
19
19
  from sklearn.neighbors import NearestNeighbors
20
20
 
21
+ from pertpy._doc import _doc_params, doc_common_plot_args
22
+
21
23
  if TYPE_CHECKING:
22
24
  from anndata import AnnData
23
25
  from matplotlib.axes import Axes
26
+ from matplotlib.pyplot import Figure
24
27
  from statsmodels.tools.typing import ArrayLike
25
28
 
26
29
 
@@ -88,7 +91,7 @@ class Cinemaot:
88
91
  dim = self.get_dim(adata, use_rep=use_rep)
89
92
 
90
93
  transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
91
- X_transformed = transformer.fit_transform(adata.obsm[use_rep][:, :dim])
94
+ X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
92
95
  groupvec = (adata.obs[pert_key] == control * 1).values # control
93
96
  xi = np.zeros(dim)
94
97
  j = 0
@@ -97,9 +100,9 @@ class Cinemaot:
97
100
  xi[j] = xi_obj.correlation
98
101
  j = j + 1
99
102
 
100
- cf = X_transformed[:, xi < thres]
101
- cf1 = cf[adata.obs[pert_key] == control, :]
102
- cf2 = cf[adata.obs[pert_key] != control, :]
103
+ cf = np.array(X_transformed[:, xi < thres], np.float64)
104
+ cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
105
+ cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
103
106
  if sum(xi < thres) == 1:
104
107
  sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
105
108
  elif sum(xi < thres) == 0:
@@ -167,7 +170,7 @@ class Cinemaot:
167
170
  else:
168
171
  _solver = sinkhorn.Sinkhorn(threshold=eps)
169
172
  ot_sink = _solver(ot_prob)
170
- ot_matrix = ot_sink.matrix.T
173
+ ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
171
174
  embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
172
175
  ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
173
176
  )
@@ -639,6 +642,7 @@ class Cinemaot:
639
642
  s_effect = (np.linalg.norm(e1, axis=0) + 1e-6) / (np.linalg.norm(e0, axis=0) + 1e-6)
640
643
  return c_effect, s_effect
641
644
 
645
+ @_doc_params(common_plot_args=doc_common_plot_args)
642
646
  def plot_vis_matching(
643
647
  self,
644
648
  adata: AnnData,
@@ -647,16 +651,16 @@ class Cinemaot:
647
651
  control: str,
648
652
  de_label: str,
649
653
  source_label: str,
654
+ *,
650
655
  matching_rep: str = "ot",
651
656
  resolution: float = 0.5,
652
657
  normalize: str = "col",
653
658
  title: str = "CINEMA-OT matching matrix",
654
659
  min_val: float = 0.01,
655
- show: bool = True,
656
- save: str | None = None,
657
660
  ax: Axes | None = None,
661
+ return_fig: bool = False,
658
662
  **kwargs,
659
- ) -> None:
663
+ ) -> Figure | None:
660
664
  """Visualize the CINEMA-OT matching matrix.
661
665
 
662
666
  Args:
@@ -670,11 +674,12 @@ class Cinemaot:
670
674
  normalize: normalize the coarse-grained matching matrix by row / column.
671
675
  title: the title for the figure.
672
676
  min_val: The min value to truncate the matching matrix.
673
- show: Show the plot, do not return axis.
674
- save: If `True` or a `str`, save the figure. A string is appended to the default filename.
675
- Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
677
+ {common_plot_args}
676
678
  **kwargs: Other parameters to input for seaborn.heatmap.
677
679
 
680
+ Returns:
681
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
682
+
678
683
  Examples:
679
684
  >>> import pertpy as pt
680
685
  >>> adata = pt.dt.cinemaot_example()
@@ -710,12 +715,11 @@ class Cinemaot:
710
715
 
711
716
  g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
712
717
  plt.title(title)
713
- _utils.savefig_or_show("matching_heatmap", show=show, save=save)
714
- if not show:
715
- if ax is not None:
716
- return ax
717
- else:
718
- return g
718
+
719
+ if return_fig:
720
+ return g
721
+ plt.show()
722
+ return None
719
723
 
720
724
 
721
725
  class Xi: