pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pertpy/tools/_dialogue.py CHANGED
@@ -25,6 +25,8 @@ from sklearn.linear_model import LinearRegression
25
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
26
26
  from statsmodels.sandbox.stats.multicomp import multipletests
27
27
 
28
+ from pertpy._doc import _doc_params, doc_common_plot_args
29
+
28
30
  if TYPE_CHECKING:
29
31
  from matplotlib.axes import Axes
30
32
  from matplotlib.figure import Figure
@@ -80,27 +82,27 @@ class Dialogue:
80
82
 
81
83
  return pseudobulk
82
84
 
83
- def _pseudobulk_pca(self, adata: AnnData, groupby: str, n_components: int = 50) -> pd.DataFrame:
84
- """Return cell-averaged PCA components.
85
+ def _pseudobulk_feature_space(
86
+ self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
87
+ ) -> pd.DataFrame:
88
+ """Return Cell-averaged components from a passed feature space.
85
89
 
86
90
  TODO: consider merging with `get_pseudobulks`
87
91
  TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
88
92
 
89
93
  Args:
90
- groupby: The key to groupby for pseudobulks
91
- n_components: The number of PCA components
94
+ groupby: The key to groupby for pseudobulks.
95
+ n_components: The number of components to use.
96
+ feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
92
97
 
93
98
  Returns:
94
- A pseudobulk of PCA components.
99
+ A pseudobulk DataFrame of the averaged components.
95
100
  """
96
101
  aggr = {}
97
-
98
102
  for category in adata.obs.loc[:, groupby].cat.categories:
99
103
  temp = adata.obs.loc[:, groupby] == category
100
- aggr[category] = adata[temp].obsm["X_pca"][:, :n_components].mean(axis=0)
101
-
104
+ aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
102
105
  aggr = pd.DataFrame(aggr)
103
-
104
106
  return aggr
105
107
 
106
108
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
@@ -556,7 +558,7 @@ class Dialogue:
556
558
  self,
557
559
  adata: AnnData,
558
560
  ct_order: list[str],
559
- agg_pca: bool = True,
561
+ agg_feature: bool = True,
560
562
  normalize: bool = True,
561
563
  ) -> tuple[list, dict]:
562
564
  """Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
@@ -566,14 +568,14 @@ class Dialogue:
566
568
  Args:
567
569
  adata: AnnData object generate celltype objects for
568
570
  ct_order: The order of cell types
569
- agg_pca: Whether to aggregate pseudobulks with PCA or not.
571
+ agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
570
572
  normalize: Whether to mimic DIALOGUE behavior or not.
571
573
 
572
574
  Returns:
573
575
  A celltype_label:array dictionary.
574
576
  """
575
577
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
576
- fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
578
+ fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
577
579
  ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
578
580
 
579
581
  # TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
@@ -591,7 +593,7 @@ class Dialogue:
591
593
  adata: AnnData,
592
594
  penalties: list[int] = None,
593
595
  ct_order: list[str] = None,
594
- agg_pca: bool = True,
596
+ agg_feature: bool = True,
595
597
  solver: Literal["lp", "bs"] = "bs",
596
598
  normalize: bool = True,
597
599
  ) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
@@ -604,7 +606,7 @@ class Dialogue:
604
606
  sample_id: Key to use for pseudobulk determination.
605
607
  penalties: PMD penalties.
606
608
  ct_order: The order of cell types.
607
- agg_pca: Whether to calculate cell-averaged PCA components.
609
+ agg_features: Whether to calculate cell-averaged principal components.
608
610
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
609
611
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
610
612
  normalize: Whether to mimic DIALOGUE as close as possible
@@ -629,7 +631,7 @@ class Dialogue:
629
631
  else:
630
632
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
631
633
 
632
- mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
634
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
633
635
 
634
636
  n_samples = mcca_in[0].shape[1]
635
637
  if penalties is None:
@@ -1059,18 +1061,17 @@ class Dialogue:
1059
1061
 
1060
1062
  return rank_dfs
1061
1063
 
1064
+ @_doc_params(common_plot_args=doc_common_plot_args)
1062
1065
  def plot_split_violins(
1063
1066
  self,
1064
1067
  adata: AnnData,
1065
1068
  split_key: str,
1066
1069
  celltype_key: str,
1070
+ *,
1067
1071
  split_which: tuple[str, str] = None,
1068
1072
  mcp: str = "mcp_0",
1069
- return_fig: bool | None = None,
1070
- ax: Axes | None = None,
1071
- save: bool | str | None = None,
1072
- show: bool | None = None,
1073
- ) -> Axes | Figure | None:
1073
+ return_fig: bool = False,
1074
+ ) -> Figure | None:
1074
1075
  """Plots split violin plots for a given MCP and split variable.
1075
1076
 
1076
1077
  Any cells with a value for split_key not in split_which are removed from the plot.
@@ -1081,9 +1082,10 @@ class Dialogue:
1081
1082
  celltype_key: Key for cell type annotations.
1082
1083
  split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1083
1084
  mcp: Key for MCP data.
1085
+ {common_plot_args}
1084
1086
 
1085
1087
  Returns:
1086
- A :class:`~matplotlib.axes.Axes` object
1088
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1087
1089
 
1088
1090
  Examples:
1089
1091
  >>> import pertpy as pt
@@ -1105,30 +1107,24 @@ class Dialogue:
1105
1107
  df[split_key] = df[split_key].cat.remove_unused_categories()
1106
1108
 
1107
1109
  ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1108
-
1109
1110
  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1110
1111
 
1111
- if save:
1112
- plt.savefig(save, bbox_inches="tight")
1113
- if show:
1114
- plt.show()
1115
1112
  if return_fig:
1116
1113
  return plt.gcf()
1117
- if not (show or save):
1118
- return ax
1114
+ plt.show()
1119
1115
  return None
1120
1116
 
1117
+ @_doc_params(common_plot_args=doc_common_plot_args)
1121
1118
  def plot_pairplot(
1122
1119
  self,
1123
1120
  adata: AnnData,
1124
1121
  celltype_key: str,
1125
1122
  color: str,
1126
1123
  sample_id: str,
1124
+ *,
1127
1125
  mcp: str = "mcp_0",
1128
- return_fig: bool | None = None,
1129
- show: bool | None = None,
1130
- save: bool | str | None = None,
1131
- ) -> PairGrid | Figure | None:
1126
+ return_fig: bool = False,
1127
+ ) -> Figure | None:
1132
1128
  """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1133
1129
 
1134
1130
  Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
@@ -1140,9 +1136,10 @@ class Dialogue:
1140
1136
  color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1141
1137
  sample_id: Key in `adata.obs` for the sample annotations.
1142
1138
  mcp: Key in `adata.obs` for MCP feature values.
1139
+ {common_plot_args}
1143
1140
 
1144
1141
  Returns:
1145
- Seaborn Pairgrid object.
1142
+ If `return_fig` is `True`, returns the figure, otherwise `None`.
1146
1143
 
1147
1144
  Examples:
1148
1145
  >>> import pertpy as pt
@@ -1165,14 +1162,9 @@ class Dialogue:
1165
1162
  aggstats = aggstats.loc[list(mcp_pivot.index), :]
1166
1163
  aggstats[color] = aggstats["top"]
1167
1164
  mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
- ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1165
+ sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
1166
 
1170
- if save:
1171
- plt.savefig(save, bbox_inches="tight")
1172
- if show:
1173
- plt.show()
1174
1167
  if return_fig:
1175
1168
  return plt.gcf()
1176
- if not (show or save):
1177
- return ax
1169
+ plt.show()
1178
1170
  return None
@@ -1,4 +1,4 @@
1
- from ._base import ContrastType, LinearModelBase, MethodBase
1
+ from ._base import LinearModelBase, MethodBase
2
2
  from ._dge_comparison import DGEEVAL
3
3
  from ._edger import EdgeR
4
4
  from ._pydeseq2 import PyDESeq2
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SimpleComparisonBase",
15
15
  "WilcoxonTest",
16
16
  "TTest",
17
- "ContrastType",
18
17
  ]
19
18
 
20
19
  AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]