pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_dialogue.py CHANGED
@@ -2,27 +2,33 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  from collections import defaultdict
5
- from typing import Any, Literal
5
+ from typing import TYPE_CHECKING, Any, Literal
6
6
 
7
7
  import anndata as ad
8
+ import matplotlib.pyplot as plt
8
9
  import numpy as np
9
10
  import pandas as pd
10
11
  import scanpy as sc
11
- import scipy.sparse as sp
12
+ import seaborn as sns
12
13
  import statsmodels.formula.api as smf
13
14
  import statsmodels.stats.multitest as ssm
14
15
  from anndata import AnnData
16
+ from lamin_utils import logger
15
17
  from pandas import DataFrame
16
- from rich import print
17
18
  from rich.console import Group
18
19
  from rich.live import Live
19
20
  from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
20
21
  from scipy import stats
21
22
  from scipy.optimize import nnls
23
+ from seaborn import PairGrid
22
24
  from sklearn.linear_model import LinearRegression
23
25
  from sparsecca import lp_pmd, multicca_permute, multicca_pmd
24
26
  from statsmodels.sandbox.stats.multicomp import multipletests
25
27
 
28
+ if TYPE_CHECKING:
29
+ from matplotlib.axes import Axes
30
+ from matplotlib.figure import Figure
31
+
26
32
 
27
33
  class Dialogue:
28
34
  """Python implementation of DIALOGUE"""
@@ -53,8 +59,6 @@ class Dialogue:
53
59
 
54
60
  Copied from `https://github.com/schillerlab/sc-toolbox/blob/397e80dc5e8fb8017b75f6c3fa634a1e1213d484/sc_toolbox/tools/__init__.py#L458`
55
61
 
56
- # TODO: Replace with decoupler's implementation
57
-
58
62
  Args:
59
63
  groupby: The key to groupby for pseudobulks
60
64
  strategy: The pseudobulking strategy. One of "median" or "mean"
@@ -62,14 +66,15 @@ class Dialogue:
62
66
  Returns:
63
67
  A Pandas DataFrame of pseudobulk counts
64
68
  """
69
+ # TODO: Replace with decoupler's implementation
65
70
  pseudobulk = {"Genes": adata.var_names.values}
66
71
 
67
72
  for category in adata.obs.loc[:, groupby].cat.categories:
68
73
  temp = adata.obs.loc[:, groupby] == category
69
74
  if strategy == "median":
70
- pseudobulk[category] = adata[temp].X.median(axis=0).A1
75
+ pseudobulk[category] = adata[temp].X.median(axis=0)
71
76
  elif strategy == "mean":
72
- pseudobulk[category] = adata[temp].X.mean(axis=0).A1
77
+ pseudobulk[category] = adata[temp].X.mean(axis=0)
73
78
 
74
79
  pseudobulk = pd.DataFrame(pseudobulk).set_index("Genes")
75
80
 
@@ -101,8 +106,6 @@ class Dialogue:
101
106
  def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
102
107
  """Row-wise mean center and scale by the standard deviation.
103
108
 
104
- TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
105
-
106
109
  Args:
107
110
  pseudobulks: The pseudobulk PCA components.
108
111
  normalize: Whether to mimic DIALOGUE behavior or not.
@@ -110,9 +113,9 @@ class Dialogue:
110
113
  Returns:
111
114
  The scaled count matrix.
112
115
  """
116
+ # TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
113
117
  # DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
114
118
  # However, performing this scaling _does_ increase overall correlation of the end result
115
- # WHEN SAMPLE ORDER AND DIALOGUE2+3 PROCESSING IS IGNORED.
116
119
  if normalize:
117
120
  return pseudobulks.to_numpy()
118
121
  else:
@@ -288,7 +291,7 @@ class Dialogue:
288
291
  mcp_name: Name of mcp which was used for calculation of column value.
289
292
  max_length: Value needed to later decide at what index the threshold value should be extracted from column.
290
293
  min_threshold: Minimal threshold to select final scores by if it is smaller than calculated threshold.
291
- index: Column index to use eto calculate the significant genes. Defaults to `z_score`.
294
+ index: Column index to use eto calculate the significant genes.
292
295
 
293
296
  Returns:
294
297
  According to the values in a df column (default: zscore) the significant up and downregulated gene names
@@ -313,13 +316,13 @@ class Dialogue:
313
316
  def _apply_HLM_per_MCP_for_one_pair(
314
317
  self,
315
318
  mcp_name: str,
316
- scores_df: dict,
319
+ scores_df: pd.DataFrame,
317
320
  ct_data: AnnData,
318
321
  tme: pd.DataFrame,
319
322
  sig: dict,
320
323
  n_counts: str,
321
324
  formula: str,
322
- confounder: str,
325
+ confounder: str | None,
323
326
  ) -> tuple[pd.DataFrame, dict[str, Any]]:
324
327
  """Applies hierarchical modeling for a single MCP.
325
328
 
@@ -340,7 +343,7 @@ class Dialogue:
340
343
  """
341
344
  HLM_result = self._mixed_effects(
342
345
  scores=scores_df[[mcp_name]],
343
- x_labels=ct_data.obs[[n_counts, confounder]],
346
+ x_labels=ct_data.obs[[n_counts, confounder]] if confounder else ct_data.obs[[n_counts]],
344
347
  tme=tme,
345
348
  genes_in_mcp=list(sig[mcp_name]["up"]) + list(sig[mcp_name]["down"]),
346
349
  formula=formula,
@@ -367,19 +370,13 @@ class Dialogue:
367
370
  return np.array(resid)
368
371
 
369
372
  def _iterative_nnls(self, A_orig: np.ndarray, y_orig: np.ndarray, feature_ranks: list[int], n_iter: int = 1000):
370
- """Solves non-negative least squares separately for different feature categories.
373
+ """Solves non-negative least-squares separately for different feature categories.
371
374
 
372
375
  Mimics DLG.iterative.nnls.
373
376
  Variables are notated according to:
374
377
 
375
378
  `argmin|Ax - y|`
376
379
 
377
- Args:
378
- A_orig:
379
- y_orig:
380
- feature_ranks:
381
- n_iter: Passed to scipy.optimize.nnls. Defaults to 1000.
382
-
383
380
  Returns:
384
381
  Returns the aggregated coefficients from nnls.
385
382
  """
@@ -398,7 +395,7 @@ class Dialogue:
398
395
 
399
396
  x_final = np.zeros(A_orig.shape[0])
400
397
  Ax = np.zeros(A_orig.shape[1])
401
- for _, mask in zip(sig_ranks, masks):
398
+ for _, mask in zip(sig_ranks, masks, strict=False):
402
399
  A = A_orig[mask].T
403
400
  coef_nnls, _ = nnls(A, y, maxiter=n_iter)
404
401
  y = y - A @ coef_nnls # residuals
@@ -516,8 +513,8 @@ class Dialogue:
516
513
  # TODO: probably format the up and down within get_top_elements
517
514
  cca_sig: dict[str, Any] = defaultdict(dict)
518
515
  for i in range(0, int(len(cca_sig_unformatted) / 2)):
519
- cca_sig[f"MCP{i + 1}"]["up"] = cca_sig_unformatted[i * 2]
520
- cca_sig[f"MCP{i + 1}"]["down"] = cca_sig_unformatted[i * 2 + 1]
516
+ cca_sig[f"MCP{i}"]["up"] = cca_sig_unformatted[i * 2]
517
+ cca_sig[f"MCP{i}"]["down"] = cca_sig_unformatted[i * 2 + 1]
521
518
 
522
519
  cca_sig = dict(cca_sig)
523
520
  cca_sig_results[ct] = cca_sig
@@ -555,7 +552,7 @@ class Dialogue:
555
552
 
556
553
  return cca_sig_results, new_mcp_scores
557
554
 
558
- def load(
555
+ def _load(
559
556
  self,
560
557
  adata: AnnData,
561
558
  ct_order: list[str],
@@ -569,21 +566,11 @@ class Dialogue:
569
566
  Args:
570
567
  adata: AnnData object generate celltype objects for
571
568
  ct_order: The order of cell types
572
- agg_pca: Whether to aggregate pseudobulks with PCA or not. Defaults to True.
573
- normalize: Whether to mimic DIALOGUE behavior or not. Defaults to True.
569
+ agg_pca: Whether to aggregate pseudobulks with PCA or not.
570
+ normalize: Whether to mimic DIALOGUE behavior or not.
574
571
 
575
572
  Returns:
576
573
  A celltype_label:array dictionary.
577
-
578
- Examples:
579
- >>> import pertpy as pt
580
- >>> import scanpy as sc
581
- >>> adata = pt.dt.dialogue_example()
582
- >>> sc.pp.pca(adata)
583
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
584
- n_counts_key = "nCount_RNA", n_mpcs = 3)
585
- >>> cell_types = adata.obs[dl.celltype_key].astype("category").cat.categories
586
- >>> mcca_in, ct_subs = dl.load(adata, ct_order=cell_types)
587
574
  """
588
575
  ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
589
576
  fn = self._pseudobulk_pca if agg_pca else self._get_pseudobulks
@@ -620,7 +607,6 @@ class Dialogue:
620
607
  agg_pca: Whether to calculate cell-averaged PCA components.
621
608
  solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
622
609
  For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
623
- Defaults to 'bs'.
624
610
  normalize: Whether to mimic DIALOGUE as close as possible
625
611
 
626
612
  Returns:
@@ -631,25 +617,31 @@ class Dialogue:
631
617
  >>> import scanpy as sc
632
618
  >>> adata = pt.dt.dialogue_example()
633
619
  >>> sc.pp.pca(adata)
634
- >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
635
- n_counts_key = "nCount_RNA", n_mpcs = 3)
620
+ >>> dl = pt.tl.Dialogue(
621
+ ... sample_id="clinical.status", celltype_key="cell.subtypes", n_counts_key="nCount_RNA", n_mpcs=3
622
+ ... )
636
623
  >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
637
624
  """
638
- # IMPORTANT NOTE: the order in which matrices are passed to multicca matters. As such,
639
- # it is important here that to obtain the same result as in R, we pass the matrices in
640
- # in the same order.
625
+ # IMPORTANT NOTE: the order in which matrices are passed to multicca matters.
626
+ # As such, it is important here that to obtain the same result as in R, we pass the matrices in the same order.
641
627
  if ct_order is not None:
642
628
  cell_types = ct_order
643
629
  else:
644
630
  ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
645
631
 
646
- mcca_in, ct_subs = self.load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
632
+ mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_pca=agg_pca, normalize=normalize)
647
633
 
648
634
  n_samples = mcca_in[0].shape[1]
649
635
  if penalties is None:
650
- penalties = multicca_permute(
651
- mcca_in, penalties=np.sqrt(n_samples) / 2, nperms=10, niter=50, standardize=True
652
- )["bestpenalties"]
636
+ try:
637
+ penalties = multicca_permute(
638
+ mcca_in, penalties=np.sqrt(n_samples) / 2, nperms=10, niter=50, standardize=True
639
+ )["bestpenalties"]
640
+ except ValueError as e:
641
+ if "matmul: input operand 1 has a mismatch in its core dimension" in str(e):
642
+ raise ValueError("Please ensure that every cell type is represented in every sample.") from e
643
+ else:
644
+ raise
653
645
  else:
654
646
  penalties = penalties
655
647
 
@@ -685,7 +677,7 @@ class Dialogue:
685
677
  ct_subs: dict,
686
678
  mcp_scores: dict,
687
679
  ws_dict: dict,
688
- confounder: str,
680
+ confounder: str | None,
689
681
  formula: str = None,
690
682
  ):
691
683
  """Runs the multilevel modeling step to match genes to MCPs and generate p-values for MCPs.
@@ -700,7 +692,6 @@ class Dialogue:
700
692
  A Pandas DataFrame containing:
701
693
  - for each mcp: HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2
702
694
  - merged HLM_result_1, HLM_result_2, sig_genes_1, sig_genes_2 of all mcps
703
- TODO: Describe both returns
704
695
 
705
696
  Examples:
706
697
  >>> import pertpy as pt
@@ -713,7 +704,9 @@ class Dialogue:
713
704
  >>> all_results, new_mcps = dl.multilevel_modeling(ct_subs=ct_subs, mcp_scores=mcps, ws_dict=ws, \
714
705
  confounder="gender")
715
706
  """
716
- # all possible pairs of cell types with out pairing same cell type
707
+ # TODO the returns of the function better
708
+
709
+ # all possible pairs of cell types without pairing same cell type
717
710
  cell_types = list(ct_subs.keys())
718
711
  pairs = list(itertools.combinations(cell_types, 2))
719
712
 
@@ -721,9 +714,9 @@ class Dialogue:
721
714
  formula = f"y ~ x + {self.n_counts_key}"
722
715
 
723
716
  # Hierarchical modeling expects DataFrames
724
- mcp_cell_types = {f"MCP{i + 1}": cell_types for i in range(self.n_mcps)}
717
+ mcp_cell_types = {f"MCP{i}": cell_types for i in range(self.n_mcps)}
725
718
  mcp_scores_df = {
726
- ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=mcp_cell_types.keys())
719
+ ct: pd.DataFrame(v, index=ct_subs[ct].obs.index, columns=list(mcp_cell_types.keys()))
727
720
  for ct, v in mcp_scores.items()
728
721
  }
729
722
 
@@ -762,10 +755,10 @@ class Dialogue:
762
755
  mcps.append(mcp)
763
756
 
764
757
  if len(mcps) == 0:
765
- print(f"[bold red]No shared MCPs between {cell_type_1} and {cell_type_2}.")
758
+ logger.warning(f"No shared MCPs between {cell_type_1} and {cell_type_2}.")
766
759
  continue
767
760
 
768
- print(f"[bold blue]{len(mcps)} MCPs identified for {cell_type_1} and {cell_type_2}.")
761
+ logger.info(f"{len(mcps)} MCPs identified for {cell_type_1} and {cell_type_2}.")
769
762
 
770
763
  new_mcp_scores: dict[Any, list[Any]]
771
764
  cca_sig, new_mcp_scores = self._calculate_cca_sig(
@@ -805,7 +798,7 @@ class Dialogue:
805
798
  for mcp in mcps:
806
799
  mixed_model_progress.update(mm_task, description=f"[bold blue]Determining mixed effects for {mcp}")
807
800
 
808
- # TODO Check that the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
801
+ # TODO Check whether the genes in result{sig_genes_1] are different and if so note that somewhere and explain why
809
802
  result = {}
810
803
  result["HLM_result_1"], result["sig_genes_1"] = self._apply_HLM_per_MCP_for_one_pair(
811
804
  mcp_name=mcp,
@@ -875,22 +868,19 @@ class Dialogue:
875
868
  sample_label = self.sample_id
876
869
  n_mcps = self.n_mcps
877
870
 
878
- # create conditions_compare if not supplied
879
871
  if conditions_compare is None:
880
- conditions_compare = list(adata.obs["path_str"].cat.categories) # type: ignore
872
+ conditions_compare = list(adata.obs[condition_label].cat.categories) # type: ignore
881
873
  if len(conditions_compare) != 2:
882
874
  raise ValueError("Please specify conditions to compare or supply an object with only 2 conditions")
883
875
 
884
- # create data frames to store results
885
876
  pvals = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
886
877
  tstats = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
887
878
  pvals_adj = pd.DataFrame(1, adata.obs[celltype_label].unique(), ["mcp_" + str(n) for n in range(0, n_mcps)])
888
879
 
889
880
  response = adata.obs.groupby(sample_label)[condition_label].agg(pd.Series.mode)
890
881
  for celltype in adata.obs[celltype_label].unique():
891
- # subset data to cell type
892
882
  df = adata.obs[adata.obs[celltype_label] == celltype]
893
- # run t-test for each MCP
883
+
894
884
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
895
885
  mns = df.groupby(sample_label)[mcpnum].mean()
896
886
  mns = pd.concat([mns, response], axis=1)
@@ -900,11 +890,10 @@ class Dialogue:
900
890
  )
901
891
  pvals.loc[celltype, mcpnum] = res[1]
902
892
  tstats.loc[celltype, mcpnum] = res[0]
903
- # return(res)
904
893
 
905
- # benjamini-hochberg correction for number of cell types (use BH because correlated MCPs)
906
894
  for mcpnum in ["mcp_" + str(n) for n in range(0, n_mcps)]:
907
895
  pvals_adj[mcpnum] = multipletests(pvals[mcpnum], method="fdr_bh")[1]
896
+
908
897
  return {"pvals": pvals, "tstats": tstats, "pvals_adj": pvals_adj}
909
898
 
910
899
  def get_mlm_mcp_genes(
@@ -921,10 +910,8 @@ class Dialogue:
921
910
  celltype: Cell type of interest.
922
911
  results: dl.MultilevelModeling result object.
923
912
  MCP: MCP key of the result object.
924
- threshhold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
925
- Defaults to 0.70.
913
+ threshold: Number between [0,1]. The fraction of cell types compared against which must have the associated MCP gene.
926
914
  focal_celltypes: None (compare against all cell types) or a list of other cell types which you want to compare against.
927
- Defaults to None.
928
915
 
929
916
  Returns:
930
917
  Dict with keys 'up_genes' and 'down_genes' and values of lists of genes
@@ -945,7 +932,6 @@ class Dialogue:
945
932
  # REMOVE THIS BLOCK ONCE MLM OUTPUT MATCHES STANDARD
946
933
  if MCP.startswith("mcp_"):
947
934
  MCP = MCP.replace("mcp_", "MCP")
948
- # convert from MCPx to MCPx+1
949
935
  MCP = "MCP" + str(int(MCP[3:]) - 1)
950
936
 
951
937
  # Extract all comparison keys from the results object
@@ -1004,27 +990,24 @@ class Dialogue:
1004
990
  Args:
1005
991
  ct_subs: Dialogue output ct_subs dictionary
1006
992
  mcp: The name of the marker gene expression column.
1007
- Defaults to "mcp_0".
1008
993
  fraction: Fraction of extreme cells to consider for gene ranking.
1009
994
  Should be between 0 and 1.
1010
- Defaults to 0.1.
1011
995
 
1012
996
  Returns:
1013
997
  Dictionary where keys are subpopulation names and values are Anndata
1014
998
  objects containing the results of gene ranking analysis.
1015
999
 
1016
1000
  Examples:
1017
- ct_subs = {
1018
- "subpop1": anndata_obj1,
1019
- "subpop2": anndata_obj2,
1020
- # ... more subpopulations ...
1021
- }
1022
- genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1001
+ >>> ct_subs = {
1002
+ ... "subpop1": anndata_obj1,
1003
+ ... "subpop2": anndata_obj2,
1004
+ ... # ... more subpopulations ...
1005
+ ... }
1006
+ >>> genes_results = _get_extrema_MCP_genes_single(ct_subs, mcp="mcp_4", fraction=0.2)
1023
1007
  """
1024
1008
  genes = {}
1025
1009
  for ct in ct_subs.keys():
1026
1010
  mini = ct_subs[ct]
1027
- mini.obs[mcp]
1028
1011
  mini.obs["extrema"] = pd.qcut(
1029
1012
  mini.obs[mcp],
1030
1013
  [0, 0 + fraction, 1 - fraction, 1.0],
@@ -1034,6 +1017,7 @@ class Dialogue:
1034
1017
  mini, "extrema", groups=["high" + mcp + " " + ct], reference="low " + mcp + " " + ct
1035
1018
  )
1036
1019
  genes[ct] = mini # .uns['rank_genes_groups']
1020
+
1037
1021
  return genes
1038
1022
 
1039
1023
  def get_extrema_MCP_genes(self, ct_subs: dict, fraction: float = 0.1):
@@ -1046,7 +1030,7 @@ class Dialogue:
1046
1030
  Args:
1047
1031
  ct_subs: Dialogue output ct_subs dictionary
1048
1032
  fraction: Fraction of extreme cells to consider for gene ranking.
1049
- Should be between 0 and 1. Defaults to 0.1.
1033
+ Should be between 0 and 1.
1050
1034
 
1051
1035
  Returns:
1052
1036
  Nested dictionary where keys of the first level are MCPs (of the form "mcp_0" etc)
@@ -1064,7 +1048,7 @@ class Dialogue:
1064
1048
  >>> extrema_mcp_genes = dl.get_extrema_MCP_genes(ct_subs)
1065
1049
  """
1066
1050
  rank_dfs: dict[str, dict[Any, Any]] = {}
1067
- _, ct_sub = next(iter(ct_subs.items()))
1051
+ ct_sub = next(iter(ct_subs.values()))
1068
1052
  mcps = [col for col in ct_sub.obs.columns if col.startswith("mcp_")]
1069
1053
 
1070
1054
  for mcp in mcps:
@@ -1072,4 +1056,123 @@ class Dialogue:
1072
1056
  ct_ranked = self._get_extrema_MCP_genes_single(ct_subs, mcp=mcp, fraction=fraction)
1073
1057
  for celltype in ct_ranked.keys():
1074
1058
  rank_dfs[mcp][celltype] = sc.get.rank_genes_groups_df(ct_ranked[celltype], group=None)
1059
+
1075
1060
  return rank_dfs
1061
+
1062
+ def plot_split_violins(
1063
+ self,
1064
+ adata: AnnData,
1065
+ split_key: str,
1066
+ celltype_key: str,
1067
+ split_which: tuple[str, str] = None,
1068
+ mcp: str = "mcp_0",
1069
+ return_fig: bool | None = None,
1070
+ ax: Axes | None = None,
1071
+ save: bool | str | None = None,
1072
+ show: bool | None = None,
1073
+ ) -> Axes | Figure | None:
1074
+ """Plots split violin plots for a given MCP and split variable.
1075
+
1076
+ Any cells with a value for split_key not in split_which are removed from the plot.
1077
+
1078
+ Args:
1079
+ adata: Annotated data object.
1080
+ split_key: Variable in adata.obs used to split the data.
1081
+ celltype_key: Key for cell type annotations.
1082
+ split_which: Which values of split_key to plot. Required if more than 2 values in split_key.
1083
+ mcp: Key for MCP data.
1084
+
1085
+ Returns:
1086
+ A :class:`~matplotlib.axes.Axes` object
1087
+
1088
+ Examples:
1089
+ >>> import pertpy as pt
1090
+ >>> import scanpy as sc
1091
+ >>> adata = pt.dt.dialogue_example()
1092
+ >>> sc.pp.pca(adata)
1093
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1094
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1095
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1096
+ >>> dl.plot_split_violins(adata, split_key='gender', celltype_key='cell.subtypes')
1097
+
1098
+ Preview:
1099
+ .. image:: /_static/docstring_previews/dialogue_violin.png
1100
+ """
1101
+ df = sc.get.obs_df(adata, [celltype_key, mcp, split_key])
1102
+ if split_which is None:
1103
+ split_which = df[split_key].unique()
1104
+ df = df[df[split_key].isin(split_which)]
1105
+ df[split_key] = df[split_key].cat.remove_unused_categories()
1106
+
1107
+ ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
1108
+
1109
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1110
+
1111
+ if save:
1112
+ plt.savefig(save, bbox_inches="tight")
1113
+ if show:
1114
+ plt.show()
1115
+ if return_fig:
1116
+ return plt.gcf()
1117
+ if not (show or save):
1118
+ return ax
1119
+ return None
1120
+
1121
+ def plot_pairplot(
1122
+ self,
1123
+ adata: AnnData,
1124
+ celltype_key: str,
1125
+ color: str,
1126
+ sample_id: str,
1127
+ mcp: str = "mcp_0",
1128
+ return_fig: bool | None = None,
1129
+ show: bool | None = None,
1130
+ save: bool | str | None = None,
1131
+ ) -> PairGrid | Figure | None:
1132
+ """Generate a pairplot visualization for multi-cell perturbation (MCP) data.
1133
+
1134
+ Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type,
1135
+ then creates a pairplot to visualize the relationships between these mean MCP values.
1136
+
1137
+ Args:
1138
+ adata: Annotated data object.
1139
+ celltype_key: Key in `adata.obs` containing cell type annotations.
1140
+ color: Key in `adata.obs` for color annotations. This parameter is used as the hue
1141
+ sample_id: Key in `adata.obs` for the sample annotations.
1142
+ mcp: Key in `adata.obs` for MCP feature values.
1143
+
1144
+ Returns:
1145
+ Seaborn Pairgrid object.
1146
+
1147
+ Examples:
1148
+ >>> import pertpy as pt
1149
+ >>> import scanpy as sc
1150
+ >>> adata = pt.dt.dialogue_example()
1151
+ >>> sc.pp.pca(adata)
1152
+ >>> dl = pt.tl.Dialogue(sample_id = "clinical.status", celltype_key = "cell.subtypes", \
1153
+ n_counts_key = "nCount_RNA", n_mpcs = 3)
1154
+ >>> adata, mcps, ws, ct_subs = dl.calculate_multifactor_PMD(adata, normalize=True)
1155
+ >>> dl.plot_pairplot(adata, celltype_key="cell.subtypes", color="gender", sample_id="clinical.status")
1156
+
1157
+ Preview:
1158
+ .. image:: /_static/docstring_previews/dialogue_pairplot.png
1159
+ """
1160
+ mean_mcps = adata.obs.groupby([sample_id, celltype_key])[mcp].mean()
1161
+ mean_mcps = mean_mcps.reset_index()
1162
+ mcp_pivot = pd.pivot(mean_mcps[[sample_id, celltype_key, mcp]], index=sample_id, columns=celltype_key)[mcp]
1163
+
1164
+ aggstats = adata.obs.groupby([sample_id])[color].describe()
1165
+ aggstats = aggstats.loc[list(mcp_pivot.index), :]
1166
+ aggstats[color] = aggstats["top"]
1167
+ mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
1168
+ ax = sns.pairplot(mcp_pivot, hue=color, corner=True)
1169
+
1170
+ if save:
1171
+ plt.savefig(save, bbox_inches="tight")
1172
+ if show:
1173
+ plt.show()
1174
+ if return_fig:
1175
+ return plt.gcf()
1176
+ if not (show or save):
1177
+ return ax
1178
+ return None
@@ -0,0 +1,20 @@
1
+ from ._base import ContrastType, LinearModelBase, MethodBase
2
+ from ._dge_comparison import DGEEVAL
3
+ from ._edger import EdgeR
4
+ from ._pydeseq2 import PyDESeq2
5
+ from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
6
+ from ._statsmodels import Statsmodels
7
+
8
+ __all__ = [
9
+ "MethodBase",
10
+ "LinearModelBase",
11
+ "EdgeR",
12
+ "PyDESeq2",
13
+ "Statsmodels",
14
+ "SimpleComparisonBase",
15
+ "WilcoxonTest",
16
+ "TTest",
17
+ "ContrastType",
18
+ ]
19
+
20
+ AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]