pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_augur.py CHANGED
@@ -4,14 +4,17 @@ import random
4
4
  from collections import defaultdict
5
5
  from dataclasses import dataclass
6
6
  from math import floor, nan
7
- from typing import Any, Literal
7
+ from typing import TYPE_CHECKING, Any, Literal
8
8
 
9
+ import anndata as ad
10
+ import matplotlib.pyplot as plt
9
11
  import numpy as np
10
12
  import pandas as pd
11
13
  import scanpy as sc
12
14
  import statsmodels.api as sm
13
15
  from anndata import AnnData
14
16
  from joblib import Parallel, delayed
17
+ from lamin_utils import logger
15
18
  from rich import print
16
19
  from rich.progress import track
17
20
  from scipy import sparse, stats
@@ -34,6 +37,10 @@ from sklearn.preprocessing import LabelEncoder
34
37
  from skmisc.loess import loess
35
38
  from statsmodels.stats.multitest import fdrcorrection
36
39
 
40
+ if TYPE_CHECKING:
41
+ from matplotlib.axes import Axes
42
+ from matplotlib.figure import Figure
43
+
37
44
 
38
45
  @dataclass
39
46
  class Params:
@@ -121,7 +128,7 @@ class Augur:
121
128
  _ = input[cell_type_col]
122
129
  _ = input[label_col]
123
130
  except KeyError:
124
- print("[bold red]No column names matching cell_type_col and label_col.")
131
+ logger.error("No column names matching cell_type_col and label_col.")
125
132
 
126
133
  label = input[label_col] if meta is None else meta[label_col]
127
134
  cell_type = input[cell_type_col] if meta is None else meta[cell_type_col]
@@ -134,9 +141,9 @@ class Augur:
134
141
  if adata.obs["label"].dtype.name == "category":
135
142
  # filter samples according to label
136
143
  if condition_label is not None and treatment_label is not None:
137
- print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
138
- adata = AnnData.concatenate(
139
- adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
144
+ logger.info(f"Filtering samples with {condition_label} and {treatment_label} labels.")
145
+ adata = ad.concat(
146
+ [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
140
147
  )
141
148
  label_encoder = LabelEncoder()
142
149
  adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
@@ -214,7 +221,9 @@ class Augur:
214
221
  >>> loaded_data = ag_rfc.load(adata)
215
222
  >>> ag_rfc.select_highly_variable(loaded_data)
216
223
  >>> features = loaded_data.var_names
217
- >>> subsample = ag_rfc.sample(loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names)
224
+ >>> subsample = ag_rfc.sample(
225
+ ... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
226
+ ... )
218
227
  """
219
228
  # export subsampling.
220
229
  random.seed(random_state)
@@ -230,7 +239,7 @@ class Augur:
230
239
  random_state=random_state,
231
240
  )
232
241
  )
233
- subsample = AnnData.concatenate(*label_subsamples, index_unique=None)
242
+ subsample = ad.concat([*label_subsamples], index_unique=None)
234
243
  else:
235
244
  subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
236
245
 
@@ -409,8 +418,8 @@ class Augur:
409
418
  """
410
419
  if multiclass:
411
420
  return {
412
- "augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
413
- "auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
421
+ "augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
422
+ "auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
414
423
  "accuracy": make_scorer(accuracy_score),
415
424
  "precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
416
425
  "f1": make_scorer(f1_score, average="macro"),
@@ -418,8 +427,8 @@ class Augur:
418
427
  }
419
428
  return (
420
429
  {
421
- "augur_score": make_scorer(roc_auc_score, needs_proba=True),
422
- "auc": make_scorer(roc_auc_score, needs_proba=True),
430
+ "augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
431
+ "auc": make_scorer(roc_auc_score, response_method="predict_proba"),
423
432
  "accuracy": make_scorer(accuracy_score),
424
433
  "precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
425
434
  "f1": make_scorer(f1_score, average="binary"),
@@ -488,7 +497,7 @@ class Augur:
488
497
  # feature importances
489
498
  feature_importances = defaultdict(list)
490
499
  if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
491
- for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
500
+ for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
492
501
  feature_importances["genes"].extend(x.columns.tolist())
493
502
  feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
494
503
  feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
@@ -497,7 +506,7 @@ class Augur:
497
506
  # standardized coefficients with Agresti method
498
507
  # cf. https://think-lab.github.io/d/205/#3
499
508
  if isinstance(self.estimator, LogisticRegression):
500
- for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
509
+ for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
501
510
  feature_importances["genes"].extend(x.columns.tolist())
502
511
  feature_importances["feature_importances"].extend(
503
512
  (self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
@@ -548,7 +557,7 @@ class Augur:
548
557
  try:
549
558
  sc.pp.highly_variable_genes(adata)
550
559
  except ValueError:
551
- print("[bold yellow]Data not normalized. Normalizing now using scanpy log1p normalize.")
560
+ logger.warn("Data not normalized. Normalizing now using scanpy log1p normalize.")
552
561
  sc.pp.log1p(adata)
553
562
  sc.pp.highly_variable_genes(adata)
554
563
 
@@ -600,7 +609,7 @@ class Augur:
600
609
  var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
601
610
  filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
602
611
  span: Smoothing factor, as a fraction of the number of points to take into account.
603
- Should be in the range (0, 1]. Defaults to 0.75
612
+ Should be in the range (0, 1].
604
613
 
605
614
  Return:
606
615
  AnnData object with additional select_variance column in var.
@@ -692,13 +701,11 @@ class Augur:
692
701
  feature_perc: proportion of genes that are randomly selected as features for input to the classifier in each
693
702
  subsample using the random gene filter
694
703
  var_quantile: The quantile below which features will be filtered, based on their residuals in a loess model.
695
- Defaults to 0.5.
696
704
  span: Smoothing factor, as a fraction of the number of points to take into account. Should be in the range (0, 1].
697
- Defaults to 0.75.
698
705
  filter_negative_residuals: if `True`, filter residuals at a fixed threshold of zero, instead of `var_quantile`
699
706
  n_threads: number of threads to use for parallelization
700
707
  select_variance_features: Whether to select genes based on the original Augur implementation (True)
701
- or using scanpy's highly_variable_genes (False). Defaults to True.
708
+ or using scanpy's highly_variable_genes (False).
702
709
  key_added: Key to add results to in .uns
703
710
  augur_mode: One of 'default', 'velocity' or 'permute'. Setting augur_mode = "velocity" disables feature selection,
704
711
  assuming feature selection has been performed by the RNA velocity procedure to produce the input matrix,
@@ -723,6 +730,7 @@ class Augur:
723
730
  >>> loaded_data = ag_rfc.load(adata)
724
731
  >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
725
732
  """
733
+ adata = adata.copy()
726
734
  if augur_mode == "permute" and n_subsamples < 100:
727
735
  n_subsamples = 500
728
736
  if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
@@ -742,8 +750,8 @@ class Augur:
742
750
  "full_results": defaultdict(list),
743
751
  }
744
752
  if select_variance_features:
745
- print("[bold yellow]Set smaller span value in the case of a `segmentation fault` error.")
746
- print("[bold yellow]Set larger span in case of svddc or other near singularities error.")
753
+ logger.warning("Set smaller span value in the case of a `segmentation fault` error.")
754
+ logger.warning("Set larger span in case of svddc or other near singularities error.")
747
755
  adata.obs["augur_score"] = nan
748
756
  for cell_type in track(adata.obs["cell_type"].unique(), description="Processing data..."):
749
757
  cell_type_subsample = adata[adata.obs["cell_type"] == cell_type].copy()
@@ -759,17 +767,18 @@ class Augur:
759
767
  )
760
768
  )
761
769
  if len(cell_type_subsample) < min_cells:
762
- print(
763
- f"[bold red]Skipping {cell_type} cell type - {len(cell_type_subsample)} samples is less than min_cells {min_cells}."
770
+ logger.warning(
771
+ f"Skipping {cell_type} cell type - {len(cell_type_subsample)} samples is less than min_cells {min_cells}."
764
772
  )
765
773
  elif (
766
774
  cell_type_subsample.obs.groupby(
767
775
  ["cell_type", "label"],
776
+ observed=True,
768
777
  ).y_.count()
769
778
  < subsample_size
770
779
  ).any():
771
- print(
772
- f"[bold red]Skipping {cell_type} cell type - the number of samples for at least one class type is less than "
780
+ logger.warning(
781
+ f"Skipping {cell_type} cell type - the number of samples for at least one class type is less than "
773
782
  f"subsample size {subsample_size}."
774
783
  )
775
784
  else:
@@ -804,14 +813,14 @@ class Augur:
804
813
  * (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
805
814
  )
806
815
 
807
- for idx, cv in zip(range(n_subsamples), results[cell_type]):
816
+ for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
808
817
  results["full_results"]["idx"].extend([idx] * folds)
809
818
  results["full_results"]["augur_score"].extend(cv["test_augur_score"])
810
819
  results["full_results"]["folds"].extend(range(folds))
811
820
  results["full_results"]["cell_type"].extend([cell_type] * folds * n_subsamples)
812
821
  # make sure one cell type worked
813
822
  if len(results) <= 2:
814
- print("[bold red]No cells types had more than min_cells needed. Please adjust data or min_cells parameter.")
823
+ logger.warning("No cells types had more than min_cells needed. Please adjust data or min_cells parameter.")
815
824
 
816
825
  results["summary_metrics"] = pd.DataFrame(results["summary_metrics"])
817
826
  results["feature_importances"] = pd.DataFrame(results["feature_importances"])
@@ -840,7 +849,7 @@ class Augur:
840
849
  augur2: Augurpy results from condition 2, obtained from `predict()[1]`
841
850
  permuted1: permuted Augurpy results from condition 1, obtained from `predict()` with argument `augur_mode=permute`
842
851
  permuted2: permuted Augurpy results from condition 2, obtained from `predict()` with argument `augur_mode=permute`
843
- n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation; Defaults to 50.
852
+ n_subsamples: number of subsamples to pool when calculating the mean augur score for each permutation.
844
853
  n_permutations: the total number of mean augur scores to calculate from a background distribution
845
854
 
846
855
  Returns:
@@ -869,28 +878,31 @@ class Augur:
869
878
  & set(permuted_results1["summary_metrics"].columns)
870
879
  & set(permuted_results2["summary_metrics"].columns)
871
880
  )
881
+
882
+ cell_types_list = list(cell_types)
883
+
872
884
  # mean augur scores
873
885
  augur_score1 = (
874
886
  augur_results1["summary_metrics"]
875
- .loc["mean_augur_score", cell_types]
887
+ .loc["mean_augur_score", cell_types_list]
876
888
  .reset_index()
877
889
  .rename(columns={"index": "cell_type"})
878
890
  )
879
891
  augur_score2 = (
880
892
  augur_results2["summary_metrics"]
881
- .loc["mean_augur_score", cell_types]
893
+ .loc["mean_augur_score", cell_types_list]
882
894
  .reset_index()
883
895
  .rename(columns={"index": "cell_type"})
884
896
  )
885
897
 
886
898
  # mean permuted scores over cross validation runs
887
899
  permuted_cv_augur1 = (
888
- permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types)]
900
+ permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
889
901
  .groupby(["cell_type", "idx"], as_index=False)
890
902
  .mean()
891
903
  )
892
904
  permuted_cv_augur2 = (
893
- permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types)]
905
+ permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
894
906
  .groupby(["cell_type", "idx"], as_index=False)
895
907
  .mean()
896
908
  )
@@ -901,7 +913,7 @@ class Augur:
901
913
  # draw mean aucs for permute1 and permute2
902
914
  for celltype in permuted_cv_augur1["cell_type"].unique():
903
915
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
904
- df2 = permuted_cv_augur2[permuted_cv_augur1["cell_type"] == celltype]
916
+ df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
905
917
  for permutation_idx in range(n_permutations):
906
918
  # subsample
907
919
  sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
@@ -961,3 +973,288 @@ class Augur:
961
973
  delta["padj"] = fdrcorrection(delta["pval"])[1]
962
974
 
963
975
  return delta
976
+
977
+ def plot_dp_scatter(
978
+ self,
979
+ results: pd.DataFrame,
980
+ top_n: int = None,
981
+ return_fig: bool | None = None,
982
+ ax: Axes = None,
983
+ show: bool | None = None,
984
+ save: str | bool | None = None,
985
+ ) -> Axes | Figure | None:
986
+ """Plot scatterplot of differential prioritization.
987
+
988
+ Args:
989
+ results: Results after running differential prioritization.
990
+ top_n: optionally, the number of top prioritized cell types to label in the plot
991
+ ax: optionally, axes used to draw plot
992
+
993
+ Returns:
994
+ Axes of the plot.
995
+
996
+ Examples:
997
+ >>> import pertpy as pt
998
+ >>> adata = pt.dt.bhattacherjee()
999
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1000
+
1001
+ >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
1002
+ >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
1003
+ >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1004
+
1005
+ >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
1006
+ >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
1007
+ >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1008
+
1009
+ >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
1010
+ permuted_results1=results_15_permute, permuted_results2=results_48_permute)
1011
+ >>> ag_rfc.plot_dp_scatter(pvals)
1012
+
1013
+ Preview:
1014
+ .. image:: /_static/docstring_previews/augur_dp_scatter.png
1015
+ """
1016
+ x = results["mean_augur_score1"]
1017
+ y = results["mean_augur_score2"]
1018
+
1019
+ if ax is None:
1020
+ fig, ax = plt.subplots()
1021
+ scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
1022
+
1023
+ # adding optional labels
1024
+ top_n_index = results.sort_values(by="pval").index[:top_n]
1025
+ for idx in top_n_index:
1026
+ ax.annotate(
1027
+ results.loc[idx, "cell_type"],
1028
+ (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
1029
+ )
1030
+
1031
+ # add diagonal
1032
+ limits = max(ax.get_xlim(), ax.get_ylim())
1033
+ (_,) = ax.plot(limits, limits, ls="--", c=".3")
1034
+
1035
+ # formatting and details
1036
+ plt.xlabel("Augur scores 1")
1037
+ plt.ylabel("Augur scores 2")
1038
+ legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1039
+ ax.add_artist(legend1)
1040
+
1041
+ if save:
1042
+ plt.savefig(save, bbox_inches="tight")
1043
+ if show:
1044
+ plt.show()
1045
+ if return_fig:
1046
+ return plt.gcf()
1047
+ if not (show or save):
1048
+ return ax
1049
+ return None
1050
+
1051
+ def plot_important_features(
1052
+ self,
1053
+ data: dict[str, Any],
1054
+ key: str = "augurpy_results",
1055
+ top_n: int = 10,
1056
+ return_fig: bool | None = None,
1057
+ ax: Axes = None,
1058
+ show: bool | None = None,
1059
+ save: str | bool | None = None,
1060
+ ) -> Axes | None:
1061
+ """Plot a lollipop plot of the n features with largest feature importances.
1062
+
1063
+ Args:
1064
+ results: results after running `predict()` as dictionary or the AnnData object.
1065
+ key: Key in the AnnData object of the results
1066
+ top_n: n number feature importance values to plot. Default is 10.
1067
+ ax: optionally, axes used to draw plot
1068
+ return_figure: if `True` returns figure of the plot, default is `False`
1069
+
1070
+ Returns:
1071
+ Axes of the plot.
1072
+
1073
+ Examples:
1074
+ >>> import pertpy as pt
1075
+ >>> adata = pt.dt.sc_sim_augur()
1076
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1077
+ >>> loaded_data = ag_rfc.load(adata)
1078
+ >>> v_adata, v_results = ag_rfc.predict(
1079
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1080
+ ... )
1081
+ >>> ag_rfc.plot_important_features(v_results)
1082
+
1083
+ Preview:
1084
+ .. image:: /_static/docstring_previews/augur_important_features.png
1085
+ """
1086
+ if isinstance(data, AnnData):
1087
+ results = data.uns[key]
1088
+ else:
1089
+ results = data
1090
+ n_features = (
1091
+ results["feature_importances"]
1092
+ .groupby("genes", as_index=False)
1093
+ .feature_importances.mean()
1094
+ .sort_values(by="feature_importances")[-top_n:]
1095
+ )
1096
+
1097
+ if ax is None:
1098
+ fig, ax = plt.subplots()
1099
+ y_axes_range = range(1, top_n + 1)
1100
+ ax.hlines(
1101
+ y_axes_range,
1102
+ xmin=0,
1103
+ xmax=n_features["feature_importances"],
1104
+ )
1105
+
1106
+ ax.plot(n_features["feature_importances"], y_axes_range, "o")
1107
+
1108
+ plt.xlabel("Mean Feature Importance")
1109
+ plt.ylabel("Gene")
1110
+ plt.yticks(y_axes_range, n_features["genes"])
1111
+
1112
+ if save:
1113
+ plt.savefig(save, bbox_inches="tight")
1114
+ if show:
1115
+ plt.show()
1116
+ if return_fig:
1117
+ return plt.gcf()
1118
+ if not (show or save):
1119
+ return ax
1120
+ return None
1121
+
1122
+ def plot_lollipop(
1123
+ self,
1124
+ data: dict[str, Any],
1125
+ key: str = "augurpy_results",
1126
+ return_fig: bool | None = None,
1127
+ ax: Axes = None,
1128
+ show: bool | None = None,
1129
+ save: str | bool | None = None,
1130
+ ) -> Axes | Figure | None:
1131
+ """Plot a lollipop plot of the mean augur values.
1132
+
1133
+ Args:
1134
+ results: results after running `predict()` as dictionary or the AnnData object.
1135
+ key: Key in the AnnData object of the results
1136
+ ax: optionally, axes used to draw plot
1137
+ return_figure: if `True` returns figure of the plot
1138
+
1139
+ Returns:
1140
+ Axes of the plot.
1141
+
1142
+ Examples:
1143
+ >>> import pertpy as pt
1144
+ >>> adata = pt.dt.sc_sim_augur()
1145
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1146
+ >>> loaded_data = ag_rfc.load(adata)
1147
+ >>> v_adata, v_results = ag_rfc.predict(
1148
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1149
+ ... )
1150
+ >>> ag_rfc.plot_lollipop(v_results)
1151
+
1152
+ Preview:
1153
+ .. image:: /_static/docstring_previews/augur_lollipop.png
1154
+ """
1155
+ if isinstance(data, AnnData):
1156
+ results = data.uns[key]
1157
+ else:
1158
+ results = data
1159
+ if ax is None:
1160
+ fig, ax = plt.subplots()
1161
+ y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
1162
+ ax.hlines(
1163
+ y_axes_range,
1164
+ xmin=0,
1165
+ xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1166
+ )
1167
+
1168
+ ax.plot(
1169
+ results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1170
+ y_axes_range,
1171
+ "o",
1172
+ )
1173
+
1174
+ plt.xlabel("Mean Augur Score")
1175
+ plt.ylabel("Cell Type")
1176
+ plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1177
+
1178
+ if save:
1179
+ plt.savefig(save, bbox_inches="tight")
1180
+ if show:
1181
+ plt.show()
1182
+ if return_fig:
1183
+ return plt.gcf()
1184
+ if not (show or save):
1185
+ return ax
1186
+ return None
1187
+
1188
+ def plot_scatterplot(
1189
+ self,
1190
+ results1: dict[str, Any],
1191
+ results2: dict[str, Any],
1192
+ top_n: int = None,
1193
+ return_fig: bool | None = None,
1194
+ show: bool | None = None,
1195
+ save: str | bool | None = None,
1196
+ ) -> Axes | Figure | None:
1197
+ """Create scatterplot with two augur results.
1198
+
1199
+ Args:
1200
+ results1: results after running `predict()`
1201
+ results2: results after running `predict()`
1202
+ top_n: optionally, the number of top prioritized cell types to label in the plot
1203
+ return_figure: if `True` returns figure of the plot
1204
+
1205
+ Returns:
1206
+ Axes of the plot.
1207
+
1208
+ Examples:
1209
+ >>> import pertpy as pt
1210
+ >>> adata = pt.dt.sc_sim_augur()
1211
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1212
+ >>> loaded_data = ag_rfc.load(adata)
1213
+ >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
1214
+ >>> v_adata, v_results = ag_rfc.predict(
1215
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1216
+ ... )
1217
+ >>> ag_rfc.plot_scatterplot(v_results, h_results)
1218
+
1219
+ Preview:
1220
+ .. image:: /_static/docstring_previews/augur_scatterplot.png
1221
+ """
1222
+ cell_types = results1["summary_metrics"].columns
1223
+
1224
+ fig, ax = plt.subplots()
1225
+ ax.scatter(
1226
+ results1["summary_metrics"].loc["mean_augur_score", cell_types],
1227
+ results2["summary_metrics"].loc["mean_augur_score", cell_types],
1228
+ )
1229
+
1230
+ # adding optional labels
1231
+ top_n_cell_types = (
1232
+ (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
1233
+ .sort_values(ascending=False)
1234
+ .index[:top_n]
1235
+ )
1236
+ for txt in top_n_cell_types:
1237
+ ax.annotate(
1238
+ txt,
1239
+ (
1240
+ results1["summary_metrics"].loc["mean_augur_score", txt],
1241
+ results2["summary_metrics"].loc["mean_augur_score", txt],
1242
+ ),
1243
+ )
1244
+
1245
+ # adding diagonal
1246
+ limits = max(ax.get_xlim(), ax.get_ylim())
1247
+ (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
1248
+
1249
+ plt.xlabel("Augur scores 1")
1250
+ plt.ylabel("Augur scores 2")
1251
+
1252
+ if save:
1253
+ plt.savefig(save, bbox_inches="tight")
1254
+ if show:
1255
+ plt.show()
1256
+ if return_fig:
1257
+ return plt.gcf()
1258
+ if not (show or save):
1259
+ return ax
1260
+ return None