pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  from collections import defaultdict
5
- from typing import Any, Literal
5
+ from typing import TYPE_CHECKING, Any, Literal
6
6
 
7
7
  import anndata as ad
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
10
11
  import scanpy as sc
11
- import scipy.sparse as sp
12
+ import seaborn as sns
12
13
  import statsmodels.formula.api as smf
13
14
  import statsmodels.stats.multitest as ssm
14
15
  from anndata import AnnData
@@ -19,10 +20,15 @@ from rich.live import Live
19
20
  from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
20
21
  from scipy import stats
21
22
  from scipy.optimize import nnls
23
+ from seaborn import PairGrid
22
24
  from sklearn.linear_model import LinearRegression
23
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
24
26
  from statsmodels.sandbox.stats.multicomp import multipletests
25
27
 
28
+ if TYPE_CHECKING:
29
+ from matplotlib.axes import Axes
30
+ from matplotlib.figure import Figure
31
+
26
32
 
27
33
  class Dialogue:
28
34
  """Python implementation of DIALOGUE"""
@@ -53,8 +59,6 @@ class Dialogue:
53
59
 
54
60
  Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
55
61
 
56
- # TODO: Replace with decoupler's implementation
57
-
58
62
  Args:
59
63
  groupby: The key to groupby for pseudobulks
60
64
  strategy: The pseudobulking strategy. One of "median" or "mean"
@@ -62,14 +66,15 @@ class Dialogue:
62
66
  Returns:
63
67
  A Pandas DataFrame of pseudobulk counts
64
68
  """
69
+ # TODO: Replace with decoupler's implementation
65
70
  pseudobulk = {"Genes": adata.var_names.values}
66
71
 
67
72
  for category in adata.obs.loc[:, groupby].cat.categories:
68
73
  temp = adata.obs.loc[:, groupby] == category
69
74
  if strategy == "median":
70
- pseudobulk[category] = adata[temp].X.median(axis=0).A1
75
+ pseudobulk[category] = adata[temp].X.median(axis=0)
71
76
  elif strategy == "mean":
72
- pseudobulk[category] = adata[temp].X.mean(axis=0).A1
77
+ pseudobulk[category] = adata[temp].X.mean(axis=0)
73
78
 
74
79
  pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
75
80
 
@@ -101,8 +106,6 @@ class Dialogue:
101
106
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
102
107
  """Row-wise mean center and scale by the standard deviation.
103
108
 
104
- TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
105
-
106
109
  Args:
107
110
  pseudobulks: The pseudobulk PCA components.
108
111
  normalize: Whether to mimic DIALOGUE behavior or not.
@@ -110,9 +113,9 @@ class Dialogue:
110
113
  Returns:
111
114
  The scaled count matrix.
112
115
  """
116
+ # TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
113
117
  # DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
114
118
  # However, performing this scaling _does_ increase overall correlation of the end result
115
- # WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
116
119
  if normalize:
117
120
  return pseudobulks.to_numpy()
118
121
  else:
@@ -313,13 +316,13 @@ class Dialogue:
313
316
  def _apply_HLM_per_MCP_for_one_pair(
314
317
  self,
315
318
  mcp_name: str,
316
- scores_df: dict,
319
+ scores_df: pd.DataFrame,
317
320
  ct_data: AnnData,
318
321
  tme: pd.DataFrame,
319
322
  sig: dict,
320
323
  n_counts: str,
321
324
  formula: str,
322
- confounder: str,
325
+ confounder: str | None,
323
326
  ) -> tuple[pd.DataFrame, dict[str, Any]]:
324
327
  """Applies hierarchical modeling for a single MCP.
325
328
 
@@ -340,7 +343,7 @@ class Dialogue:
340
343
  """
341
344
  HLM_result = self._mixed_effects(
342
345
  scores=scores_df[[mcp_name]],
343
- x_labels=ct_data.obs[[n_counts, confounder]],
346
+ x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
344
347
  tme=tme,
345
348
  genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
346
349
  formula=formula,
@@ -367,7 +370,7 @@ class Dialogue:
367
370
  return np.array(resid)
368
371
 
369
372
  def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
370
- """Solves non-negative least squares separately for different feature categories.
373
+ """Solves non-negative least-squares separately for different feature categories.
371
374
 
372
375
  Mimics DLG.iterative.nnls.
373
376
  Variables are notated according to:
@@ -398,7 +401,7 @@ class Dialogue:
398
401
 
399
402
  x_final = np.zeros(A_orig.shape[0])
400
403
  Ax = np.zeros(A_orig.shape[1])
401
- for _, mask in zip(sig_ranks, masks):
404
+ for _, mask in zip(sig_ranks, masks, strict=False):
402
405
  A = A_orig[mask].T
403
406
  coef_nnls, _ = nnls(A, y, maxiter=n_iter)
404
407
  y = y - A @ coef_nnls # residuals
@@ -516,8 +519,8 @@ class Dialogue:
516
519
  # TODO: probably format the up and down within get_top_elements
517
520
  cca_sig: dict[str, Any] = defaultdict(dict)
518
521
  for i in range(0, int(len(cca_sig_unformatted) / 2)):
519
- cca_sig[f"MCP{i + 1}"]["up"] = cca_sig_unformatted[i * 2]
520
- cca_sig[f"MCP{i + 1}"]["down"] = cca_sig_unformatted[i * 2 + 1]
522
+ cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
523
+ cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
521
524
 
522
525
  cca_sig = dict(cca_sig)
523
526
  cca_sig_results[ct] = cca_sig
@@ -555,7 +558,7 @@ class Dialogue:
555
558
 
556
559
  return cca_sig_results, new_mcp_scores
557
560
 
558
- def load(
561
+ def _load(
559
562
  self,
560
563
  adata: AnnData,
561
564
  ct_order: list[str],
@@ -574,16 +577,6 @@ class Dialogue:
574
577
 
575
578
  Returns:
576
579
  A celltype_label:array dictionary.
577
-
578
- Examples:
579
- >>> import pertpy as pt
580
- >>> import scanpy as sc
581
- >>> adata = pt.dt.dialogue_example()
582
- >>> sc.pp.pca(adata)
583
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
584
- n_counts_key = "nCount_RNA", n_mpcs = 3)
585
- >>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
586
- >>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
587
580
  """
588
581
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
589
582
  fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
@@ -631,19 +624,19 @@ class Dialogue:
631
624
  >>> import scanpy as sc
632
625
  >>> adata = pt.dt.dialogue_example()
633
626
  >>> sc.pp.pca(adata)
634
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
635
- n_counts_key = "nCount_RNA", n_mpcs = 3)
627
+ >>> dl = pt.tl.Dialogue(
628
+ ... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
629
+ ... )
636
630
  >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
637
631
  """
638
- # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. As such,
639
- # it is important here that to obtain the same result as in R, we pass the matrices in
640
- # in the same order.
632
+ # IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
633
+ # As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
641
634
  if ct_order is not None:
642
635
  cell_types = ct_order
643
636
  else:
644
637
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
645
638
 
646
- mcca_in, ct_subs = self.load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
639
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
647
640
 
648
641
  n_samples = mcca_in[0].shape[1]
649
642
  if penalties is None:
@@ -685,7 +678,7 @@ class Dialogue:
685
678
  ct_subs: dict,
686
679
  mcp_scores: dict,
687
680
  ws_dict: dict,
688
- confounder: str,
681
+ confounder: str | None,
689
682
  formula: str = None,
690
683
  ):
691
684
  """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
@@ -700,7 +693,6 @@ class Dialogue:
700
693
  A Pandas DataFrame containing:
701
694
  - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
702
695
  - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
703
- TODO: Describe both returns
704
696
 
705
697
  Examples:
706
698
  >>> import pertpy as pt
@@ -713,7 +705,9 @@ class Dialogue:
713
705
  >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
714
706
  confounder="gender")
715
707
  """
716
- # all possible pairs of cell types with out pairing same cell type
708
+ # TODO the returns of the function better
709
+
710
+ # all possible pairs of cell types without pairing same cell type
717
711
  cell_types = list(ct_subs.keys())
718
712
  pairs = list(itertools.combinations(cell_types, 2))
719
713
 
@@ -721,9 +715,9 @@ class Dialogue:
721
715
  formula = f"y ~ x + {self.n_counts_key}"
722
716
 
723
717
  # Hierarchical modeling expects DataFrames
724
- mcp_cell_types = {f"MCP{i + 1}": cell_types for i in range(self.n_mcps)}
718
+ mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
725
719
  mcp_scores_df = {
726
- ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
720
+ ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
727
721
  for ct, v in mcp_scores.items()
728
722
  }
729
723
 
@@ -805,7 +799,7 @@ class Dialogue:
805
799
  for mcp in mcps:
806
800
  mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
807
801
 
808
- # TODO Check that the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
802
+ # TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
809
803
  result = {}
810
804
  result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
811
805
  mcp_name=mcp,
@@ -875,22 +869,19 @@ class Dialogue:
875
869
  sample_label = self.sample_id
876
870
  n_mcps = self.n_mcps
877
871
 
878
- # create conditions_compare if not supplied
879
872
  if conditions_compare is None:
880
- conditions_compare = list(adata.obs["path_str"].cat.categories) # type: ignore
873
+ conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
881
874
  if len(conditions_compare) != 2:
882
875
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
883
876
 
884
- # create data frames to store results
885
877
  pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
886
878
  tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
887
879
  pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
888
880
 
889
881
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
890
882
  for celltype in adata.obs[celltype_label].unique():
891
- # subset data to cell type
892
883
  df = adata.obs[adata.obs[celltype_label] == celltype]
893
- # run t-test for each MCP
884
+
894
885
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
895
886
  mns = df.groupby(sample_label)[mcpnum].mean()
896
887
  mns = pd.concat([mns, response], axis=1)
@@ -900,11 +891,10 @@ class Dialogue:
900
891
  )
901
892
  pvals.loc[celltype, mcpnum] = res[1]
902
893
  tstats.loc[celltype, mcpnum] = res[0]
903
- # return(res)
904
894
 
905
- # benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
906
895
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
907
896
  pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
897
+
908
898
  return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
909
899
 
910
900
  def get_mlm_mcp_genes(
@@ -921,7 +911,7 @@ class Dialogue:
921
911
  celltype: Cell type of interest.
922
912
  results: dl.MultilevelModeling result object.
923
913
  MCP: MCP key of the result object.
924
- threshhold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
914
+ threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
925
915
  Defaults to 0.70.
926
916
  focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
927
917
  Defaults to None.
@@ -945,7 +935,6 @@ class Dialogue:
945
935
  # REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
946
936
  if MCP.startswith("mcp_"):
947
937
  MCP = MCP.replace("mcp_", "MCP")
948
- # convert from MCPx to MCPx+1
949
938
  MCP = "MCP" + str(int(MCP[3:]) - 1)
950
939
 
951
940
  # Extract all comparison keys from the results object
@@ -1014,17 +1003,16 @@ class Dialogue:
1014
1003
  objects containing the results of gene ranking analysis.
1015
1004
 
1016
1005
  Examples:
1017
- ct_subs = {
1018
- "subpop1": anndata_obj1,
1019
- "subpop2": anndata_obj2,
1020
- # ... more subpopulations ...
1021
- }
1022
- genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1006
+ >>> ct_subs = {
1007
+ ... "subpop1": anndata_obj1,
1008
+ ... "subpop2": anndata_obj2,
1009
+ ... # ... more subpopulations ...
1010
+ ... }
1011
+ >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1023
1012
  """
1024
1013
  genes = {}
1025
1014
  for ct in ct_subs.keys():
1026
1015
  mini = ct_subs[ct]
1027
- mini.obs[mcp]
1028
1016
  mini.obs["extrema"] = pd.qcut(
1029
1017
  mini.obs[mcp],
1030
1018
  [0, 0 + fraction, 1 - fraction, 1.0],
@@ -1034,6 +1022,7 @@ class Dialogue:
1034
1022
  mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
1035
1023
  )
1036
1024
  genes[ct] = mini # .uns['rank_genes_groups']
1025
+
1037
1026
  return genes
1038
1027
 
1039
1028
  def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
@@ -1064,7 +1053,7 @@ class Dialogue:
1064
1053
  >>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
1065
1054
  """
1066
1055
  rank_dfs: dict[str, dict[Any, Any]] = {}
1067
- _, ct_sub = next(iter(ct_subs.items()))
1056
+ ct_sub = next(iter(ct_subs.values()))
1068
1057
  mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
1069
1058
 
1070
1059
  for mcp in mcps:
@@ -1072,4 +1061,123 @@ class Dialogue:
1072
1061
  ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
1073
1062
  for celltype in ct_ranked.keys():
1074
1063
  rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1064
+
1075
1065
  return rank_dfs
1066
+
1067
+ def plot_split_violins(
1068
+ self,
1069
+ adata: AnnData,
1070
+ split_key: str,
1071
+ celltype_key: str,
1072
+ split_which: tuple[str, str] = None,
1073
+ mcp: str = "mcp_0",
1074
+ return_fig: bool | None = None,
1075
+ ax: Axes | None = None,
1076
+ save: bool | str | None = None,
1077
+ show: bool | None = None,
1078
+ ) -> Axes | Figure | None:
1079
+ """Plots split violin plots for a given MCP and split variable.
1080
+
1081
+ Any cells with a value for split_key not in split_which are removed from the plot.
1082
+
1083
+ Args:
1084
+ adata: Annotated data object.
1085
+ split_key: Variable in adata.obs used to split the data.
1086
+ celltype_key: Key for cell type annotations.
1087
+ split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1088
+ mcp: Key for MCP data. Defaults to "mcp_0".
1089
+
1090
+ Returns:
1091
+ A :class:`~matplotlib.axes.Axes` object
1092
+
1093
+ Examples:
1094
+ >>> import pertpy as pt
1095
+ >>> import scanpy as sc
1096
+ >>> adata = pt.dt.dialogue_example()
1097
+ >>> sc.pp.pca(adata)
1098
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1099
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1100
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1101
+ >>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
1102
+
1103
+ Preview:
1104
+ .. image:: /_static/docstring_previews/dialogue_violin.png
1105
+ """
1106
+ df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
1107
+ if split_which is None:
1108
+ split_which = df[split_key].unique()
1109
+ df = df[df[split_key].isin(split_which)]
1110
+ df[split_key] = df[split_key].cat.remove_unused_categories()
1111
+
1112
+ ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1113
+
1114
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1115
+
1116
+ if save:
1117
+ plt.savefig(save, bbox_inches="tight")
1118
+ if show:
1119
+ plt.show()
1120
+ if return_fig:
1121
+ return plt.gcf()
1122
+ if not (show or save):
1123
+ return ax
1124
+ return None
1125
+
1126
+ def plot_pairplot(
1127
+ self,
1128
+ adata: AnnData,
1129
+ celltype_key: str,
1130
+ color: str,
1131
+ sample_id: str,
1132
+ mcp: str = "mcp_0",
1133
+ return_fig: bool | None = None,
1134
+ show: bool | None = None,
1135
+ save: bool | str | None = None,
1136
+ ) -> PairGrid | Figure | None:
1137
+ """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1138
+
1139
+ Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
1140
+ then creates a pairplot to visualize the relationships between these mean MCP values.
1141
+
1142
+ Args:
1143
+ adata: Annotated data object.
1144
+ celltype_key: Key in `adata.obs` containing cell type annotations.
1145
+ color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1146
+ sample_id: Key in `adata.obs` for the sample annotations.
1147
+ mcp: Key in `adata.obs` for MCP feature values. Defaults to `"mcp_0"`.
1148
+
1149
+ Returns:
1150
+ Seaborn Pairgrid object.
1151
+
1152
+ Examples:
1153
+ >>> import pertpy as pt
1154
+ >>> import scanpy as sc
1155
+ >>> adata = pt.dt.dialogue_example()
1156
+ >>> sc.pp.pca(adata)
1157
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1158
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1159
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1160
+ >>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
1161
+
1162
+ Preview:
1163
+ .. image:: /_static/docstring_previews/dialogue_pairplot.png
1164
+ """
1165
+ mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
1166
+ mean_mcps = mean_mcps.reset_index()
1167
+ mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
1168
+
1169
+ aggstats = adata.obs.groupby([sample_id])[color].describe()
1170
+ aggstats = aggstats.loc[list(mcp_pivot.index), :]
1171
+ aggstats[color] = aggstats["top"]
1172
+ mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1173
+ ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1174
+
1175
+ if save:
1176
+ plt.savefig(save, bbox_inches="tight")
1177
+ if show:
1178
+ plt.show()
1179
+ if return_fig:
1180
+ return plt.gcf()
1181
+ if not (show or save):
1182
+ return ax
1183
+ return None