pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (53) hide show
  1. pertpy/__init__.py +3 -2
  2. pertpy/data/__init__.py +5 -1
  3. pertpy/data/_dataloader.py +2 -4
  4. pertpy/data/_datasets.py +203 -92
  5. pertpy/metadata/__init__.py +4 -0
  6. pertpy/metadata/_cell_line.py +826 -0
  7. pertpy/metadata/_compound.py +129 -0
  8. pertpy/metadata/_drug.py +242 -0
  9. pertpy/metadata/_look_up.py +582 -0
  10. pertpy/metadata/_metadata.py +73 -0
  11. pertpy/metadata/_moa.py +129 -0
  12. pertpy/plot/__init__.py +1 -9
  13. pertpy/plot/_augur.py +53 -116
  14. pertpy/plot/_coda.py +277 -677
  15. pertpy/plot/_guide_rna.py +17 -35
  16. pertpy/plot/_milopy.py +59 -134
  17. pertpy/plot/_mixscape.py +152 -391
  18. pertpy/preprocessing/_guide_rna.py +88 -4
  19. pertpy/tools/__init__.py +8 -13
  20. pertpy/tools/_augur.py +315 -17
  21. pertpy/tools/_cinemaot.py +143 -4
  22. pertpy/tools/_coda/_base_coda.py +1210 -65
  23. pertpy/tools/_coda/_sccoda.py +50 -21
  24. pertpy/tools/_coda/_tasccoda.py +27 -19
  25. pertpy/tools/_dialogue.py +164 -56
  26. pertpy/tools/_differential_gene_expression.py +240 -14
  27. pertpy/tools/_distances/_distance_tests.py +8 -8
  28. pertpy/tools/_distances/_distances.py +184 -34
  29. pertpy/tools/_enrichment.py +465 -0
  30. pertpy/tools/_milo.py +345 -11
  31. pertpy/tools/_mixscape.py +668 -50
  32. pertpy/tools/_perturbation_space/_clustering.py +5 -1
  33. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
  34. pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
  35. pertpy/tools/_perturbation_space/_simple.py +51 -10
  36. pertpy/tools/_scgen/__init__.py +1 -1
  37. pertpy/tools/_scgen/_scgen.py +701 -0
  38. pertpy/tools/_scgen/_utils.py +1 -3
  39. pertpy/tools/decoupler_LICENSE +674 -0
  40. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
  41. pertpy-0.7.0.dist-info/RECORD +53 -0
  42. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
  43. pertpy/plot/_cinemaot.py +0 -81
  44. pertpy/plot/_dialogue.py +0 -91
  45. pertpy/plot/_scgen.py +0 -337
  46. pertpy/tools/_metadata/__init__.py +0 -0
  47. pertpy/tools/_metadata/_cell_line.py +0 -613
  48. pertpy/tools/_metadata/_look_up.py +0 -342
  49. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  50. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  51. pertpy-0.6.0.dist-info/RECORD +0 -50
  52. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  53. {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,465 @@
1
+ from collections import ChainMap
2
+ from collections.abc import Sequence
3
+ from typing import Any, Literal
4
+
5
+ import blitzgsea
6
+ import numpy as np
7
+ import pandas as pd
8
+ import scanpy as sc
9
+ from anndata import AnnData
10
+ from matplotlib.axes import Axes
11
+ from scanpy.plotting import DotPlot
12
+ from scanpy.tools._score_genes import _sparse_nanmean
13
+ from scipy.sparse import issparse
14
+ from scipy.stats import hypergeom
15
+ from statsmodels.stats.multitest import multipletests
16
+
17
+ from pertpy.metadata import Drug
18
+
19
+
20
+ def _prepare_targets(
21
+ targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
22
+ nested: bool = False,
23
+ categories: str | Sequence[str] = None,
24
+ ) -> ChainMap | dict:
25
+ if categories is not None:
26
+ if isinstance(categories, str):
27
+ categories = [categories]
28
+ else:
29
+ categories = list(categories)
30
+
31
+ if targets is None:
32
+ pt_drug = Drug()
33
+ pt_drug.chembl.set()
34
+ targets = pt_drug.chembl.dictionary
35
+ nested = True
36
+ else:
37
+ targets = targets.copy()
38
+ if categories is not None:
39
+ targets = {k: targets[k] for k in categories} # type: ignore
40
+ if nested:
41
+ targets = dict(ChainMap(*[targets[cat] for cat in targets])) # type: ignore
42
+
43
+ return targets
44
+
45
+
46
+ def _mean(X, names, axis):
47
+ """Helper function to compute a mean of X across an axis, respecting names and possible nans."""
48
+ if issparse(X):
49
+ obs_avg = pd.Series(
50
+ np.array(_sparse_nanmean(X, axis=axis)).flatten(),
51
+ index=names,
52
+ )
53
+ else:
54
+ obs_avg = pd.Series(np.nanmean(X, axis=axis), index=names)
55
+ return obs_avg
56
+
57
+
58
+ class Enrichment:
59
+ def score(
60
+ self,
61
+ adata: AnnData,
62
+ layer: str = None,
63
+ targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
64
+ nested: bool = False,
65
+ categories: Sequence[str] = None,
66
+ method: Literal["mean", "seurat"] = "mean",
67
+ n_bins: int = 25,
68
+ ctrl_size: int = 50,
69
+ key_added: str = "pertpy_enrichment",
70
+ ) -> None:
71
+ """Obtain per-cell scoring of gene groups of interest.
72
+
73
+ Inspired by drug2cell score: https://github.com/Teichlab/drug2cell.
74
+ Ensure that the gene nomenclature in your target sets is compatible with your
75
+ `.var_names`. The ChEMBL drug targets use HGNC.
76
+
77
+ Args:
78
+ adata: An AnnData object. It is recommended to use log-normalised data.
79
+ targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
80
+ Accepts two forms:
81
+ - A dictionary with group names as keys and corresponding gene lists as entries.
82
+ - A dictionary of dictionaries with group categories as keys. Use `nested=True` in this case.
83
+ If not provided, ChEMBL-derived drug target sets are used.
84
+ nested: Indicates if `targets` is a dictionary of dictionaries with group categories as keys.
85
+ Defaults to False.
86
+ categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
87
+ For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
88
+ method: Method for scoring gene groups. `"mean"` calculates the mean over all genes,
89
+ while `"seurat"` uses a background profile subtraction approach.
90
+ Defaults to 'mean'.
91
+ layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
92
+ n_bins: The number of expression bins for the `'seurat'` method.
93
+ ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
94
+ key_added: Prefix key that adds the results to `uns`.
95
+ Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
96
+ Defaults to `pertpy_enrichment`.
97
+
98
+ Returns:
99
+ An AnnData object with scores.
100
+ """
101
+ if layer is not None:
102
+ mtx = adata.layers[layer]
103
+ else:
104
+ mtx = adata.X
105
+
106
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
107
+ full_targets = targets.copy()
108
+
109
+ for drug in targets:
110
+ targets[drug] = np.isin(adata.var_names, targets[drug])
111
+
112
+ # Scoring is done via matrix multiplication of the original cell by gene matrix by a new gene by drug matrix
113
+ # with the entries in the new matrix being the weights of each gene for that group (such as drug)
114
+ # The mean across targets is constant -> prepare weights for that
115
+ weights = pd.DataFrame(targets, index=adata.var_names)
116
+ weights = weights.loc[:, weights.sum() > 0]
117
+ weights = weights / weights.sum()
118
+ if issparse(mtx):
119
+ scores = mtx.dot(weights)
120
+ else:
121
+ scores = np.dot(mtx, weights)
122
+
123
+ if method == "seurat":
124
+ obs_avg = _mean(mtx, names=adata.var_names, axis=0)
125
+ n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
126
+ obs_cut = obs_avg.rank(method="min") // n_items
127
+ obs_cut = obs_cut.values
128
+
129
+ control_groups = {}
130
+ for cut in np.unique(obs_cut):
131
+ mask = obs_cut == cut
132
+ r_genes = np.nonzero(mask)[0]
133
+ rng = np.random.default_rng()
134
+ rng.shuffle(r_genes)
135
+ mask[r_genes[ctrl_size:]] = False
136
+ control_groups[cut] = mask
137
+ control_gene_weights = pd.DataFrame(control_groups, index=adata.var_names)
138
+ control_gene_weights = control_gene_weights / control_gene_weights.sum()
139
+
140
+ if issparse(mtx):
141
+ control_profiles = mtx.dot(control_gene_weights)
142
+ else:
143
+ control_profiles = np.dot(mtx, control_gene_weights)
144
+ drug_bins = {}
145
+ for drug in weights.columns:
146
+ bins = np.unique(obs_cut[targets[drug]])
147
+ drug_bins[drug] = np.isin(control_gene_weights.columns, bins)
148
+ drug_weights = pd.DataFrame(drug_bins, index=control_gene_weights.columns)
149
+ drug_weights = drug_weights / drug_weights.sum()
150
+ seurat = np.dot(control_profiles, drug_weights)
151
+ scores = scores - seurat
152
+
153
+ adata.uns[f"{key_added}_score"] = scores
154
+ adata.uns[f"{key_added}_variables"] = weights.columns
155
+
156
+ adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
157
+ adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
158
+
159
+ for drug in weights.columns:
160
+ adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
161
+ adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
162
+
163
+ def hypergeometric(
164
+ self,
165
+ adata: AnnData,
166
+ targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
167
+ nested: bool = False,
168
+ categories: str | list[str] | None = None,
169
+ pvals_adj_thresh: float = 0.05,
170
+ direction: str = "both",
171
+ corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg",
172
+ ):
173
+ """Perform a hypergeometric test to assess the overrepresentation of gene group members.
174
+
175
+ Args:
176
+ adata: With marker genes computed via `sc.tl.rank_genes_groups()` in the original expression space.
177
+ targets: The gene groups to evaluate. Can be targets of known drugs, GO terms, pathway memberships, anything you can assign genes to.
178
+ If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
179
+ Accepts two forms:
180
+ - A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
181
+ - A dictionary of dictionaries defined like above, with names of gene group categories as keys.
182
+ If passing one of those, specify `nested=True`.
183
+ nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
184
+ categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
185
+ In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
186
+ pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
187
+ direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
188
+ Can be `up`, `down`, or `both` (for no selection).
189
+ corr_method: Which FDR correction to apply to the p-values of the hypergeometric test.
190
+ Can be `benjamini-hochberg` or `bonferroni`.
191
+
192
+ Returns:
193
+ Dictionary with clusters for which the original object markers were computed as the keys,
194
+ and data frames of test results sorted on q-value as the items.
195
+ """
196
+ universe = set(adata.var_names)
197
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
198
+ for group in targets:
199
+ targets[group] = set(targets[group]).intersection(universe) # type: ignore
200
+ # We remove empty keys since we don't need them
201
+ targets = {k: v for k, v in targets.items() if v}
202
+
203
+ overrepresentation = {}
204
+ for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
205
+ results = pd.DataFrame(
206
+ 1,
207
+ index=list(targets.keys()),
208
+ columns=[
209
+ "intersection",
210
+ "gene_group",
211
+ "markers",
212
+ "universe",
213
+ "pvals",
214
+ "pvals_adj",
215
+ ],
216
+ )
217
+ mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < pvals_adj_thresh
218
+ if direction == "up":
219
+ mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] > 0)
220
+ elif direction == "down":
221
+ mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] < 0)
222
+ markers = set(adata.uns["rank_genes_groups"]["names"][cluster][mask])
223
+ results["markers"] = len(markers)
224
+ results["universe"] = len(universe)
225
+ results["pvals"] = results["pvals"].astype(float)
226
+
227
+ for ind in results.index:
228
+ gene_group = targets[ind]
229
+ common = gene_group.intersection(markers) # type: ignore
230
+ results.loc[ind, "intersection"] = len(common)
231
+ results.loc[ind, "gene_group"] = len(gene_group)
232
+ # need to subtract 1 from the intersection length
233
+ # https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
234
+ pval = hypergeom.sf(len(common) - 1, len(universe), len(markers), len(gene_group))
235
+ results.loc[ind, "pvals"] = pval
236
+ # Just in case any NaNs popped up somehow, fill them to 1 so FDR works
237
+ results = results.fillna(1)
238
+ if corr_method == "benjamini-hochberg":
239
+ results["pvals_adj"] = multipletests(results["pvals"], method="fdr_bh")[1]
240
+ elif corr_method == "bonferroni":
241
+ results["pvals_adj"] = np.minimum(results["pvals"] * results.shape[0], 1.0)
242
+ overrepresentation[cluster] = results.sort_values("pvals_adj")
243
+
244
+ return overrepresentation
245
+
246
+ def gsea(
247
+ self,
248
+ adata: "AnnData",
249
+ targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
250
+ nested: bool = False,
251
+ categories: str | list[str] | None = None,
252
+ absolute: bool = False,
253
+ key_added: str = "pertpy_enrichment_gsea",
254
+ ) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
255
+ """Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
256
+
257
+ Args:
258
+ adata: AnnData object with marker genes computed via `sc.tl.rank_genes_groups()`
259
+ in the original expression space.
260
+ targets: The gene groups to evaluate, either as a dictionary with names of the
261
+ groups as keys and gene lists as values, or a dictionary of dictionaries
262
+ with names of gene group categories as keys. Defaults to None, in which
263
+ case it uses `d2c.score()` output or loads ChEMBL-derived drug target sets.
264
+ nested: Indicates if `targets` is a dictionary of dictionaries with group
265
+ categories as keys. Defaults to False.
266
+ categories: Used to subset the gene groups to one or more categories,
267
+ applicable if `targets=None` or `nested=True`. Defaults to None.
268
+ absolute: If True, passes the absolute values of scores to GSEA, improving
269
+ statistical power. Defaults to False.
270
+ key_added: Prefix key that adds the results to `uns`.
271
+ Defaults to `pertpy_enrichment_gsea`.
272
+
273
+ Returns:
274
+ A dictionary with clusters as keys and data frames of test results sorted on
275
+ q-value as the items.
276
+ """
277
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
278
+ enrichment = {}
279
+ plot_gsea_args: dict[str, Any] = {"targets": targets, "scores": {}}
280
+ for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
281
+ df = pd.DataFrame(
282
+ {
283
+ "0": adata.uns["rank_genes_groups"]["names"][cluster],
284
+ "1": adata.uns["rank_genes_groups"]["scores"][cluster],
285
+ }
286
+ )
287
+ if absolute:
288
+ df["1"] = np.absolute(df["1"])
289
+ df = df.sort_values("1", ascending=False)
290
+ enrichment[cluster] = blitzgsea.gsea(df, targets)
291
+ plot_gsea_args["scores"][cluster] = df
292
+
293
+ adata.uns[key_added] = plot_gsea_args
294
+
295
+ return enrichment
296
+
297
+ def plot_dotplot(
298
+ self,
299
+ adata: AnnData,
300
+ targets: dict[str, dict[str, list[str]]] = None,
301
+ source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
302
+ category_name: str = "interaction_type",
303
+ categories: Sequence[str] = None,
304
+ groupby: str = None,
305
+ key: str = "pertpy_enrichment",
306
+ ax: Axes | None = None,
307
+ save: bool | str | None = None,
308
+ show: bool | None = None,
309
+ **kwargs,
310
+ ) -> DotPlot | dict | None:
311
+ """Plots a dotplot by groupby and categories.
312
+
313
+ Wraps scanpy's dotplot but formats it nicely by categories.
314
+
315
+ Args:
316
+ adata: An AnnData object with enrichment results stored in `.uns["pertpy_enrichment_score"]`.
317
+ targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
318
+ Accepts a dictionary of dictionaries with group categories as keys.
319
+ If not provided, ChEMBL-derived or dgbidb drug target sets are used, given by `source`.
320
+ source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`. Defaults to `chembl`.
321
+ categories: To subset the gene groups to specific categories, especially when `targets=None`.
322
+ For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
323
+ category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`. Defaults to `interaction_type`.
324
+ groupby: dotplot groupby such as clusters or cell types.
325
+ key: Prefix key of enrichment results in `uns`.
326
+ Defaults to `pertpy_enrichment`.
327
+ kwargs: Passed to scanpy dotplot.
328
+
329
+ Returns:
330
+ If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
331
+ else if `show` is false, return axes dict.
332
+
333
+ Examples:
334
+ >>> import pertpy as pt
335
+ >>> import scanpy as sc
336
+ >>> pt_enrichment = pt.tl.Enrichment()
337
+ >>> adata = sc.datasets.pbmc3k_processed()
338
+ >>> pt_enrichment.score(adata)
339
+ >>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
340
+ >>> pt_enrichment.plot_dotplot(adata, categories=["B01", "B02", "B03"], groupby="louvain")
341
+
342
+ Preview:
343
+ .. image:: /_static/docstring_previews/enrichment_dotplot.png
344
+ """
345
+ if categories is not None:
346
+ if isinstance(categories, str):
347
+ categories = [categories]
348
+ else:
349
+ categories = list(categories)
350
+
351
+ if targets is None:
352
+ pt_drug = Drug()
353
+ if source == "chembl":
354
+ pt_drug.chembl.set()
355
+ targets = pt_drug.chembl.dictionary
356
+ elif source == "dgidb":
357
+ pt_drug.dgidb.set()
358
+ interaction = pt_drug.dgidb.data
359
+ if category_name not in interaction.columns:
360
+ raise ValueError("The category name is not available in dgidb drug target data.")
361
+ interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
362
+ targets = (
363
+ interaction.groupby(category_name)
364
+ .apply(lambda x: x.groupby("drug_claim_name")["gene_claim_name"].apply(list).to_dict())
365
+ .to_dict()
366
+ )
367
+ else:
368
+ pt_drug.pharmgkb.set()
369
+ interaction = pt_drug.pharmgkb.data
370
+ if category_name not in interaction.columns:
371
+ raise ValueError("The category name is not available in pharmgkb drug target data.")
372
+ interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
373
+ targets = (
374
+ interaction.groupby(category_name)
375
+ .apply(lambda x: x.groupby("Compound|Disease")["Gene"].apply(list).to_dict())
376
+ .to_dict()
377
+ )
378
+ else:
379
+ targets = targets.copy()
380
+ if categories is not None:
381
+ targets = {k: targets[k] for k in categories} # type: ignore
382
+
383
+ for group in targets:
384
+ targets[group] = list(targets[group].keys()) # type: ignore
385
+
386
+ var_names: list[str] = []
387
+ var_group_positions: list[tuple[int, int]] = []
388
+ var_group_labels: list[str] = []
389
+ start = 0
390
+
391
+ enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
392
+ enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]
393
+
394
+ for group in targets:
395
+ targets[group] = list( # type: ignore
396
+ enrichment_score_adata.var_names[np.isin(enrichment_score_adata.var_names, targets[group])]
397
+ )
398
+ if len(targets[group]) == 0:
399
+ continue
400
+ var_names = var_names + targets[group] # type: ignore
401
+ var_group_positions = var_group_positions + [(start, len(var_names) - 1)]
402
+ var_group_labels = var_group_labels + [group]
403
+ start = len(var_names)
404
+
405
+ plot_args = {
406
+ "var_names": var_names,
407
+ "var_group_positions": var_group_positions,
408
+ "var_group_labels": var_group_labels,
409
+ }
410
+
411
+ return sc.pl.dotplot(
412
+ enrichment_score_adata,
413
+ groupby=groupby,
414
+ swap_axes=True,
415
+ ax=ax,
416
+ save=save,
417
+ show=show,
418
+ **plot_args,
419
+ **kwargs,
420
+ )
421
+
422
+ def plot_gsea(
423
+ self,
424
+ adata: AnnData,
425
+ enrichment: dict[str, pd.DataFrame],
426
+ n: int = 10,
427
+ key: str = "pertpy_enrichment_gsea",
428
+ interactive_plot: bool = False,
429
+ ) -> None:
430
+ """Generates a blitzgsea top_table plot.
431
+
432
+ This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
433
+ It uses the output from the `gsea()` method, which provides the enrichment data,
434
+ and displays the top results using blitzgsea's `top_table()` plot.
435
+
436
+ Args:
437
+ adata: AnnData object to plot.
438
+ enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
439
+ n: How many top scores to show for each group. Defaults to 10.
440
+ key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea".
441
+ interactive_plot: Whether to plot interactively or not. Defaults to False.
442
+
443
+ Examples:
444
+ >>> import pertpy as pt
445
+ >>> import scanpy as sc
446
+ >>> pt_enrichment = pt.tl.Enrichment()
447
+ >>> adata = sc.datasets.pbmc3k_processed()
448
+ >>> pt_enrichment.score(adata)
449
+ >>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
450
+ >>> enrichment = pt_enrichment.gsea(adata)
451
+ >>> pt_enrichment.plot_gsea(adata, enrichment, interactive_plot=True)
452
+
453
+ Preview:
454
+ .. image:: /_static/docstring_previews/enrichment_gsea.png
455
+ """
456
+ for cluster in enrichment:
457
+ fig = blitzgsea.plot.top_table(
458
+ adata.uns[key]["scores"][cluster],
459
+ adata.uns[key]["targets"],
460
+ enrichment[cluster],
461
+ n=n,
462
+ interactive_plot=interactive_plot,
463
+ )
464
+ fig.suptitle(cluster)
465
+ fig.show()