pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pertpy/tools/_dialogue.py CHANGED
@@ -25,6 +25,8 @@ from sklearn.linear_model import LinearRegression
25
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
26
26
  from statsmodels.sandbox.stats.multicomp import multipletests
27
27
 
28
+ from pertpy._doc import _doc_params, doc_common_plot_args
29
+
28
30
  if TYPE_CHECKING:
29
31
  from matplotlib.axes import Axes
30
32
  from matplotlib.figure import Figure
@@ -80,27 +82,27 @@ class Dialogue:
80
82
 
81
83
  return pseudobulk
82
84
 
83
- def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
84
- """Return cell-averaged PCA components.
85
+ def _pseudobulk_feature_space(
86
+ self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
87
+ ) -> pd.DataFrame:
88
+ """Return Cell-averaged components from a passed feature space.
85
89
 
86
90
  TODO: consider merging with `get_pseudobulks`
87
91
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
88
92
 
89
93
  Args:
90
- groupby: The key to groupby for pseudobulks
91
- n_components: The number of PCA components
94
+ groupby: The key to groupby for pseudobulks.
95
+ n_components: The number of components to use.
96
+ feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
92
97
 
93
98
  Returns:
94
- A pseudobulk of PCA components.
99
+ A pseudobulk DataFrame of the averaged components.
95
100
  """
96
101
  aggr = {}
97
-
98
102
  for category in adata.obs.loc[:, groupby].cat.categories:
99
103
  temp = adata.obs.loc[:, groupby] == category
100
- aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)
101
-
104
+ aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
102
105
  aggr = pd.DataFrame(aggr)
103
-
104
106
  return aggr
105
107
 
106
108
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
@@ -556,7 +558,7 @@ class Dialogue:
556
558
  self,
557
559
  adata: AnnData,
558
560
  ct_order: list[str],
559
- agg_pca: bool = True,
561
+ agg_feature: bool = True,
560
562
  normalize: bool = True,
561
563
  ) -> tuple[list, dict]:
562
564
  """Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
@@ -566,14 +568,14 @@ class Dialogue:
566
568
  Args:
567
569
  adata: AnnData object generate celltype objects for
568
570
  ct_order: The order of cell types
569
- agg_pca: Whether to aggregate pseudobulks with PCA or not.
571
+ agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
570
572
  normalize: Whether to mimic DIALOGUE behavior or not.
571
573
 
572
574
  Returns:
573
575
  A celltype_label:array dictionary.
574
576
  """
575
577
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
576
- fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
578
+ fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
577
579
  ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
578
580
 
579
581
  # TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
@@ -591,7 +593,7 @@ class Dialogue:
591
593
  adata: AnnData,
592
594
  penalties: list[int] = None,
593
595
  ct_order: list[str] = None,
594
- agg_pca: bool = True,
596
+ agg_feature: bool = True,
595
597
  solver: Literal["lp", "bs"] = "bs",
596
598
  normalize: bool = True,
597
599
  ) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
@@ -604,7 +606,7 @@ class Dialogue:
604
606
  sample_id: Key to use for pseudobulk determination.
605
607
  penalties: PMD penalties.
606
608
  ct_order: The order of cell types.
607
- agg_pca: Whether to calculate cell-averaged PCA components.
609
+ agg_features: Whether to calculate cell-averaged principal components.
608
610
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
609
611
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
610
612
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -629,7 +631,7 @@ class Dialogue:
629
631
  else:
630
632
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
631
633
 
632
- mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
634
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
633
635
 
634
636
  n_samples = mcca_in[0].shape[1]
635
637
  if penalties is None:
@@ -1059,18 +1061,17 @@ class Dialogue:
1059
1061
 
1060
1062
  return rank_dfs
1061
1063
 
1064
+ @_doc_params(common_plot_args=doc_common_plot_args)
1062
1065
  def plot_split_violins(
1063
1066
  self,
1064
1067
  adata: AnnData,
1065
1068
  split_key: str,
1066
1069
  celltype_key: str,
1070
+ *,
1067
1071
  split_which: tuple[str, str] = None,
1068
1072
  mcp: str = "mcp_0",
1069
- return_fig: bool | None = None,
1070
- ax: Axes | None = None,
1071
- save: bool | str | None = None,
1072
- show: bool | None = None,
1073
- ) -> Axes | Figure | None:
1073
+ return_fig: bool = False,
1074
+ ) -> Figure | None:
1074
1075
  """Plots split violin plots for a given MCP and split variable.
1075
1076
 
1076
1077
  Any cells with a value for split_key not in split_which are removed from the plot.
@@ -1081,9 +1082,10 @@ class Dialogue:
1081
1082
  celltype_key: Key for cell type annotations.
1082
1083
  split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1083
1084
  mcp: Key for MCP data.
1085
+ {common_plot_args}
1084
1086
 
1085
1087
  Returns:
1086
- A :class:`~matplotlib.axes.Axes` object
1088
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1087
1089
 
1088
1090
  Examples:
1089
1091
  >>> import pertpy as pt
@@ -1105,30 +1107,24 @@ class Dialogue:
1105
1107
  df[split_key] = df[split_key].cat.remove_unused_categories()
1106
1108
 
1107
1109
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1108
-
1109
1110
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1110
1111
 
1111
- if save:
1112
- plt.savefig(save, bbox_inches="tight")
1113
- if show:
1114
- plt.show()
1115
1112
  if return_fig:
1116
1113
  return plt.gcf()
1117
- if not (show or save):
1118
- return ax
1114
+ plt.show()
1119
1115
  return None
1120
1116
 
1117
+ @_doc_params(common_plot_args=doc_common_plot_args)
1121
1118
  def plot_pairplot(
1122
1119
  self,
1123
1120
  adata: AnnData,
1124
1121
  celltype_key: str,
1125
1122
  color: str,
1126
1123
  sample_id: str,
1124
+ *,
1127
1125
  mcp: str = "mcp_0",
1128
- return_fig: bool | None = None,
1129
- show: bool | None = None,
1130
- save: bool | str | None = None,
1131
- ) -> PairGrid | Figure | None:
1126
+ return_fig: bool = False,
1127
+ ) -> Figure | None:
1132
1128
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1133
1129
 
1134
1130
  Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
@@ -1140,9 +1136,10 @@ class Dialogue:
1140
1136
  color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1141
1137
  sample_id: Key in `adata.obs` for the sample annotations.
1142
1138
  mcp: Key in `adata.obs` for MCP feature values.
1139
+ {common_plot_args}
1143
1140
 
1144
1141
  Returns:
1145
- Seaborn Pairgrid object.
1142
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1146
1143
 
1147
1144
  Examples:
1148
1145
  >>> import pertpy as pt
@@ -1165,14 +1162,9 @@ class Dialogue:
1165
1162
  aggstats = aggstats.loc[list(mcp_pivot.index), :]
1166
1163
  aggstats[color] = aggstats["top"]
1167
1164
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
- ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1165
+ sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1166
 
1170
- if save:
1171
- plt.savefig(save, bbox_inches="tight")
1172
- if show:
1173
- plt.show()
1174
1167
  if return_fig:
1175
1168
  return plt.gcf()
1176
- if not (show or save):
1177
- return ax
1169
+ plt.show()
1178
1170
  return None
@@ -1,4 +1,4 @@
1
- from ._base import ContrastType, LinearModelBase, MethodBase
1
+ from ._base import LinearModelBase, MethodBase
2
2
  from ._dge_comparison import DGEEVAL
3
3
  from ._edger import EdgeR
4
4
  from ._pydeseq2 import PyDESeq2
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SimpleComparisonBase",
15
15
  "WilcoxonTest",
16
16
  "TTest",
17
- "ContrastType",
18
17
  ]
19
18
 
20
19
  AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]