pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import uuid
3
4
  from typing import TYPE_CHECKING
4
5
 
5
6
  import numpy as np
7
+ import pandas as pd
8
+ import scanpy as sc
6
9
  import scipy
7
10
 
8
11
  if TYPE_CHECKING:
9
12
  from anndata import AnnData
13
+ from matplotlib.axes import Axes
10
14
 
11
15
 
12
16
  class GuideAssignment:
@@ -39,7 +43,7 @@ class GuideAssignment:
39
43
 
40
44
  >>> import pertpy as pt
41
45
  >>> mdata = pt.data.papalexi_2021()
42
- >>> gdo = mdata.mod['gdo']
46
+ >>> gdo = mdata.mod["gdo"]
43
47
  >>> ga = pt.pp.GuideAssignment()
44
48
  >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
45
49
  """
@@ -71,7 +75,6 @@ class GuideAssignment:
71
75
 
72
76
  Args:
73
77
  adata: Annotated data matrix containing gRNA values
74
- assignment_threshold: If a gRNA is available for at least `assignment_threshold`, it will be recognized as assigned.
75
78
  assignment_threshold: The count threshold that is required for an assignment to be viable.
76
79
  layer: Key to the layer containing raw count values of the gRNAs.
77
80
  adata.X is used if layer is None. Expects count data.
@@ -83,8 +86,8 @@ class GuideAssignment:
83
86
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
84
87
 
85
88
  >>> import pertpy as pt
86
- >>> mdata = pt.data.papalexi_2021()
87
- >>> gdo = mdata.mod['gdo']
89
+ >>> mdata = pt.dt.papalexi_2021()
90
+ >>> gdo = mdata.mod["gdo"]
88
91
  >>> ga = pt.pp.GuideAssignment()
89
92
  >>> ga.assign_to_max_guide(gdo, assignment_threshold=5)
90
93
  """
@@ -103,3 +106,84 @@ class GuideAssignment:
103
106
  adata.obs[output_key] = assigned_grna
104
107
 
105
108
  return None
109
+
110
+ def plot_heatmap(
111
+ self,
112
+ adata: AnnData,
113
+ layer: str | None = None,
114
+ order_by: np.ndarray | str | None = None,
115
+ key_to_save_order: str = None,
116
+ **kwargs,
117
+ ) -> list[Axes]:
118
+ """Heatmap plotting of guide RNA expression matrix.
119
+
120
+ Assuming guides have sparse expression, this function reorders cells
121
+ and plots guide RNA expression so that a nice sparse representation is achieved.
122
+ The cell ordering can be stored and reused in future plots to obtain consistent
123
+ plots before and after analysis of the guide RNA expression.
124
+ Note: This function expects a log-normalized or binary data.
125
+
126
+ Args:
127
+ adata: Annotated data matrix containing gRNA values
128
+ layer: Key to the layer containing log normalized count values of the gRNAs.
129
+ adata.X is used if layer is None.
130
+ order_by: The order of cells in y axis. Defaults to None.
131
+ If None, cells will be reordered to have a nice sparse representation.
132
+ If a string is provided, adata.obs[order_by] will be used as the order.
133
+ If a numpy array is provided, the array will be used for ordering.
134
+ key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None.
135
+ kwargs: Are passed to sc.pl.heatmap.
136
+
137
+ Returns:
138
+ List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap.
139
+ Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided.
140
+
141
+ Examples:
142
+ Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then
143
+ visualized using a heatmap.
144
+
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_by_threshold(gdo, assignment_threshold=5)
150
+ >>> ga.plot_heatmap(gdo)
151
+ """
152
+ data = adata.X if layer is None else adata.layers[layer]
153
+
154
+ if order_by is None:
155
+ if scipy.sparse.issparse(data):
156
+ max_values = data.max(axis=1).A.squeeze()
157
+ data_argmax = data.argmax(axis=1).A.squeeze()
158
+ max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
159
+ else:
160
+ max_guide_index = np.where(
161
+ data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
162
+ )
163
+ order = np.argsort(max_guide_index)
164
+ elif isinstance(order_by, str):
165
+ order = np.argsort(adata.obs[order_by])
166
+ else:
167
+ order = order_by
168
+
169
+ temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
170
+ adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])
171
+
172
+ if key_to_save_order is not None:
173
+ adata.obs[key_to_save_order] = pd.Categorical(order)
174
+
175
+ try:
176
+ axis_group = sc.pl.heatmap(
177
+ adata[order, :],
178
+ var_names=adata.var.index.tolist(),
179
+ groupby=temp_col_name,
180
+ cmap="viridis",
181
+ use_raw=False,
182
+ dendrogram=False,
183
+ layer=layer,
184
+ **kwargs,
185
+ )
186
+ finally:
187
+ del adata.obs[temp_col_name]
188
+
189
+ return axis_group
pertpy/tools/__init__.py CHANGED
@@ -1,24 +1,19 @@
1
- from rich import print
2
-
3
1
  from pertpy.tools._augur import Augur
4
2
  from pertpy.tools._cinemaot import Cinemaot
3
+ from pertpy.tools._coda._sccoda import Sccoda
4
+ from pertpy.tools._coda._tasccoda import Tasccoda
5
5
  from pertpy.tools._dialogue import Dialogue
6
6
  from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
7
7
  from pertpy.tools._distances._distance_tests import DistanceTest
8
8
  from pertpy.tools._distances._distances import Distance
9
- from pertpy.tools._metadata._cell_line import CellLineMetaData
9
+ from pertpy.tools._enrichment import Enrichment
10
10
  from pertpy.tools._milo import Milo
11
11
  from pertpy.tools._mixscape import Mixscape
12
12
  from pertpy.tools._perturbation_space._clustering import ClusteringSpace
13
- from pertpy.tools._perturbation_space._discriminator_classifier import DiscriminatorClassifierSpace
13
+ from pertpy.tools._perturbation_space._discriminator_classifiers import (
14
+ DiscriminatorClassifierSpace,
15
+ LRClassifierSpace,
16
+ MLPClassifierSpace,
17
+ )
14
18
  from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
15
19
  from pertpy.tools._scgen import SCGEN
16
-
17
- try:
18
- from pertpy.tools._coda._sccoda import Sccoda
19
- from pertpy.tools._coda._tasccoda import Tasccoda
20
- except ImportError as e:
21
- if "ete3" in str(e):
22
- print("[bold yellow]To use sccoda or tasccoda please install ete3 with [green]pip install ete3")
23
- else:
24
- raise e
pertpy/tools/_augur.py CHANGED
@@ -4,8 +4,10 @@ import random
4
4
  from collections import defaultdict
5
5
  from dataclasses import dataclass
6
6
  from math import floor, nan
7
- from typing import Any, Literal
7
+ from typing import TYPE_CHECKING, Any, Literal
8
8
 
9
+ import anndata as ad
10
+ import matplotlib.pyplot as plt
9
11
  import numpy as np
10
12
  import pandas as pd
11
13
  import scanpy as sc
@@ -34,6 +36,10 @@ from sklearn.preprocessing import LabelEncoder
34
36
  from skmisc.loess import loess
35
37
  from statsmodels.stats.multitest import fdrcorrection
36
38
 
39
+ if TYPE_CHECKING:
40
+ from matplotlib.axes import Axes
41
+ from matplotlib.figure import Figure
42
+
37
43
 
38
44
  @dataclass
39
45
  class Params:
@@ -135,8 +141,8 @@ class Augur:
135
141
  # filter samples according to label
136
142
  if condition_label is not None and treatment_label is not None:
137
143
  print(f"Filtering samples with {condition_label} and {treatment_label} labels.")
138
- adata = AnnData.concatenate(
139
- adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]
144
+ adata = ad.concat(
145
+ [adata[adata.obs["label"] == condition_label], adata[adata.obs["label"] == treatment_label]]
140
146
  )
141
147
  label_encoder = LabelEncoder()
142
148
  adata.obs["y_"] = label_encoder.fit_transform(adata.obs["label"])
@@ -214,7 +220,9 @@ class Augur:
214
220
  >>> loaded_data = ag_rfc.load(adata)
215
221
  >>> ag_rfc.select_highly_variable(loaded_data)
216
222
  >>> features = loaded_data.var_names
217
- >>> subsample = ag_rfc.sample(loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names)
223
+ >>> subsample = ag_rfc.sample(
224
+ ... loaded_data, categorical=True, subsample_size=20, random_state=42, features=loaded_data.var_names
225
+ ... )
218
226
  """
219
227
  # export subsampling.
220
228
  random.seed(random_state)
@@ -230,7 +238,7 @@ class Augur:
230
238
  random_state=random_state,
231
239
  )
232
240
  )
233
- subsample = AnnData.concatenate(*label_subsamples, index_unique=None)
241
+ subsample = ad.concat([*label_subsamples], index_unique=None)
234
242
  else:
235
243
  subsample = sc.pp.subsample(adata[:, features], n_obs=subsample_size, copy=True, random_state=random_state)
236
244
 
@@ -409,8 +417,8 @@ class Augur:
409
417
  """
410
418
  if multiclass:
411
419
  return {
412
- "augur_score": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
413
- "auc": make_scorer(roc_auc_score, multi_class="ovo", needs_proba=True),
420
+ "augur_score": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
421
+ "auc": make_scorer(roc_auc_score, multi_class="ovo", response_method="predict_proba"),
414
422
  "accuracy": make_scorer(accuracy_score),
415
423
  "precision": make_scorer(precision_score, average="macro", zero_division=zero_division),
416
424
  "f1": make_scorer(f1_score, average="macro"),
@@ -418,8 +426,8 @@ class Augur:
418
426
  }
419
427
  return (
420
428
  {
421
- "augur_score": make_scorer(roc_auc_score, needs_proba=True),
422
- "auc": make_scorer(roc_auc_score, needs_proba=True),
429
+ "augur_score": make_scorer(roc_auc_score, response_method="predict_proba"),
430
+ "auc": make_scorer(roc_auc_score, response_method="predict_proba"),
423
431
  "accuracy": make_scorer(accuracy_score),
424
432
  "precision": make_scorer(precision_score, average="binary", zero_division=zero_division),
425
433
  "f1": make_scorer(f1_score, average="binary"),
@@ -488,7 +496,7 @@ class Augur:
488
496
  # feature importances
489
497
  feature_importances = defaultdict(list)
490
498
  if isinstance(self.estimator, RandomForestClassifier) or isinstance(self.estimator, RandomForestRegressor):
491
- for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
499
+ for fold, estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
492
500
  feature_importances["genes"].extend(x.columns.tolist())
493
501
  feature_importances["feature_importances"].extend(estimator.feature_importances_.tolist())
494
502
  feature_importances["subsample_idx"].extend(len(x.columns) * [subsample_idx])
@@ -497,7 +505,7 @@ class Augur:
497
505
  # standardized coefficients with Agresti method
498
506
  # cf. https://think-lab.github.io/d/205/#3
499
507
  if isinstance(self.estimator, LogisticRegression):
500
- for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"])):
508
+ for fold, self.estimator in list(zip(range(len(results["estimator"])), results["estimator"], strict=False)):
501
509
  feature_importances["genes"].extend(x.columns.tolist())
502
510
  feature_importances["feature_importances"].extend(
503
511
  (self.estimator.coef_ * self.estimator.coef_.std()).flatten().tolist()
@@ -723,6 +731,7 @@ class Augur:
723
731
  >>> loaded_data = ag_rfc.load(adata)
724
732
  >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
725
733
  """
734
+ adata = adata.copy()
726
735
  if augur_mode == "permute" and n_subsamples < 100:
727
736
  n_subsamples = 500
728
737
  if is_regressor(self.estimator) and len(adata.obs["y_"].unique()) <= 3:
@@ -765,6 +774,7 @@ class Augur:
765
774
  elif (
766
775
  cell_type_subsample.obs.groupby(
767
776
  ["cell_type", "label"],
777
+ observed=True,
768
778
  ).y_.count()
769
779
  < subsample_size
770
780
  ).any():
@@ -804,7 +814,7 @@ class Augur:
804
814
  * (len(results["feature_importances"]["genes"]) - len(results["feature_importances"]["cell_type"]))
805
815
  )
806
816
 
807
- for idx, cv in zip(range(n_subsamples), results[cell_type]):
817
+ for idx, cv in zip(range(n_subsamples), results[cell_type], strict=False):
808
818
  results["full_results"]["idx"].extend([idx] * folds)
809
819
  results["full_results"]["augur_score"].extend(cv["test_augur_score"])
810
820
  results["full_results"]["folds"].extend(range(folds))
@@ -869,28 +879,31 @@ class Augur:
869
879
  & set(permuted_results1["summary_metrics"].columns)
870
880
  & set(permuted_results2["summary_metrics"].columns)
871
881
  )
882
+
883
+ cell_types_list = list(cell_types)
884
+
872
885
  # mean augur scores
873
886
  augur_score1 = (
874
887
  augur_results1["summary_metrics"]
875
- .loc["mean_augur_score", cell_types]
888
+ .loc["mean_augur_score", cell_types_list]
876
889
  .reset_index()
877
890
  .rename(columns={"index": "cell_type"})
878
891
  )
879
892
  augur_score2 = (
880
893
  augur_results2["summary_metrics"]
881
- .loc["mean_augur_score", cell_types]
894
+ .loc["mean_augur_score", cell_types_list]
882
895
  .reset_index()
883
896
  .rename(columns={"index": "cell_type"})
884
897
  )
885
898
 
886
899
  # mean permuted scores over cross validation runs
887
900
  permuted_cv_augur1 = (
888
- permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types)]
901
+ permuted_results1["full_results"][permuted_results1["full_results"]["cell_type"].isin(cell_types_list)]
889
902
  .groupby(["cell_type", "idx"], as_index=False)
890
903
  .mean()
891
904
  )
892
905
  permuted_cv_augur2 = (
893
- permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types)]
906
+ permuted_results2["full_results"][permuted_results2["full_results"]["cell_type"].isin(cell_types_list)]
894
907
  .groupby(["cell_type", "idx"], as_index=False)
895
908
  .mean()
896
909
  )
@@ -901,7 +914,7 @@ class Augur:
901
914
  # draw mean aucs for permute1 and permute2
902
915
  for celltype in permuted_cv_augur1["cell_type"].unique():
903
916
  df1 = permuted_cv_augur1[permuted_cv_augur1["cell_type"] == celltype]
904
- df2 = permuted_cv_augur2[permuted_cv_augur1["cell_type"] == celltype]
917
+ df2 = permuted_cv_augur2[permuted_cv_augur2["cell_type"] == celltype]
905
918
  for permutation_idx in range(n_permutations):
906
919
  # subsample
907
920
  sample1 = df1.sample(n=n_subsamples, random_state=permutation_idx, axis="index")
@@ -961,3 +974,288 @@ class Augur:
961
974
  delta["padj"] = fdrcorrection(delta["pval"])[1]
962
975
 
963
976
  return delta
977
+
978
+ def plot_dp_scatter(
979
+ self,
980
+ results: pd.DataFrame,
981
+ top_n: int = None,
982
+ return_fig: bool | None = None,
983
+ ax: Axes = None,
984
+ show: bool | None = None,
985
+ save: str | bool | None = None,
986
+ ) -> Axes | Figure | None:
987
+ """Plot scatterplot of differential prioritization.
988
+
989
+ Args:
990
+ results: Results after running differential prioritization.
991
+ top_n: optionally, the number of top prioritized cell types to label in the plot
992
+ ax: optionally, axes used to draw plot
993
+
994
+ Returns:
995
+ Axes of the plot.
996
+
997
+ Examples:
998
+ >>> import pertpy as pt
999
+ >>> adata = pt.dt.bhattacherjee()
1000
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1001
+
1002
+ >>> data_15 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_15d_Cocaine")
1003
+ >>> adata_15, results_15 = ag_rfc.predict(data_15, random_state=None, n_threads=4)
1004
+ >>> adata_15_permute, results_15_permute = ag_rfc.predict(data_15, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1005
+
1006
+ >>> data_48 = ag_rfc.load(adata, condition_label="Maintenance_Cocaine", treatment_label="withdraw_48h_Cocaine")
1007
+ >>> adata_48, results_48 = ag_rfc.predict(data_48, random_state=None, n_threads=4)
1008
+ >>> adata_48_permute, results_48_permute = ag_rfc.predict(data_48, augur_mode="permute", n_subsamples=100, random_state=None, n_threads=4)
1009
+
1010
+ >>> pvals = ag_rfc.predict_differential_prioritization(augur_results1=results_15, augur_results2=results_48, \
1011
+ permuted_results1=results_15_permute, permuted_results2=results_48_permute)
1012
+ >>> ag_rfc.plot_dp_scatter(pvals)
1013
+
1014
+ Preview:
1015
+ .. image:: /_static/docstring_previews/augur_dp_scatter.png
1016
+ """
1017
+ x = results["mean_augur_score1"]
1018
+ y = results["mean_augur_score2"]
1019
+
1020
+ if ax is None:
1021
+ fig, ax = plt.subplots()
1022
+ scatter = ax.scatter(x, y, c=results.z, cmap="Greens")
1023
+
1024
+ # adding optional labels
1025
+ top_n_index = results.sort_values(by="pval").index[:top_n]
1026
+ for idx in top_n_index:
1027
+ ax.annotate(
1028
+ results.loc[idx, "cell_type"],
1029
+ (results.loc[idx, "mean_augur_score1"], results.loc[idx, "mean_augur_score2"]),
1030
+ )
1031
+
1032
+ # add diagonal
1033
+ limits = max(ax.get_xlim(), ax.get_ylim())
1034
+ (_,) = ax.plot(limits, limits, ls="--", c=".3")
1035
+
1036
+ # formatting and details
1037
+ plt.xlabel("Augur scores 1")
1038
+ plt.ylabel("Augur scores 2")
1039
+ legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
1040
+ ax.add_artist(legend1)
1041
+
1042
+ if save:
1043
+ plt.savefig(save, bbox_inches="tight")
1044
+ if show:
1045
+ plt.show()
1046
+ if return_fig:
1047
+ return plt.gcf()
1048
+ if not (show or save):
1049
+ return ax
1050
+ return None
1051
+
1052
+ def plot_important_features(
1053
+ self,
1054
+ data: dict[str, Any],
1055
+ key: str = "augurpy_results",
1056
+ top_n: int = 10,
1057
+ return_fig: bool | None = None,
1058
+ ax: Axes = None,
1059
+ show: bool | None = None,
1060
+ save: str | bool | None = None,
1061
+ ) -> Axes | None:
1062
+ """Plot a lollipop plot of the n features with largest feature importances.
1063
+
1064
+ Args:
1065
+ results: results after running `predict()` as dictionary or the AnnData object.
1066
+ key: Key in the AnnData object of the results
1067
+ top_n: n number feature importance values to plot. Default is 10.
1068
+ ax: optionally, axes used to draw plot
1069
+ return_figure: if `True` returns figure of the plot, default is `False`
1070
+
1071
+ Returns:
1072
+ Axes of the plot.
1073
+
1074
+ Examples:
1075
+ >>> import pertpy as pt
1076
+ >>> adata = pt.dt.sc_sim_augur()
1077
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1078
+ >>> loaded_data = ag_rfc.load(adata)
1079
+ >>> v_adata, v_results = ag_rfc.predict(
1080
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1081
+ ... )
1082
+ >>> ag_rfc.plot_important_features(v_results)
1083
+
1084
+ Preview:
1085
+ .. image:: /_static/docstring_previews/augur_important_features.png
1086
+ """
1087
+ if isinstance(data, AnnData):
1088
+ results = data.uns[key]
1089
+ else:
1090
+ results = data
1091
+ n_features = (
1092
+ results["feature_importances"]
1093
+ .groupby("genes", as_index=False)
1094
+ .feature_importances.mean()
1095
+ .sort_values(by="feature_importances")[-top_n:]
1096
+ )
1097
+
1098
+ if ax is None:
1099
+ fig, ax = plt.subplots()
1100
+ y_axes_range = range(1, top_n + 1)
1101
+ ax.hlines(
1102
+ y_axes_range,
1103
+ xmin=0,
1104
+ xmax=n_features["feature_importances"],
1105
+ )
1106
+
1107
+ ax.plot(n_features["feature_importances"], y_axes_range, "o")
1108
+
1109
+ plt.xlabel("Mean Feature Importance")
1110
+ plt.ylabel("Gene")
1111
+ plt.yticks(y_axes_range, n_features["genes"])
1112
+
1113
+ if save:
1114
+ plt.savefig(save, bbox_inches="tight")
1115
+ if show:
1116
+ plt.show()
1117
+ if return_fig:
1118
+ return plt.gcf()
1119
+ if not (show or save):
1120
+ return ax
1121
+ return None
1122
+
1123
+ def plot_lollipop(
1124
+ self,
1125
+ data: dict[str, Any],
1126
+ key: str = "augurpy_results",
1127
+ return_fig: bool | None = None,
1128
+ ax: Axes = None,
1129
+ show: bool | None = None,
1130
+ save: str | bool | None = None,
1131
+ ) -> Axes | Figure | None:
1132
+ """Plot a lollipop plot of the mean augur values.
1133
+
1134
+ Args:
1135
+ results: results after running `predict()` as dictionary or the AnnData object.
1136
+ key: Key in the AnnData object of the results
1137
+ ax: optionally, axes used to draw plot
1138
+ return_figure: if `True` returns figure of the plot
1139
+
1140
+ Returns:
1141
+ Axes of the plot.
1142
+
1143
+ Examples:
1144
+ >>> import pertpy as pt
1145
+ >>> adata = pt.dt.sc_sim_augur()
1146
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1147
+ >>> loaded_data = ag_rfc.load(adata)
1148
+ >>> v_adata, v_results = ag_rfc.predict(
1149
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1150
+ ... )
1151
+ >>> ag_rfc.plot_lollipop(v_results)
1152
+
1153
+ Preview:
1154
+ .. image:: /_static/docstring_previews/augur_lollipop.png
1155
+ """
1156
+ if isinstance(data, AnnData):
1157
+ results = data.uns[key]
1158
+ else:
1159
+ results = data
1160
+ if ax is None:
1161
+ fig, ax = plt.subplots()
1162
+ y_axes_range = range(1, len(results["summary_metrics"].columns) + 1)
1163
+ ax.hlines(
1164
+ y_axes_range,
1165
+ xmin=0,
1166
+ xmax=results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1167
+ )
1168
+
1169
+ ax.plot(
1170
+ results["summary_metrics"].sort_values("mean_augur_score", axis=1).loc["mean_augur_score"],
1171
+ y_axes_range,
1172
+ "o",
1173
+ )
1174
+
1175
+ plt.xlabel("Mean Augur Score")
1176
+ plt.ylabel("Cell Type")
1177
+ plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)
1178
+
1179
+ if save:
1180
+ plt.savefig(save, bbox_inches="tight")
1181
+ if show:
1182
+ plt.show()
1183
+ if return_fig:
1184
+ return plt.gcf()
1185
+ if not (show or save):
1186
+ return ax
1187
+ return None
1188
+
1189
+ def plot_scatterplot(
1190
+ self,
1191
+ results1: dict[str, Any],
1192
+ results2: dict[str, Any],
1193
+ top_n: int = None,
1194
+ return_fig: bool | None = None,
1195
+ show: bool | None = None,
1196
+ save: str | bool | None = None,
1197
+ ) -> Axes | Figure | None:
1198
+ """Create scatterplot with two augur results.
1199
+
1200
+ Args:
1201
+ results1: results after running `predict()`
1202
+ results2: results after running `predict()`
1203
+ top_n: optionally, the number of top prioritized cell types to label in the plot
1204
+ return_figure: if `True` returns figure of the plot
1205
+
1206
+ Returns:
1207
+ Axes of the plot.
1208
+
1209
+ Examples:
1210
+ >>> import pertpy as pt
1211
+ >>> adata = pt.dt.sc_sim_augur()
1212
+ >>> ag_rfc = pt.tl.Augur("random_forest_classifier")
1213
+ >>> loaded_data = ag_rfc.load(adata)
1214
+ >>> h_adata, h_results = ag_rfc.predict(loaded_data, subsample_size=20, n_threads=4)
1215
+ >>> v_adata, v_results = ag_rfc.predict(
1216
+ ... loaded_data, subsample_size=20, select_variance_features=True, n_threads=4
1217
+ ... )
1218
+ >>> ag_rfc.plot_scatterplot(v_results, h_results)
1219
+
1220
+ Preview:
1221
+ .. image:: /_static/docstring_previews/augur_scatterplot.png
1222
+ """
1223
+ cell_types = results1["summary_metrics"].columns
1224
+
1225
+ fig, ax = plt.subplots()
1226
+ ax.scatter(
1227
+ results1["summary_metrics"].loc["mean_augur_score", cell_types],
1228
+ results2["summary_metrics"].loc["mean_augur_score", cell_types],
1229
+ )
1230
+
1231
+ # adding optional labels
1232
+ top_n_cell_types = (
1233
+ (results1["summary_metrics"].loc["mean_augur_score"] - results2["summary_metrics"].loc["mean_augur_score"])
1234
+ .sort_values(ascending=False)
1235
+ .index[:top_n]
1236
+ )
1237
+ for txt in top_n_cell_types:
1238
+ ax.annotate(
1239
+ txt,
1240
+ (
1241
+ results1["summary_metrics"].loc["mean_augur_score", txt],
1242
+ results2["summary_metrics"].loc["mean_augur_score", txt],
1243
+ ),
1244
+ )
1245
+
1246
+ # adding diagonal
1247
+ limits = max(ax.get_xlim(), ax.get_ylim())
1248
+ (diag_line,) = ax.plot(limits, limits, ls="--", c=".3")
1249
+
1250
+ plt.xlabel("Augur scores 1")
1251
+ plt.ylabel("Augur scores 2")
1252
+
1253
+ if save:
1254
+ plt.savefig(save, bbox_inches="tight")
1255
+ if show:
1256
+ plt.show()
1257
+ if return_fig:
1258
+ return plt.gcf()
1259
+ if not (show or save):
1260
+ return ax
1261
+ return None