pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  from collections import defaultdict
5
- from typing import Any, Literal
5
+ from typing import TYPE_CHECKING, Any, Literal
6
6
 
7
7
  import anndata as ad
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
10
11
  import scanpy as sc
11
- import scipy.sparse as sp
12
+ import seaborn as sns
12
13
  import statsmodels.formula.api as smf
13
14
  import statsmodels.stats.multitest as ssm
14
15
  from anndata import AnnData
@@ -19,10 +20,15 @@ from rich.live import Live
19
20
  from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
20
21
  from scipy import stats
21
22
  from scipy.optimize import nnls
23
+ from seaborn import PairGrid
22
24
  from sklearn.linear_model import LinearRegression
23
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
24
26
  from statsmodels.sandbox.stats.multicomp import multipletests
25
27
 
28
+ if TYPE_CHECKING:
29
+ from matplotlib.axes import Axes
30
+ from matplotlib.figure import Figure
31
+
26
32
 
27
33
  class Dialogue:
28
34
  """Python implementation of DIALOGUE"""
@@ -53,8 +59,6 @@ class Dialogue:
53
59
 
54
60
  Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
55
61
 
56
- # TODO: Replace with decoupler's implementation
57
-
58
62
  Args:
59
63
  groupby: The key to groupby for pseudobulks
60
64
  strategy: The pseudobulking strategy. One of "median" or "mean"
@@ -62,14 +66,15 @@ class Dialogue:
62
66
  Returns:
63
67
  A Pandas DataFrame of pseudobulk counts
64
68
  """
69
+ # TODO: Replace with decoupler's implementation
65
70
  pseudobulk = {"Genes": adata.var_names.values}
66
71
 
67
72
  for category in adata.obs.loc[:, groupby].cat.categories:
68
73
  temp = adata.obs.loc[:, groupby] == category
69
74
  if strategy == "median":
70
- pseudobulk[category] = adata[temp].X.median(axis=0).A1
75
+ pseudobulk[category] = adata[temp].X.median(axis=0)
71
76
  elif strategy == "mean":
72
- pseudobulk[category] = adata[temp].X.mean(axis=0).A1
77
+ pseudobulk[category] = adata[temp].X.mean(axis=0)
73
78
 
74
79
  pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
75
80
 
@@ -101,8 +106,6 @@ class Dialogue:
101
106
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
102
107
  """Row-wise mean center and scale by the standard deviation.
103
108
 
104
- TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
105
-
106
109
  Args:
107
110
  pseudobulks: The pseudobulk PCA components.
108
111
  normalize: Whether to mimic DIALOGUE behavior or not.
@@ -110,9 +113,9 @@ class Dialogue:
110
113
  Returns:
111
114
  The scaled count matrix.
112
115
  """
116
+ # TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
113
117
  # DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
114
118
  # However, performing this scaling _does_ increase overall correlation of the end result
115
- # WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
116
119
  if normalize:
117
120
  return pseudobulks.to_numpy()
118
121
  else:
@@ -313,13 +316,13 @@ class Dialogue:
313
316
  def _apply_HLM_per_MCP_for_one_pair(
314
317
  self,
315
318
  mcp_name: str,
316
- scores_df: dict,
319
+ scores_df: pd.DataFrame,
317
320
  ct_data: AnnData,
318
321
  tme: pd.DataFrame,
319
322
  sig: dict,
320
323
  n_counts: str,
321
324
  formula: str,
322
- confounder: str,
325
+ confounder: str | None,
323
326
  ) -> tuple[pd.DataFrame, dict[str, Any]]:
324
327
  """Applies hierarchical modeling for a single MCP.
325
328
 
@@ -340,7 +343,7 @@ class Dialogue:
340
343
  """
341
344
  HLM_result = self._mixed_effects(
342
345
  scores=scores_df[[mcp_name]],
343
- x_labels=ct_data.obs[[n_counts, confounder]],
346
+ x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
344
347
  tme=tme,
345
348
  genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
346
349
  formula=formula,
@@ -367,7 +370,7 @@ class Dialogue:
367
370
  return np.array(resid)
368
371
 
369
372
  def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
370
- """Solves non-negative least squares separately for different feature categories.
373
+ """Solves non-negative least-squares separately for different feature categories.
371
374
 
372
375
  Mimics DLG.iterative.nnls.
373
376
  Variables are notated according to:
@@ -398,7 +401,7 @@ class Dialogue:
398
401
 
399
402
  x_final = np.zeros(A_orig.shape[0])
400
403
  Ax = np.zeros(A_orig.shape[1])
401
- for _, mask in zip(sig_ranks, masks):
404
+ for _, mask in zip(sig_ranks, masks, strict=False):
402
405
  A = A_orig[mask].T
403
406
  coef_nnls, _ = nnls(A, y, maxiter=n_iter)
404
407
  y = y - A @ coef_nnls # residuals
@@ -516,8 +519,8 @@ class Dialogue:
516
519
  # TODO: probably format the up and down within get_top_elements
517
520
  cca_sig: dict[str, Any] = defaultdict(dict)
518
521
  for i in range(0, int(len(cca_sig_unformatted) / 2)):
519
- cca_sig[f"MCP{i + 1}"]["up"] = cca_sig_unformatted[i * 2]
520
- cca_sig[f"MCP{i + 1}"]["down"] = cca_sig_unformatted[i * 2 + 1]
522
+ cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
523
+ cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
521
524
 
522
525
  cca_sig = dict(cca_sig)
523
526
  cca_sig_results[ct] = cca_sig
@@ -555,7 +558,7 @@ class Dialogue:
555
558
 
556
559
  return cca_sig_results, new_mcp_scores
557
560
 
558
- def load(
561
+ def _load(
559
562
  self,
560
563
  adata: AnnData,
561
564
  ct_order: list[str],
@@ -574,16 +577,6 @@ class Dialogue:
574
577
 
575
578
  Returns:
576
579
  A celltype_label:array dictionary.
577
-
578
- Examples:
579
- >>> import pertpy as pt
580
- >>> import scanpy as sc
581
- >>> adata = pt.dt.dialogue_example()
582
- >>> sc.pp.pca(adata)
583
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
584
- n_counts_key = "nCount_RNA", n_mpcs = 3)
585
- >>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
586
- >>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
587
580
  """
588
581
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
589
582
  fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
@@ -631,19 +624,19 @@ class Dialogue:
631
624
  >>> import scanpy as sc
632
625
  >>> adata = pt.dt.dialogue_example()
633
626
  >>> sc.pp.pca(adata)
634
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
635
- n_counts_key = "nCount_RNA", n_mpcs = 3)
627
+ >>> dl = pt.tl.Dialogue(
628
+ ... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
629
+ ... )
636
630
  >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
637
631
  """
638
- # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. As such,
639
- # it is important here that to obtain the same result as in R, we pass the matrices in
640
- # in the same order.
632
+ # IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
633
+ # As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
641
634
  if ct_order is not None:
642
635
  cell_types = ct_order
643
636
  else:
644
637
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
645
638
 
646
- mcca_in, ct_subs = self.load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
639
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
647
640
 
648
641
  n_samples = mcca_in[0].shape[1]
649
642
  if penalties is None:
@@ -685,7 +678,7 @@ class Dialogue:
685
678
  ct_subs: dict,
686
679
  mcp_scores: dict,
687
680
  ws_dict: dict,
688
- confounder: str,
681
+ confounder: str | None,
689
682
  formula: str = None,
690
683
  ):
691
684
  """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
@@ -700,7 +693,6 @@ class Dialogue:
700
693
  A Pandas DataFrame containing:
701
694
  - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
702
695
  - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
703
- TODO: Describe both returns
704
696
 
705
697
  Examples:
706
698
  >>> import pertpy as pt
@@ -713,7 +705,9 @@ class Dialogue:
713
705
  >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
714
706
  confounder="gender")
715
707
  """
716
- # all possible pairs of cell types with out pairing same cell type
708
+ # TODO the returns of the function better
709
+
710
+ # all possible pairs of cell types without pairing same cell type
717
711
  cell_types = list(ct_subs.keys())
718
712
  pairs = list(itertools.combinations(cell_types, 2))
719
713
 
@@ -721,9 +715,9 @@ class Dialogue:
721
715
  formula = f"y ~ x + {self.n_counts_key}"
722
716
 
723
717
  # Hierarchical modeling expects DataFrames
724
- mcp_cell_types = {f"MCP{i + 1}": cell_types for i in range(self.n_mcps)}
718
+ mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
725
719
  mcp_scores_df = {
726
- ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
720
+ ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
727
721
  for ct, v in mcp_scores.items()
728
722
  }
729
723
 
@@ -805,7 +799,7 @@ class Dialogue:
805
799
  for mcp in mcps:
806
800
  mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
807
801
 
808
- # TODO Check that the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
802
+ # TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
809
803
  result = {}
810
804
  result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
811
805
  mcp_name=mcp,
@@ -875,22 +869,19 @@ class Dialogue:
875
869
  sample_label = self.sample_id
876
870
  n_mcps = self.n_mcps
877
871
 
878
- # create conditions_compare if not supplied
879
872
  if conditions_compare is None:
880
- conditions_compare = list(adata.obs["path_str"].cat.categories) # type: ignore
873
+ conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
881
874
  if len(conditions_compare) != 2:
882
875
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
883
876
 
884
- # create data frames to store results
885
877
  pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
886
878
  tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
887
879
  pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
888
880
 
889
881
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
890
882
  for celltype in adata.obs[celltype_label].unique():
891
- # subset data to cell type
892
883
  df = adata.obs[adata.obs[celltype_label] == celltype]
893
- # run t-test for each MCP
884
+
894
885
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
895
886
  mns = df.groupby(sample_label)[mcpnum].mean()
896
887
  mns = pd.concat([mns, response], axis=1)
@@ -900,11 +891,10 @@ class Dialogue:
900
891
  )
901
892
  pvals.loc[celltype, mcpnum] = res[1]
902
893
  tstats.loc[celltype, mcpnum] = res[0]
903
- # return(res)
904
894
 
905
- # benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
906
895
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
907
896
  pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
897
+
908
898
  return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
909
899
 
910
900
  def get_mlm_mcp_genes(
@@ -921,7 +911,7 @@ class Dialogue:
921
911
  celltype: Cell type of interest.
922
912
  results: dl.MultilevelModeling result object.
923
913
  MCP: MCP key of the result object.
924
- threshhold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
914
+ threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
925
915
  Defaults to 0.70.
926
916
  focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
927
917
  Defaults to None.
@@ -945,7 +935,6 @@ class Dialogue:
945
935
  # REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
946
936
  if MCP.startswith("mcp_"):
947
937
  MCP = MCP.replace("mcp_", "MCP")
948
- # convert from MCPx to MCPx+1
949
938
  MCP = "MCP" + str(int(MCP[3:]) - 1)
950
939
 
951
940
  # Extract all comparison keys from the results object
@@ -1014,17 +1003,16 @@ class Dialogue:
1014
1003
  objects containing the results of gene ranking analysis.
1015
1004
 
1016
1005
  Examples:
1017
- ct_subs = {
1018
- "subpop1": anndata_obj1,
1019
- "subpop2": anndata_obj2,
1020
- # ... more subpopulations ...
1021
- }
1022
- genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1006
+ >>> ct_subs = {
1007
+ ... "subpop1": anndata_obj1,
1008
+ ... "subpop2": anndata_obj2,
1009
+ ... # ... more subpopulations ...
1010
+ ... }
1011
+ >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1023
1012
  """
1024
1013
  genes = {}
1025
1014
  for ct in ct_subs.keys():
1026
1015
  mini = ct_subs[ct]
1027
- mini.obs[mcp]
1028
1016
  mini.obs["extrema"] = pd.qcut(
1029
1017
  mini.obs[mcp],
1030
1018
  [0, 0 + fraction, 1 - fraction, 1.0],
@@ -1034,6 +1022,7 @@ class Dialogue:
1034
1022
  mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
1035
1023
  )
1036
1024
  genes[ct] = mini # .uns['rank_genes_groups']
1025
+
1037
1026
  return genes
1038
1027
 
1039
1028
  def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
@@ -1064,7 +1053,7 @@ class Dialogue:
1064
1053
  >>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
1065
1054
  """
1066
1055
  rank_dfs: dict[str, dict[Any, Any]] = {}
1067
- _, ct_sub = next(iter(ct_subs.items()))
1056
+ ct_sub = next(iter(ct_subs.values()))
1068
1057
  mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
1069
1058
 
1070
1059
  for mcp in mcps:
@@ -1072,4 +1061,123 @@ class Dialogue:
1072
1061
  ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
1073
1062
  for celltype in ct_ranked.keys():
1074
1063
  rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1064
+
1075
1065
  return rank_dfs
1066
+
1067
+ def plot_split_violins(
1068
+ self,
1069
+ adata: AnnData,
1070
+ split_key: str,
1071
+ celltype_key: str,
1072
+ split_which: tuple[str, str] = None,
1073
+ mcp: str = "mcp_0",
1074
+ return_fig: bool | None = None,
1075
+ ax: Axes | None = None,
1076
+ save: bool | str | None = None,
1077
+ show: bool | None = None,
1078
+ ) -> Axes | Figure | None:
1079
+ """Plots split violin plots for a given MCP and split variable.
1080
+
1081
+ Any cells with a value for split_key not in split_which are removed from the plot.
1082
+
1083
+ Args:
1084
+ adata: Annotated data object.
1085
+ split_key: Variable in adata.obs used to split the data.
1086
+ celltype_key: Key for cell type annotations.
1087
+ split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1088
+ mcp: Key for MCP data. Defaults to "mcp_0".
1089
+
1090
+ Returns:
1091
+ A :class:`~matplotlib.axes.Axes` object
1092
+
1093
+ Examples:
1094
+ >>> import pertpy as pt
1095
+ >>> import scanpy as sc
1096
+ >>> adata = pt.dt.dialogue_example()
1097
+ >>> sc.pp.pca(adata)
1098
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1099
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1100
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1101
+ >>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
1102
+
1103
+ Preview:
1104
+ .. image:: /_static/docstring_previews/dialogue_violin.png
1105
+ """
1106
+ df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
1107
+ if split_which is None:
1108
+ split_which = df[split_key].unique()
1109
+ df = df[df[split_key].isin(split_which)]
1110
+ df[split_key] = df[split_key].cat.remove_unused_categories()
1111
+
1112
+ ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1113
+
1114
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1115
+
1116
+ if save:
1117
+ plt.savefig(save, bbox_inches="tight")
1118
+ if show:
1119
+ plt.show()
1120
+ if return_fig:
1121
+ return plt.gcf()
1122
+ if not (show or save):
1123
+ return ax
1124
+ return None
1125
+
1126
+ def plot_pairplot(
1127
+ self,
1128
+ adata: AnnData,
1129
+ celltype_key: str,
1130
+ color: str,
1131
+ sample_id: str,
1132
+ mcp: str = "mcp_0",
1133
+ return_fig: bool | None = None,
1134
+ show: bool | None = None,
1135
+ save: bool | str | None = None,
1136
+ ) -> PairGrid | Figure | None:
1137
+ """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1138
+
1139
+ Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
1140
+ then creates a pairplot to visualize the relationships between these mean MCP values.
1141
+
1142
+ Args:
1143
+ adata: Annotated data object.
1144
+ celltype_key: Key in `adata.obs` containing cell type annotations.
1145
+ color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1146
+ sample_id: Key in `adata.obs` for the sample annotations.
1147
+ mcp: Key in `adata.obs` for MCP feature values. Defaults to `"mcp_0"`.
1148
+
1149
+ Returns:
1150
+ Seaborn Pairgrid object.
1151
+
1152
+ Examples:
1153
+ >>> import pertpy as pt
1154
+ >>> import scanpy as sc
1155
+ >>> adata = pt.dt.dialogue_example()
1156
+ >>> sc.pp.pca(adata)
1157
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1158
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1159
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1160
+ >>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
1161
+
1162
+ Preview:
1163
+ .. image:: /_static/docstring_previews/dialogue_pairplot.png
1164
+ """
1165
+ mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
1166
+ mean_mcps = mean_mcps.reset_index()
1167
+ mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
1168
+
1169
+ aggstats = adata.obs.groupby([sample_id])[color].describe()
1170
+ aggstats = aggstats.loc[list(mcp_pivot.index), :]
1171
+ aggstats[color] = aggstats["top"]
1172
+ mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1173
+ ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1174
+
1175
+ if save:
1176
+ plt.savefig(save, bbox_inches="tight")
1177
+ if show:
1178
+ plt.show()
1179
+ if return_fig:
1180
+ return plt.gcf()
1181
+ if not (show or save):
1182
+ return ax
1183
+ return None