pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,460 @@
1
+ from collections import ChainMap
2
+ from collections.abc import Sequence
3
+ from typing import Any, Literal
4
+
5
+ import blitzgsea
6
+ import numpy as np
7
+ import pandas as pd
8
+ import scanpy as sc
9
+ from anndata import AnnData
10
+ from matplotlib.axes import Axes
11
+ from scanpy.plotting import DotPlot
12
+ from scanpy.tools._score_genes import _sparse_nanmean
13
+ from scipy.sparse import issparse
14
+ from scipy.stats import hypergeom
15
+ from statsmodels.stats.multitest import multipletests
16
+
17
+ from pertpy.metadata import Drug
18
+
19
+
20
+ def _prepare_targets(
21
+ targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
22
+ nested: bool = False,
23
+ categories: str | Sequence[str] = None,
24
+ ) -> ChainMap | dict:
25
+ if categories is not None:
26
+ if isinstance(categories, str):
27
+ categories = [categories]
28
+ else:
29
+ categories = list(categories)
30
+
31
+ if targets is None:
32
+ pt_drug = Drug()
33
+ pt_drug.chembl.set()
34
+ targets = pt_drug.chembl.dictionary
35
+ nested = True
36
+ else:
37
+ targets = targets.copy()
38
+ if categories is not None:
39
+ targets = {k: targets[k] for k in categories} # type: ignore
40
+ if nested:
41
+ targets = dict(ChainMap(*[targets[cat] for cat in targets])) # type: ignore
42
+
43
+ return targets
44
+
45
+
46
+ def _mean(X, names, axis):
47
+ """Helper function to compute a mean of X across an axis, respecting names and possible nans."""
48
+ if issparse(X):
49
+ obs_avg = pd.Series(
50
+ np.array(_sparse_nanmean(X, axis=axis)).flatten(),
51
+ index=names,
52
+ )
53
+ else:
54
+ obs_avg = pd.Series(np.nanmean(X, axis=axis), index=names)
55
+ return obs_avg
56
+
57
+
58
+ class Enrichment:
59
+ def score(
60
+ self,
61
+ adata: AnnData,
62
+ layer: str = None,
63
+ targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
64
+ nested: bool = False,
65
+ categories: Sequence[str] = None,
66
+ method: Literal["mean", "seurat"] = "mean",
67
+ n_bins: int = 25,
68
+ ctrl_size: int = 50,
69
+ key_added: str = "pertpy_enrichment",
70
+ ) -> None:
71
+ """Obtain per-cell scoring of gene groups of interest.
72
+
73
+ Inspired by drug2cell score: https://github.com/Teichlab/drug2cell.
74
+ Ensure that the gene nomenclature in your target sets is compatible with your
75
+ `.var_names`. The ChEMBL drug targets use HGNC.
76
+
77
+ Args:
78
+ adata: An AnnData object. It is recommended to use log-normalised data.
79
+ targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
80
+ Accepts two forms:
81
+ - A dictionary with group names as keys and corresponding gene lists as entries.
82
+ - A dictionary of dictionaries with group categories as keys. Use `nested=True` in this case.
83
+ If not provided, ChEMBL-derived drug target sets are used.
84
+ nested: Indicates if `targets` is a dictionary of dictionaries with group categories as keys.
85
+ categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
86
+ For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
87
+ method: Method for scoring gene groups. `"mean"` calculates the mean over all genes,
88
+ while `"seurat"` uses a background profile subtraction approach.
89
+ layer: Specifies which `.layers` of AnnData to use for expression values.
90
+ n_bins: The number of expression bins for the `'seurat'` method.
91
+ ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
92
+ key_added: Prefix key that adds the results to `uns`.
93
+ Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
94
+
95
+ Returns:
96
+ An AnnData object with scores.
97
+ """
98
+ if layer is not None:
99
+ mtx = adata.layers[layer]
100
+ else:
101
+ mtx = adata.X
102
+
103
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
104
+ full_targets = targets.copy()
105
+
106
+ for drug in targets:
107
+ targets[drug] = np.isin(adata.var_names, targets[drug])
108
+
109
+ # Scoring is done via matrix multiplication of the original cell by gene matrix by a new gene by drug matrix
110
+ # with the entries in the new matrix being the weights of each gene for that group (such as drug)
111
+ # The mean across targets is constant -> prepare weights for that
112
+ weights = pd.DataFrame(targets, index=adata.var_names)
113
+ weights = weights.loc[:, weights.sum() > 0]
114
+ weights = weights / weights.sum()
115
+ if issparse(mtx):
116
+ scores = mtx.dot(weights)
117
+ else:
118
+ scores = np.dot(mtx, weights)
119
+
120
+ if method == "seurat":
121
+ obs_avg = _mean(mtx, names=adata.var_names, axis=0)
122
+ n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
123
+ obs_cut = obs_avg.rank(method="min") // n_items
124
+ obs_cut = obs_cut.values
125
+
126
+ control_groups = {}
127
+ for cut in np.unique(obs_cut):
128
+ mask = obs_cut == cut
129
+ r_genes = np.nonzero(mask)[0]
130
+ rng = np.random.default_rng()
131
+ rng.shuffle(r_genes)
132
+ mask[r_genes[ctrl_size:]] = False
133
+ control_groups[cut] = mask
134
+ control_gene_weights = pd.DataFrame(control_groups, index=adata.var_names)
135
+ control_gene_weights = control_gene_weights / control_gene_weights.sum()
136
+
137
+ if issparse(mtx):
138
+ control_profiles = mtx.dot(control_gene_weights)
139
+ else:
140
+ control_profiles = np.dot(mtx, control_gene_weights)
141
+ drug_bins = {}
142
+ for drug in weights.columns:
143
+ bins = np.unique(obs_cut[targets[drug]])
144
+ drug_bins[drug] = np.isin(control_gene_weights.columns, bins)
145
+ drug_weights = pd.DataFrame(drug_bins, index=control_gene_weights.columns)
146
+ drug_weights = drug_weights / drug_weights.sum()
147
+ seurat = np.dot(control_profiles, drug_weights)
148
+ scores = scores - seurat
149
+
150
+ adata.uns[f"{key_added}_score"] = scores
151
+ adata.uns[f"{key_added}_variables"] = weights.columns
152
+
153
+ adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
154
+ adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
155
+
156
+ for drug in weights.columns:
157
+ adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
158
+ adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
159
+
160
+ def hypergeometric(
161
+ self,
162
+ adata: AnnData,
163
+ targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
164
+ nested: bool = False,
165
+ categories: str | list[str] | None = None,
166
+ pvals_adj_thresh: float = 0.05,
167
+ direction: str = "both",
168
+ corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg",
169
+ ):
170
+ """Perform a hypergeometric test to assess the overrepresentation of gene group members.
171
+
172
+ Args:
173
+ adata: With marker genes computed via `sc.tl.rank_genes_groups()` in the original expression space.
174
+ targets: The gene groups to evaluate. Can be targets of known drugs, GO terms, pathway memberships, anything you can assign genes to.
175
+ If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
176
+ Accepts two forms:
177
+ - A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
178
+ - A dictionary of dictionaries defined like above, with names of gene group categories as keys.
179
+ If passing one of those, specify `nested=True`.
180
+ nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
181
+ categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
182
+ In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
183
+ pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
184
+ direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
185
+ Can be `up`, `down`, or `both` (for no selection).
186
+ corr_method: Which FDR correction to apply to the p-values of the hypergeometric test.
187
+ Can be `benjamini-hochberg` or `bonferroni`.
188
+
189
+ Returns:
190
+ Dictionary with clusters for which the original object markers were computed as the keys,
191
+ and data frames of test results sorted on q-value as the items.
192
+ """
193
+ universe = set(adata.var_names)
194
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
195
+ for group in targets:
196
+ targets[group] = set(targets[group]).intersection(universe) # type: ignore
197
+ # We remove empty keys since we don't need them
198
+ targets = {k: v for k, v in targets.items() if v}
199
+
200
+ overrepresentation = {}
201
+ for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
202
+ results = pd.DataFrame(
203
+ 1,
204
+ index=list(targets.keys()),
205
+ columns=[
206
+ "intersection",
207
+ "gene_group",
208
+ "markers",
209
+ "universe",
210
+ "pvals",
211
+ "pvals_adj",
212
+ ],
213
+ )
214
+ mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < pvals_adj_thresh
215
+ if direction == "up":
216
+ mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] > 0)
217
+ elif direction == "down":
218
+ mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] < 0)
219
+ markers = set(adata.uns["rank_genes_groups"]["names"][cluster][mask])
220
+ results["markers"] = len(markers)
221
+ results["universe"] = len(universe)
222
+ results["pvals"] = results["pvals"].astype(float)
223
+
224
+ for ind in results.index:
225
+ gene_group = targets[ind]
226
+ common = gene_group.intersection(markers) # type: ignore
227
+ results.loc[ind, "intersection"] = len(common)
228
+ results.loc[ind, "gene_group"] = len(gene_group)
229
+ # need to subtract 1 from the intersection length
230
+ # https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
231
+ pval = hypergeom.sf(len(common) - 1, len(universe), len(markers), len(gene_group))
232
+ results.loc[ind, "pvals"] = pval
233
+ # Just in case any NaNs popped up somehow, fill them to 1 so FDR works
234
+ results = results.fillna(1)
235
+ if corr_method == "benjamini-hochberg":
236
+ results["pvals_adj"] = multipletests(results["pvals"], method="fdr_bh")[1]
237
+ elif corr_method == "bonferroni":
238
+ results["pvals_adj"] = np.minimum(results["pvals"] * results.shape[0], 1.0)
239
+ overrepresentation[cluster] = results.sort_values("pvals_adj")
240
+
241
+ return overrepresentation
242
+
243
+ def gsea(
244
+ self,
245
+ adata: "AnnData",
246
+ targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
247
+ nested: bool = False,
248
+ categories: str | list[str] | None = None,
249
+ absolute: bool = False,
250
+ key_added: str = "pertpy_enrichment_gsea",
251
+ ) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
252
+ """Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
253
+
254
+ Args:
255
+ adata: AnnData object with marker genes computed via `sc.tl.rank_genes_groups()`
256
+ in the original expression space.
257
+ targets: The gene groups to evaluate, either as a dictionary with names of the
258
+ groups as keys and gene lists as values, or a dictionary of dictionaries
259
+ with names of gene group categories as keys.
260
+ case it uses `d2c.score()` output or loads ChEMBL-derived drug target sets.
261
+ nested: Indicates if `targets` is a dictionary of dictionaries with group
262
+ categories as keys.
263
+ categories: Used to subset the gene groups to one or more categories,
264
+ applicable if `targets=None` or `nested=True`.
265
+ absolute: If True, passes the absolute values of scores to GSEA, improving
266
+ statistical power.
267
+ key_added: Prefix key that adds the results to `uns`.
268
+
269
+ Returns:
270
+ A dictionary with clusters as keys and data frames of test results sorted on
271
+ q-value as the items.
272
+ """
273
+ targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
274
+ enrichment = {}
275
+ plot_gsea_args: dict[str, Any] = {"targets": targets, "scores": {}}
276
+ for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
277
+ df = pd.DataFrame(
278
+ {
279
+ "0": adata.uns["rank_genes_groups"]["names"][cluster],
280
+ "1": adata.uns["rank_genes_groups"]["scores"][cluster],
281
+ }
282
+ )
283
+ if absolute:
284
+ df["1"] = np.absolute(df["1"])
285
+ df = df.sort_values("1", ascending=False)
286
+ enrichment[cluster] = blitzgsea.gsea(df, targets)
287
+ plot_gsea_args["scores"][cluster] = df
288
+
289
+ adata.uns[key_added] = plot_gsea_args
290
+
291
+ return enrichment
292
+
293
+ def plot_dotplot(
294
+ self,
295
+ adata: AnnData,
296
+ targets: dict[str, dict[str, list[str]]] = None,
297
+ source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
298
+ category_name: str = "interaction_type",
299
+ categories: Sequence[str] = None,
300
+ groupby: str = None,
301
+ key: str = "pertpy_enrichment",
302
+ ax: Axes | None = None,
303
+ save: bool | str | None = None,
304
+ show: bool | None = None,
305
+ **kwargs,
306
+ ) -> DotPlot | dict | None:
307
+ """Plots a dotplot by groupby and categories.
308
+
309
+ Wraps scanpy's dotplot but formats it nicely by categories.
310
+
311
+ Args:
312
+ adata: An AnnData object with enrichment results stored in `.uns["pertpy_enrichment_score"]`.
313
+ targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
314
+ Accepts a dictionary of dictionaries with group categories as keys.
315
+ If not provided, ChEMBL-derived or dgbidb drug target sets are used, given by `source`.
316
+ source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`.
317
+ categories: To subset the gene groups to specific categories, especially when `targets=None`.
318
+ For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
319
+ category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
320
+ groupby: dotplot groupby such as clusters or cell types.
321
+ key: Prefix key of enrichment results in `uns`.
322
+ kwargs: Passed to scanpy dotplot.
323
+
324
+ Returns:
325
+ If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
326
+ else if `show` is false, return axes dict.
327
+
328
+ Examples:
329
+ >>> import pertpy as pt
330
+ >>> import scanpy as sc
331
+ >>> pt_enrichment = pt.tl.Enrichment()
332
+ >>> adata = sc.datasets.pbmc3k_processed()
333
+ >>> pt_enrichment.score(adata)
334
+ >>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
335
+ >>> pt_enrichment.plot_dotplot(adata, categories=["B01", "B02", "B03"], groupby="louvain")
336
+
337
+ Preview:
338
+ .. image:: /_static/docstring_previews/enrichment_dotplot.png
339
+ """
340
+ if categories is not None:
341
+ if isinstance(categories, str):
342
+ categories = [categories]
343
+ else:
344
+ categories = list(categories)
345
+
346
+ if targets is None:
347
+ pt_drug = Drug()
348
+ if source == "chembl":
349
+ pt_drug.chembl.set()
350
+ targets = pt_drug.chembl.dictionary
351
+ elif source == "dgidb":
352
+ pt_drug.dgidb.set()
353
+ interaction = pt_drug.dgidb.data
354
+ if category_name not in interaction.columns:
355
+ raise ValueError("The category name is not available in dgidb drug target data.")
356
+ interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
357
+ targets = (
358
+ interaction.groupby(category_name)
359
+ .apply(lambda x: x.groupby("drug_claim_name")["gene_claim_name"].apply(list).to_dict())
360
+ .to_dict()
361
+ )
362
+ else:
363
+ pt_drug.pharmgkb.set()
364
+ interaction = pt_drug.pharmgkb.data
365
+ if category_name not in interaction.columns:
366
+ raise ValueError("The category name is not available in pharmgkb drug target data.")
367
+ interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
368
+ targets = (
369
+ interaction.groupby(category_name)
370
+ .apply(lambda x: x.groupby("Compound|Disease")["Gene"].apply(list).to_dict())
371
+ .to_dict()
372
+ )
373
+ else:
374
+ targets = targets.copy()
375
+ if categories is not None:
376
+ targets = {k: targets[k] for k in categories} # type: ignore
377
+
378
+ for group in targets:
379
+ targets[group] = list(targets[group].keys()) # type: ignore
380
+
381
+ var_names: list[str] = []
382
+ var_group_positions: list[tuple[int, int]] = []
383
+ var_group_labels: list[str] = []
384
+ start = 0
385
+
386
+ enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
387
+ enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]
388
+
389
+ for group in targets:
390
+ targets[group] = list( # type: ignore
391
+ enrichment_score_adata.var_names[np.isin(enrichment_score_adata.var_names, targets[group])]
392
+ )
393
+ if len(targets[group]) == 0:
394
+ continue
395
+ var_names = var_names + targets[group] # type: ignore
396
+ var_group_positions = var_group_positions + [(start, len(var_names) - 1)]
397
+ var_group_labels = var_group_labels + [group]
398
+ start = len(var_names)
399
+
400
+ plot_args = {
401
+ "var_names": var_names,
402
+ "var_group_positions": var_group_positions,
403
+ "var_group_labels": var_group_labels,
404
+ }
405
+
406
+ return sc.pl.dotplot(
407
+ enrichment_score_adata,
408
+ groupby=groupby,
409
+ swap_axes=True,
410
+ ax=ax,
411
+ save=save,
412
+ show=show,
413
+ **plot_args,
414
+ **kwargs,
415
+ )
416
+
417
+ def plot_gsea(
418
+ self,
419
+ adata: AnnData,
420
+ enrichment: dict[str, pd.DataFrame],
421
+ n: int = 10,
422
+ key: str = "pertpy_enrichment_gsea",
423
+ interactive_plot: bool = False,
424
+ ) -> None:
425
+ """Generates a blitzgsea top_table plot.
426
+
427
+ This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
428
+ It uses the output from the `gsea()` method, which provides the enrichment data,
429
+ and displays the top results using blitzgsea's `top_table()` plot.
430
+
431
+ Args:
432
+ adata: AnnData object to plot.
433
+ enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
434
+ n: How many top scores to show for each group.
435
+ key: GSEA results key in `uns`.
436
+ interactive_plot: Whether to plot interactively or not.
437
+
438
+ Examples:
439
+ >>> import pertpy as pt
440
+ >>> import scanpy as sc
441
+ >>> pt_enrichment = pt.tl.Enrichment()
442
+ >>> adata = sc.datasets.pbmc3k_processed()
443
+ >>> pt_enrichment.score(adata)
444
+ >>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
445
+ >>> enrichment = pt_enrichment.gsea(adata)
446
+ >>> pt_enrichment.plot_gsea(adata, enrichment, interactive_plot=True)
447
+
448
+ Preview:
449
+ .. image:: /_static/docstring_previews/enrichment_gsea.png
450
+ """
451
+ for cluster in enrichment:
452
+ fig = blitzgsea.plot.top_table(
453
+ adata.uns[key]["scores"][cluster],
454
+ adata.uns[key]["targets"],
455
+ enrichment[cluster],
456
+ n=n,
457
+ interactive_plot=interactive_plot,
458
+ )
459
+ fig.suptitle(cluster)
460
+ fig.show()
@@ -31,7 +31,7 @@ def kernel_pca(
31
31
 
32
32
  Returns:
33
33
  If `copy=True`, returns the copy of `adata` with kernel pca in `.obsm["X_kpca"]`.
34
- Otherwise writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
34
+ Otherwise, writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
35
35
  If `return_transformer=True`, returns also the fitted `KernelPCA` transformer.
36
36
  """
37
37
  if copy: