pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import uuid
3
4
  from typing import TYPE_CHECKING
4
5
 
5
6
  import numpy as np
7
+ import pandas as pd
8
+ import scanpy as sc
6
9
  import scipy
7
10
 
8
11
  if TYPE_CHECKING:
9
12
  from anndata import AnnData
13
+ from matplotlib.axes import Axes
10
14
 
11
15
 
12
16
  class GuideAssignment:
@@ -39,7 +43,7 @@ class GuideAssignment:
39
43
 
40
44
  >>> import pertpy as pt
41
45
  >>> mdata = pt.data.papalexi_2021()
42
- >>> gdo = mdata.mod['gdo']
46
+ >>> gdo = mdata.mod["gdo"]
43
47
  >>> ga = pt.pp.GuideAssignment()
44
48
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
45
49
  """
@@ -71,7 +75,6 @@ class GuideAssignment:
71
75
 
72
76
  Args:
73
77
  adata: Annotated data matrix containing gRNA values
74
- assignment_threshold: If a gRNA is available for at least `assignment_threshold`, it will be recognized as assigned.
75
78
  assignment_threshold: The count threshold that is required for an assignment to be viable.
76
79
  layer: Key to the layer containing raw count values of the gRNAs.
77
80
  adata.X is used if layer is None. Expects count data.
@@ -83,8 +86,8 @@ class GuideAssignment:
83
86
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
84
87
 
85
88
  >>> import pertpy as pt
86
- >>> mdata = pt.data.papalexi_2021()
87
- >>> gdo = mdata.mod['gdo']
89
+ >>> mdata = pt.dt.papalexi_2021()
90
+ >>> gdo = mdata.mod["gdo"]
88
91
  >>> ga = pt.pp.GuideAssignment()
89
92
  >>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
90
93
  """
@@ -103,3 +106,84 @@ class GuideAssignment:
103
106
  adata.obs[output_key] = assigned_grna
104
107
 
105
108
  return None
109
+
110
+ def plot_heatmap(
111
+ self,
112
+ adata: AnnData,
113
+ layer: str | None = None,
114
+ order_by: np.ndarray | str | None = None,
115
+ key_to_save_order: str = None,
116
+ **kwargs,
117
+ ) -> list[Axes]:
118
+ """Heatmap plotting of guide RNA expression matrix.
119
+
120
+ Assuming guides have sparse expression, this function reorders cells
121
+ and plots guide RNA expression so that a nice sparse representation is achieved.
122
+ The cell ordering can be stored and reused in future plots to obtain consistent
123
+ plots before and after analysis of the guide RNA expression.
124
+ Note: This function expects a log-normalized or binary data.
125
+
126
+ Args:
127
+ adata: Annotated data matrix containing gRNA values
128
+ layer: Key to the layer containing log normalized count values of the gRNAs.
129
+ adata.X is used if layer is None.
130
+ order_by: The order of cells in y axis. Defaults to None.
131
+ If None, cells will be reordered to have a nice sparse representation.
132
+ If a string is provided, adata.obs[order_by] will be used as the order.
133
+ If a numpy array is provided, the array will be used for ordering.
134
+ key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
135
+ kwargs: Are passed to sc.pl.heatmap.
136
+
137
+ Returns:
138
+ List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
139
+ Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
140
+
141
+ Examples:
142
+ Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
143
+ visualized using a heatmap.
144
+
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
150
+ >>> ga.plot_heatmap(gdo)
151
+ """
152
+ data = adata.X if layer is None else adata.layers[layer]
153
+
154
+ if order_by is None:
155
+ if scipy.sparse.issparse(data):
156
+ max_values = data.max(axis=1).A.squeeze()
157
+ data_argmax = data.argmax(axis=1).A.squeeze()
158
+ max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
159
+ else:
160
+ max_guide_index = np.where(
161
+ data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
162
+ )
163
+ order = np.argsort(max_guide_index)
164
+ elif isinstance(order_by, str):
165
+ order = np.argsort(adata.obs[order_by])
166
+ else:
167
+ order = order_by
168
+
169
+ temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
170
+ adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])
171
+
172
+ if key_to_save_order is not None:
173
+ adata.obs[key_to_save_order] = pd.Categorical(order)
174
+
175
+ try:
176
+ axis_group = sc.pl.heatmap(
177
+ adata[order, :],
178
+ var_names=adata.var.index.tolist(),
179
+ groupby=temp_col_name,
180
+ cmap="viridis",
181
+ use_raw=False,
182
+ dendrogram=False,
183
+ layer=layer,
184
+ **kwargs,
185
+ )
186
+ finally:
187
+ del adata.obs[temp_col_name]
188
+
189
+ return axis_group
pertpy/tools/__init__.py CHANGED
@@ -1,24 +1,19 @@
1
- from rich import print
2
-
3
1
  from pertpy.tools._augur import Augur
4
2
  from pertpy.tools._cinemaot import Cinemaot
3
+ from pertpy.tools._coda._sccoda import Sccoda
4
+ from pertpy.tools._coda._tasccoda import Tasccoda
5
5
  from pertpy.tools._dialogue import Dialogue
6
6
  from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
7
7
  from pertpy.tools._distances._distance_tests import DistanceTest
8
8
  from pertpy.tools._distances._distances import Distance
9
- from pertpy.tools._metadata._cell_line import CellLineMetaData
9
+ from pertpy.tools._enrichment import Enrichment
10
10
  from pertpy.tools._milo import Milo
11
11
  from pertpy.tools._mixscape import Mixscape
12
12
  from pertpy.tools._perturbation_space._clustering import ClusteringSpace
13
- from pertpy.tools._perturbation_space._discriminator_classifier import DiscriminatorClassifierSpace
13
+ from pertpy.tools._perturbation_space._discriminator_classifiers import (
14
+ DiscriminatorClassifierSpace,
15
+ LRClassifierSpace,
16
+ MLPClassifierSpace,
17
+ )
14
18
  from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
15
19
  from pertpy.tools._scgen import SCGEN
16
-
17
- try:
18
- from pertpy.tools._coda._sccoda import Sccoda
19
- from pertpy.tools._coda._tasccoda import Tasccoda
20
- except ImportError as e:
21
- if "ete3" in str(e):
22
- print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
23
- else:
24
- raise e
pertpy/tools/_augur.py CHANGED
@@ -4,8 +4,10 @@ import random
4
4
  from collections import defaultdict
5
5
  from dataclasses import dataclass
6
6
  from math import floor, nan
7
- from typing import Any, Literal
7
+ from typing import TYPE_CHECKING, Any, Literal
8
8
 
9
+ import anndata as ad
10
+ import matplotlib.pyplot as plt
9
11
  import numpy as np
10
12
  import pandas as pd
11
13
  import scanpy as sc
@@ -34,6 +36,10 @@ from sklearn.preprocessing import LabelEncoder
34
36
  from skmisc.loess import loess
35
37
  from statsmodels.stats.multitest import fdrcorrection
36
38
 
39
+ if TYPE_CHECKING:
40
+ from matplotlib.axes import Axes
41
+ from matplotlib.figure import Figure
42
+
37
43
 
38
44
  @dataclass
39
45
  class Params:
@@ -135,8 +141,8 @@ class Augur:
135
141
  # filter samples according to label
136
142
  if condition_label is not None and treatment_label is not None:
137
143
  print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
138
- adata = AnnData.concatenate(
139
- adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
144
+ adata = ad.concat(
145
+ [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
140
146
  )
141
147
  label_encoder = LabelEncoder()
142
148
  adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
@@ -214,7 +220,9 @@ class Augur:
214
220
  >>> loaded_data = ag_rfc.load(adata)
215
221
  >>> ag_rfc.select_highly_variable(loaded_data)
216
222
  >>> features = loaded_data.var_names
217
- >>> subsample = ag_rfc.sample(loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names)
223
+ >>> subsample = ag_rfc.sample(
224
+ ... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
225
+ ... )
218
226
  """
219
227
  # export subsampling.
220
228
  random.seed(random_state)
@@ -230,7 +238,7 @@ class Augur:
230
238
  random_state=random_state,
231
239
  )
232
240
  )
233
- subsample = AnnData.concatenate(*label_subsamples, index_unique=None)
241
+ subsample = ad.concat([*label_subsamples], index_unique=None)
234
242
  else:
235
243
  subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
236
244
 
@@ -409,8 +417,8 @@ class Augur:
409
417
  """
410
418
  if multiclass:
411
419
  return {
412
- "augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
413
- "auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
420
+ "augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
421
+ "auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
414
422
  "accuracy": make_scorer(accuracy_score),
415
423
  "precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
416
424
  "f1": make_scorer(f1_score, average="macro"),
@@ -418,8 +426,8 @@ class Augur:
418
426
  }
419
427
  return (
420
428
  {
421
- "augur_score": make_scorer(roc_auc_score, needs_proba=True),
422
- "auc": make_scorer(roc_auc_score, needs_proba=True),
429
+ "augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
430
+ "auc": make_scorer(roc_auc_score, response_method="predict_proba"),
423
431
  "accuracy": make_scorer(accuracy_score),
424
432
  "precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
425
433
  "f1": make_scorer(f1_score, average="binary"),
@@ -488,7 +496,7 @@ class Augur:
488
496
  # feature importances
489
497
  feature_importances = defaultdict(list)
490
498
  if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
491
- for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
499
+ for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
492
500
  feature_importances["genes"].extend(x.columns.tolist())
493
501
  feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
494
502
  feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
@@ -497,7 +505,7 @@ class Augur:
497
505
  # standardized coefficients with Agresti method
498
506
  # cf. https://think-lab.github.io/d/205/#3
499
507
  if isinstance(self.estimator, LogisticRegression):
500
- for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
508
+ for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
501
509
  feature_importances["genes"].extend(x.columns.tolist())
502
510
  feature_importances["feature_importances"].extend(
503
511
  (self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
@@ -723,6 +731,7 @@ class Augur:
723
731
  >>> loaded_data = ag_rfc.load(adata)
724
732
  >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
725
733
  """
734
+ adata = adata.copy()
726
735
  if augur_mode == "permute" and n_subsamples < 100:
727
736
  n_subsamples = 500
728
737
  if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
@@ -765,6 +774,7 @@ class Augur:
765
774
  elif (
766
775
  cell_type_subsample.obs.groupby(
767
776
  ["cell_type", "label"],
777
+ observed=True,
768
778
  ).y_.count()
769
779
  < subsample_size
770
780
  ).any():
@@ -804,7 +814,7 @@ class Augur:
804
814
  * (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
805
815
  )
806
816
 
807
- for idx, cv in zip(range(n_subsamples), results[cell_type]):
817
+ for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
808
818
  results["full_results"]["idx"].extend([idx] * folds)
809
819
  results["full_results"]["augur_score"].extend(cv["test_augur_score"])
810
820
  results["full_results"]["folds"].extend(range(folds))
@@ -869,28 +879,31 @@ class Augur:
869
879
  & set(permuted_results1["summary_metrics"].columns)
870
880
  & set(permuted_results2["summary_metrics"].columns)
871
881
  )
882
+
883
+ cell_types_list = list(cell_types)
884
+
872
885
  # mean augur scores
873
886
  augur_score1 = (
874
887
  augur_results1["summary_metrics"]
875
- .loc["mean_augur_score", cell_types]
888
+ .loc["mean_augur_score", cell_types_list]
876
889
  .reset_index()
877
890
  .rename(columns={"index": "cell_type"})
878
891
  )
879
892
  augur_score2 = (
880
893
  augur_results2["summary_metrics"]
881
- .loc["mean_augur_score", cell_types]
894
+ .loc["mean_augur_score", cell_types_list]
882
895
  .reset_index()
883
896
  .rename(columns={"index": "cell_type"})
884
897
  )
885
898
 
886
899
  # mean permuted scores over cross validation runs
887
900
  permuted_cv_augur1 = (
888
- permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types)]
901
+ permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
889
902
  .groupby(["cell_type", "idx"], as_index=False)
890
903
  .mean()
891
904
  )
892
905
  permuted_cv_augur2 = (
893
- permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types)]
906
+ permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
894
907
  .groupby(["cell_type", "idx"], as_index=False)
895
908
  .mean()
896
909
  )
@@ -901,7 +914,7 @@ class Augur:
901
914
  # draw mean aucs for permute1 and permute2
902
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
903
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
904
- df2 = permuted_cv_augur2[permuted_cv_augur1["cell_type"] == celltype]
917
+ df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
905
918
  for permutation_idx in range(n_permutations):
906
919
  # subsample
907
920
  sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
@@ -961,3 +974,288 @@ class Augur:
961
974
  delta["padj"] = fdrcorrection(delta["pval"])[1]
962
975
 
963
976
  return delta
977
+
978
+ def plot_dp_scatter(
979
+ self,
980
+ results: pd.DataFrame,
981
+ top_n: int = None,
982
+ return_fig: bool | None = None,
983
+ ax: Axes = None,
984
+ show: bool | None = None,
985
+ save: str | bool | None = None,
986
+ ) -> Axes | Figure | None:
987
+ """Plot scatterplot of differential prioritization.
988
+
989
+ Args:
990
+ results: Results after running differential prioritization.
991
+ top_n: optionally, the number of top prioritized cell types to label in the plot
992
+ ax: optionally, axes used to draw plot
993
+
994
+ Returns:
995
+ Axes of the plot.
996
+
997
+ Examples:
998
+ >>> import pertpy as pt
999
+ >>> adata = pt.dt.bhattacherjee()
1000
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1001
+
1002
+ >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
1003
+ >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
1004
+ >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1005
+
1006
+ >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
1007
+ >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
1008
+ >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1009
+
1010
+ >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
1011
+ permuted_results1=results_15_permute, permuted_results2=results_48_permute)
1012
+ >>> ag_rfc.plot_dp_scatter(pvals)
1013
+
1014
+ Preview:
1015
+ .. image:: /_static/docstring_previews/augur_dp_scatter.png
1016
+ """
1017
+ x = results["mean_augur_score1"]
1018
+ y = results["mean_augur_score2"]
1019
+
1020
+ if ax is None:
1021
+ fig, ax = plt.subplots()
1022
+ scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
1023
+
1024
+ # adding optional labels
1025
+ top_n_index = results.sort_values(by="pval").index[:top_n]
1026
+ for idx in top_n_index:
1027
+ ax.annotate(
1028
+ results.loc[idx, "cell_type"],
1029
+ (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
1030
+ )
1031
+
1032
+ # add diagonal
1033
+ limits = max(ax.get_xlim(), ax.get_ylim())
1034
+ (_,) = ax.plot(limits, limits, ls="--", c=".3")
1035
+
1036
+ # formatting and details
1037
+ plt.xlabel("Augur scores 1")
1038
+ plt.ylabel("Augur scores 2")
1039
+ legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1040
+ ax.add_artist(legend1)
1041
+
1042
+ if save:
1043
+ plt.savefig(save, bbox_inches="tight")
1044
+ if show:
1045
+ plt.show()
1046
+ if return_fig:
1047
+ return plt.gcf()
1048
+ if not (show or save):
1049
+ return ax
1050
+ return None
1051
+
1052
+ def plot_important_features(
1053
+ self,
1054
+ data: dict[str, Any],
1055
+ key: str = "augurpy_results",
1056
+ top_n: int = 10,
1057
+ return_fig: bool | None = None,
1058
+ ax: Axes = None,
1059
+ show: bool | None = None,
1060
+ save: str | bool | None = None,
1061
+ ) -> Axes | None:
1062
+ """Plot a lollipop plot of the n features with largest feature importances.
1063
+
1064
+ Args:
1065
+ results: results after running `predict()` as dictionary or the AnnData object.
1066
+ key: Key in the AnnData object of the results
1067
+ top_n: n number feature importance values to plot. Default is 10.
1068
+ ax: optionally, axes used to draw plot
1069
+ return_figure: if `True` returns figure of the plot, default is `False`
1070
+
1071
+ Returns:
1072
+ Axes of the plot.
1073
+
1074
+ Examples:
1075
+ >>> import pertpy as pt
1076
+ >>> adata = pt.dt.sc_sim_augur()
1077
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1078
+ >>> loaded_data = ag_rfc.load(adata)
1079
+ >>> v_adata, v_results = ag_rfc.predict(
1080
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1081
+ ... )
1082
+ >>> ag_rfc.plot_important_features(v_results)
1083
+
1084
+ Preview:
1085
+ .. image:: /_static/docstring_previews/augur_important_features.png
1086
+ """
1087
+ if isinstance(data, AnnData):
1088
+ results = data.uns[key]
1089
+ else:
1090
+ results = data
1091
+ n_features = (
1092
+ results["feature_importances"]
1093
+ .groupby("genes", as_index=False)
1094
+ .feature_importances.mean()
1095
+ .sort_values(by="feature_importances")[-top_n:]
1096
+ )
1097
+
1098
+ if ax is None:
1099
+ fig, ax = plt.subplots()
1100
+ y_axes_range = range(1, top_n + 1)
1101
+ ax.hlines(
1102
+ y_axes_range,
1103
+ xmin=0,
1104
+ xmax=n_features["feature_importances"],
1105
+ )
1106
+
1107
+ ax.plot(n_features["feature_importances"], y_axes_range, "o")
1108
+
1109
+ plt.xlabel("Mean Feature Importance")
1110
+ plt.ylabel("Gene")
1111
+ plt.yticks(y_axes_range, n_features["genes"])
1112
+
1113
+ if save:
1114
+ plt.savefig(save, bbox_inches="tight")
1115
+ if show:
1116
+ plt.show()
1117
+ if return_fig:
1118
+ return plt.gcf()
1119
+ if not (show or save):
1120
+ return ax
1121
+ return None
1122
+
1123
+ def plot_lollipop(
1124
+ self,
1125
+ data: dict[str, Any],
1126
+ key: str = "augurpy_results",
1127
+ return_fig: bool | None = None,
1128
+ ax: Axes = None,
1129
+ show: bool | None = None,
1130
+ save: str | bool | None = None,
1131
+ ) -> Axes | Figure | None:
1132
+ """Plot a lollipop plot of the mean augur values.
1133
+
1134
+ Args:
1135
+ results: results after running `predict()` as dictionary or the AnnData object.
1136
+ key: Key in the AnnData object of the results
1137
+ ax: optionally, axes used to draw plot
1138
+ return_figure: if `True` returns figure of the plot
1139
+
1140
+ Returns:
1141
+ Axes of the plot.
1142
+
1143
+ Examples:
1144
+ >>> import pertpy as pt
1145
+ >>> adata = pt.dt.sc_sim_augur()
1146
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1147
+ >>> loaded_data = ag_rfc.load(adata)
1148
+ >>> v_adata, v_results = ag_rfc.predict(
1149
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1150
+ ... )
1151
+ >>> ag_rfc.plot_lollipop(v_results)
1152
+
1153
+ Preview:
1154
+ .. image:: /_static/docstring_previews/augur_lollipop.png
1155
+ """
1156
+ if isinstance(data, AnnData):
1157
+ results = data.uns[key]
1158
+ else:
1159
+ results = data
1160
+ if ax is None:
1161
+ fig, ax = plt.subplots()
1162
+ y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
1163
+ ax.hlines(
1164
+ y_axes_range,
1165
+ xmin=0,
1166
+ xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1167
+ )
1168
+
1169
+ ax.plot(
1170
+ results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1171
+ y_axes_range,
1172
+ "o",
1173
+ )
1174
+
1175
+ plt.xlabel("Mean Augur Score")
1176
+ plt.ylabel("Cell Type")
1177
+ plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1178
+
1179
+ if save:
1180
+ plt.savefig(save, bbox_inches="tight")
1181
+ if show:
1182
+ plt.show()
1183
+ if return_fig:
1184
+ return plt.gcf()
1185
+ if not (show or save):
1186
+ return ax
1187
+ return None
1188
+
1189
+ def plot_scatterplot(
1190
+ self,
1191
+ results1: dict[str, Any],
1192
+ results2: dict[str, Any],
1193
+ top_n: int = None,
1194
+ return_fig: bool | None = None,
1195
+ show: bool | None = None,
1196
+ save: str | bool | None = None,
1197
+ ) -> Axes | Figure | None:
1198
+ """Create scatterplot with two augur results.
1199
+
1200
+ Args:
1201
+ results1: results after running `predict()`
1202
+ results2: results after running `predict()`
1203
+ top_n: optionally, the number of top prioritized cell types to label in the plot
1204
+ return_figure: if `True` returns figure of the plot
1205
+
1206
+ Returns:
1207
+ Axes of the plot.
1208
+
1209
+ Examples:
1210
+ >>> import pertpy as pt
1211
+ >>> adata = pt.dt.sc_sim_augur()
1212
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1213
+ >>> loaded_data = ag_rfc.load(adata)
1214
+ >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
1215
+ >>> v_adata, v_results = ag_rfc.predict(
1216
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1217
+ ... )
1218
+ >>> ag_rfc.plot_scatterplot(v_results, h_results)
1219
+
1220
+ Preview:
1221
+ .. image:: /_static/docstring_previews/augur_scatterplot.png
1222
+ """
1223
+ cell_types = results1["summary_metrics"].columns
1224
+
1225
+ fig, ax = plt.subplots()
1226
+ ax.scatter(
1227
+ results1["summary_metrics"].loc["mean_augur_score", cell_types],
1228
+ results2["summary_metrics"].loc["mean_augur_score", cell_types],
1229
+ )
1230
+
1231
+ # adding optional labels
1232
+ top_n_cell_types = (
1233
+ (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
1234
+ .sort_values(ascending=False)
1235
+ .index[:top_n]
1236
+ )
1237
+ for txt in top_n_cell_types:
1238
+ ax.annotate(
1239
+ txt,
1240
+ (
1241
+ results1["summary_metrics"].loc["mean_augur_score", txt],
1242
+ results2["summary_metrics"].loc["mean_augur_score", txt],
1243
+ ),
1244
+ )
1245
+
1246
+ # adding diagonal
1247
+ limits = max(ax.get_xlim(), ax.get_ylim())
1248
+ (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
1249
+
1250
+ plt.xlabel("Augur scores 1")
1251
+ plt.ylabel("Augur scores 2")
1252
+
1253
+ if save:
1254
+ plt.savefig(save, bbox_inches="tight")
1255
+ if show:
1256
+ plt.show()
1257
+ if return_fig:
1258
+ return plt.gcf()
1259
+ if not (show or save):
1260
+ return ax
1261
+ return None