pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,465 @@
|
|
1
|
+
from collections import ChainMap
|
2
|
+
from collections.abc import Sequence
|
3
|
+
from typing import Any, Literal
|
4
|
+
|
5
|
+
import blitzgsea
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import scanpy as sc
|
9
|
+
from anndata import AnnData
|
10
|
+
from matplotlib.axes import Axes
|
11
|
+
from scanpy.plotting import DotPlot
|
12
|
+
from scanpy.tools._score_genes import _sparse_nanmean
|
13
|
+
from scipy.sparse import issparse
|
14
|
+
from scipy.stats import hypergeom
|
15
|
+
from statsmodels.stats.multitest import multipletests
|
16
|
+
|
17
|
+
from pertpy.metadata import Drug
|
18
|
+
|
19
|
+
|
20
|
+
def _prepare_targets(
|
21
|
+
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
|
22
|
+
nested: bool = False,
|
23
|
+
categories: str | Sequence[str] = None,
|
24
|
+
) -> ChainMap | dict:
|
25
|
+
if categories is not None:
|
26
|
+
if isinstance(categories, str):
|
27
|
+
categories = [categories]
|
28
|
+
else:
|
29
|
+
categories = list(categories)
|
30
|
+
|
31
|
+
if targets is None:
|
32
|
+
pt_drug = Drug()
|
33
|
+
pt_drug.chembl.set()
|
34
|
+
targets = pt_drug.chembl.dictionary
|
35
|
+
nested = True
|
36
|
+
else:
|
37
|
+
targets = targets.copy()
|
38
|
+
if categories is not None:
|
39
|
+
targets = {k: targets[k] for k in categories} # type: ignore
|
40
|
+
if nested:
|
41
|
+
targets = dict(ChainMap(*[targets[cat] for cat in targets])) # type: ignore
|
42
|
+
|
43
|
+
return targets
|
44
|
+
|
45
|
+
|
46
|
+
def _mean(X, names, axis):
|
47
|
+
"""Helper function to compute a mean of X across an axis, respecting names and possible nans."""
|
48
|
+
if issparse(X):
|
49
|
+
obs_avg = pd.Series(
|
50
|
+
np.array(_sparse_nanmean(X, axis=axis)).flatten(),
|
51
|
+
index=names,
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
obs_avg = pd.Series(np.nanmean(X, axis=axis), index=names)
|
55
|
+
return obs_avg
|
56
|
+
|
57
|
+
|
58
|
+
class Enrichment:
|
59
|
+
def score(
|
60
|
+
self,
|
61
|
+
adata: AnnData,
|
62
|
+
layer: str = None,
|
63
|
+
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
|
64
|
+
nested: bool = False,
|
65
|
+
categories: Sequence[str] = None,
|
66
|
+
method: Literal["mean", "seurat"] = "mean",
|
67
|
+
n_bins: int = 25,
|
68
|
+
ctrl_size: int = 50,
|
69
|
+
key_added: str = "pertpy_enrichment",
|
70
|
+
) -> None:
|
71
|
+
"""Obtain per-cell scoring of gene groups of interest.
|
72
|
+
|
73
|
+
Inspired by drug2cell score: https://github.com/Teichlab/drug2cell.
|
74
|
+
Ensure that the gene nomenclature in your target sets is compatible with your
|
75
|
+
`.var_names`. The ChEMBL drug targets use HGNC.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
adata: An AnnData object. It is recommended to use log-normalised data.
|
79
|
+
targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
|
80
|
+
Accepts two forms:
|
81
|
+
- A dictionary with group names as keys and corresponding gene lists as entries.
|
82
|
+
- A dictionary of dictionaries with group categories as keys. Use `nested=True` in this case.
|
83
|
+
If not provided, ChEMBL-derived drug target sets are used.
|
84
|
+
nested: Indicates if `targets` is a dictionary of dictionaries with group categories as keys.
|
85
|
+
Defaults to False.
|
86
|
+
categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
|
87
|
+
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
88
|
+
method: Method for scoring gene groups. `"mean"` calculates the mean over all genes,
|
89
|
+
while `"seurat"` uses a background profile subtraction approach.
|
90
|
+
Defaults to 'mean'.
|
91
|
+
layer: Specifies which `.layers` of AnnData to use for expression values. Defaults to `.X` if None.
|
92
|
+
n_bins: The number of expression bins for the `'seurat'` method.
|
93
|
+
ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
|
94
|
+
key_added: Prefix key that adds the results to `uns`.
|
95
|
+
Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
|
96
|
+
Defaults to `pertpy_enrichment`.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
An AnnData object with scores.
|
100
|
+
"""
|
101
|
+
if layer is not None:
|
102
|
+
mtx = adata.layers[layer]
|
103
|
+
else:
|
104
|
+
mtx = adata.X
|
105
|
+
|
106
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
107
|
+
full_targets = targets.copy()
|
108
|
+
|
109
|
+
for drug in targets:
|
110
|
+
targets[drug] = np.isin(adata.var_names, targets[drug])
|
111
|
+
|
112
|
+
# Scoring is done via matrix multiplication of the original cell by gene matrix by a new gene by drug matrix
|
113
|
+
# with the entries in the new matrix being the weights of each gene for that group (such as drug)
|
114
|
+
# The mean across targets is constant -> prepare weights for that
|
115
|
+
weights = pd.DataFrame(targets, index=adata.var_names)
|
116
|
+
weights = weights.loc[:, weights.sum() > 0]
|
117
|
+
weights = weights / weights.sum()
|
118
|
+
if issparse(mtx):
|
119
|
+
scores = mtx.dot(weights)
|
120
|
+
else:
|
121
|
+
scores = np.dot(mtx, weights)
|
122
|
+
|
123
|
+
if method == "seurat":
|
124
|
+
obs_avg = _mean(mtx, names=adata.var_names, axis=0)
|
125
|
+
n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
|
126
|
+
obs_cut = obs_avg.rank(method="min") // n_items
|
127
|
+
obs_cut = obs_cut.values
|
128
|
+
|
129
|
+
control_groups = {}
|
130
|
+
for cut in np.unique(obs_cut):
|
131
|
+
mask = obs_cut == cut
|
132
|
+
r_genes = np.nonzero(mask)[0]
|
133
|
+
rng = np.random.default_rng()
|
134
|
+
rng.shuffle(r_genes)
|
135
|
+
mask[r_genes[ctrl_size:]] = False
|
136
|
+
control_groups[cut] = mask
|
137
|
+
control_gene_weights = pd.DataFrame(control_groups, index=adata.var_names)
|
138
|
+
control_gene_weights = control_gene_weights / control_gene_weights.sum()
|
139
|
+
|
140
|
+
if issparse(mtx):
|
141
|
+
control_profiles = mtx.dot(control_gene_weights)
|
142
|
+
else:
|
143
|
+
control_profiles = np.dot(mtx, control_gene_weights)
|
144
|
+
drug_bins = {}
|
145
|
+
for drug in weights.columns:
|
146
|
+
bins = np.unique(obs_cut[targets[drug]])
|
147
|
+
drug_bins[drug] = np.isin(control_gene_weights.columns, bins)
|
148
|
+
drug_weights = pd.DataFrame(drug_bins, index=control_gene_weights.columns)
|
149
|
+
drug_weights = drug_weights / drug_weights.sum()
|
150
|
+
seurat = np.dot(control_profiles, drug_weights)
|
151
|
+
scores = scores - seurat
|
152
|
+
|
153
|
+
adata.uns[f"{key_added}_score"] = scores
|
154
|
+
adata.uns[f"{key_added}_variables"] = weights.columns
|
155
|
+
|
156
|
+
adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
|
157
|
+
adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
|
158
|
+
|
159
|
+
for drug in weights.columns:
|
160
|
+
adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
|
161
|
+
adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
|
162
|
+
|
163
|
+
def hypergeometric(
|
164
|
+
self,
|
165
|
+
adata: AnnData,
|
166
|
+
targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
|
167
|
+
nested: bool = False,
|
168
|
+
categories: str | list[str] | None = None,
|
169
|
+
pvals_adj_thresh: float = 0.05,
|
170
|
+
direction: str = "both",
|
171
|
+
corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg",
|
172
|
+
):
|
173
|
+
"""Perform a hypergeometric test to assess the overrepresentation of gene group members.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
adata: With marker genes computed via `sc.tl.rank_genes_groups()` in the original expression space.
|
177
|
+
targets: The gene groups to evaluate. Can be targets of known drugs, GO terms, pathway memberships, anything you can assign genes to.
|
178
|
+
If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
|
179
|
+
Accepts two forms:
|
180
|
+
- A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
|
181
|
+
- A dictionary of dictionaries defined like above, with names of gene group categories as keys.
|
182
|
+
If passing one of those, specify `nested=True`.
|
183
|
+
nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
|
184
|
+
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
|
185
|
+
In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
186
|
+
pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
|
187
|
+
direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
|
188
|
+
Can be `up`, `down`, or `both` (for no selection).
|
189
|
+
corr_method: Which FDR correction to apply to the p-values of the hypergeometric test.
|
190
|
+
Can be `benjamini-hochberg` or `bonferroni`.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
Dictionary with clusters for which the original object markers were computed as the keys,
|
194
|
+
and data frames of test results sorted on q-value as the items.
|
195
|
+
"""
|
196
|
+
universe = set(adata.var_names)
|
197
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
198
|
+
for group in targets:
|
199
|
+
targets[group] = set(targets[group]).intersection(universe) # type: ignore
|
200
|
+
# We remove empty keys since we don't need them
|
201
|
+
targets = {k: v for k, v in targets.items() if v}
|
202
|
+
|
203
|
+
overrepresentation = {}
|
204
|
+
for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
|
205
|
+
results = pd.DataFrame(
|
206
|
+
1,
|
207
|
+
index=list(targets.keys()),
|
208
|
+
columns=[
|
209
|
+
"intersection",
|
210
|
+
"gene_group",
|
211
|
+
"markers",
|
212
|
+
"universe",
|
213
|
+
"pvals",
|
214
|
+
"pvals_adj",
|
215
|
+
],
|
216
|
+
)
|
217
|
+
mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < pvals_adj_thresh
|
218
|
+
if direction == "up":
|
219
|
+
mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] > 0)
|
220
|
+
elif direction == "down":
|
221
|
+
mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] < 0)
|
222
|
+
markers = set(adata.uns["rank_genes_groups"]["names"][cluster][mask])
|
223
|
+
results["markers"] = len(markers)
|
224
|
+
results["universe"] = len(universe)
|
225
|
+
results["pvals"] = results["pvals"].astype(float)
|
226
|
+
|
227
|
+
for ind in results.index:
|
228
|
+
gene_group = targets[ind]
|
229
|
+
common = gene_group.intersection(markers) # type: ignore
|
230
|
+
results.loc[ind, "intersection"] = len(common)
|
231
|
+
results.loc[ind, "gene_group"] = len(gene_group)
|
232
|
+
# need to subtract 1 from the intersection length
|
233
|
+
# https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
|
234
|
+
pval = hypergeom.sf(len(common) - 1, len(universe), len(markers), len(gene_group))
|
235
|
+
results.loc[ind, "pvals"] = pval
|
236
|
+
# Just in case any NaNs popped up somehow, fill them to 1 so FDR works
|
237
|
+
results = results.fillna(1)
|
238
|
+
if corr_method == "benjamini-hochberg":
|
239
|
+
results["pvals_adj"] = multipletests(results["pvals"], method="fdr_bh")[1]
|
240
|
+
elif corr_method == "bonferroni":
|
241
|
+
results["pvals_adj"] = np.minimum(results["pvals"] * results.shape[0], 1.0)
|
242
|
+
overrepresentation[cluster] = results.sort_values("pvals_adj")
|
243
|
+
|
244
|
+
return overrepresentation
|
245
|
+
|
246
|
+
def gsea(
|
247
|
+
self,
|
248
|
+
adata: "AnnData",
|
249
|
+
targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
|
250
|
+
nested: bool = False,
|
251
|
+
categories: str | list[str] | None = None,
|
252
|
+
absolute: bool = False,
|
253
|
+
key_added: str = "pertpy_enrichment_gsea",
|
254
|
+
) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
|
255
|
+
"""Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
adata: AnnData object with marker genes computed via `sc.tl.rank_genes_groups()`
|
259
|
+
in the original expression space.
|
260
|
+
targets: The gene groups to evaluate, either as a dictionary with names of the
|
261
|
+
groups as keys and gene lists as values, or a dictionary of dictionaries
|
262
|
+
with names of gene group categories as keys. Defaults to None, in which
|
263
|
+
case it uses `d2c.score()` output or loads ChEMBL-derived drug target sets.
|
264
|
+
nested: Indicates if `targets` is a dictionary of dictionaries with group
|
265
|
+
categories as keys. Defaults to False.
|
266
|
+
categories: Used to subset the gene groups to one or more categories,
|
267
|
+
applicable if `targets=None` or `nested=True`. Defaults to None.
|
268
|
+
absolute: If True, passes the absolute values of scores to GSEA, improving
|
269
|
+
statistical power. Defaults to False.
|
270
|
+
key_added: Prefix key that adds the results to `uns`.
|
271
|
+
Defaults to `pertpy_enrichment_gsea`.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
A dictionary with clusters as keys and data frames of test results sorted on
|
275
|
+
q-value as the items.
|
276
|
+
"""
|
277
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
278
|
+
enrichment = {}
|
279
|
+
plot_gsea_args: dict[str, Any] = {"targets": targets, "scores": {}}
|
280
|
+
for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
|
281
|
+
df = pd.DataFrame(
|
282
|
+
{
|
283
|
+
"0": adata.uns["rank_genes_groups"]["names"][cluster],
|
284
|
+
"1": adata.uns["rank_genes_groups"]["scores"][cluster],
|
285
|
+
}
|
286
|
+
)
|
287
|
+
if absolute:
|
288
|
+
df["1"] = np.absolute(df["1"])
|
289
|
+
df = df.sort_values("1", ascending=False)
|
290
|
+
enrichment[cluster] = blitzgsea.gsea(df, targets)
|
291
|
+
plot_gsea_args["scores"][cluster] = df
|
292
|
+
|
293
|
+
adata.uns[key_added] = plot_gsea_args
|
294
|
+
|
295
|
+
return enrichment
|
296
|
+
|
297
|
+
def plot_dotplot(
|
298
|
+
self,
|
299
|
+
adata: AnnData,
|
300
|
+
targets: dict[str, dict[str, list[str]]] = None,
|
301
|
+
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
302
|
+
category_name: str = "interaction_type",
|
303
|
+
categories: Sequence[str] = None,
|
304
|
+
groupby: str = None,
|
305
|
+
key: str = "pertpy_enrichment",
|
306
|
+
ax: Axes | None = None,
|
307
|
+
save: bool | str | None = None,
|
308
|
+
show: bool | None = None,
|
309
|
+
**kwargs,
|
310
|
+
) -> DotPlot | dict | None:
|
311
|
+
"""Plots a dotplot by groupby and categories.
|
312
|
+
|
313
|
+
Wraps scanpy's dotplot but formats it nicely by categories.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
adata: An AnnData object with enrichment results stored in `.uns["pertpy_enrichment_score"]`.
|
317
|
+
targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
|
318
|
+
Accepts a dictionary of dictionaries with group categories as keys.
|
319
|
+
If not provided, ChEMBL-derived or dgbidb drug target sets are used, given by `source`.
|
320
|
+
source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`. Defaults to `chembl`.
|
321
|
+
categories: To subset the gene groups to specific categories, especially when `targets=None`.
|
322
|
+
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
323
|
+
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`. Defaults to `interaction_type`.
|
324
|
+
groupby: dotplot groupby such as clusters or cell types.
|
325
|
+
key: Prefix key of enrichment results in `uns`.
|
326
|
+
Defaults to `pertpy_enrichment`.
|
327
|
+
kwargs: Passed to scanpy dotplot.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
|
331
|
+
else if `show` is false, return axes dict.
|
332
|
+
|
333
|
+
Examples:
|
334
|
+
>>> import pertpy as pt
|
335
|
+
>>> import scanpy as sc
|
336
|
+
>>> pt_enrichment = pt.tl.Enrichment()
|
337
|
+
>>> adata = sc.datasets.pbmc3k_processed()
|
338
|
+
>>> pt_enrichment.score(adata)
|
339
|
+
>>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
|
340
|
+
>>> pt_enrichment.plot_dotplot(adata, categories=["B01", "B02", "B03"], groupby="louvain")
|
341
|
+
|
342
|
+
Preview:
|
343
|
+
.. image:: /_static/docstring_previews/enrichment_dotplot.png
|
344
|
+
"""
|
345
|
+
if categories is not None:
|
346
|
+
if isinstance(categories, str):
|
347
|
+
categories = [categories]
|
348
|
+
else:
|
349
|
+
categories = list(categories)
|
350
|
+
|
351
|
+
if targets is None:
|
352
|
+
pt_drug = Drug()
|
353
|
+
if source == "chembl":
|
354
|
+
pt_drug.chembl.set()
|
355
|
+
targets = pt_drug.chembl.dictionary
|
356
|
+
elif source == "dgidb":
|
357
|
+
pt_drug.dgidb.set()
|
358
|
+
interaction = pt_drug.dgidb.data
|
359
|
+
if category_name not in interaction.columns:
|
360
|
+
raise ValueError("The category name is not available in dgidb drug target data.")
|
361
|
+
interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
|
362
|
+
targets = (
|
363
|
+
interaction.groupby(category_name)
|
364
|
+
.apply(lambda x: x.groupby("drug_claim_name")["gene_claim_name"].apply(list).to_dict())
|
365
|
+
.to_dict()
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
pt_drug.pharmgkb.set()
|
369
|
+
interaction = pt_drug.pharmgkb.data
|
370
|
+
if category_name not in interaction.columns:
|
371
|
+
raise ValueError("The category name is not available in pharmgkb drug target data.")
|
372
|
+
interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
|
373
|
+
targets = (
|
374
|
+
interaction.groupby(category_name)
|
375
|
+
.apply(lambda x: x.groupby("Compound|Disease")["Gene"].apply(list).to_dict())
|
376
|
+
.to_dict()
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
targets = targets.copy()
|
380
|
+
if categories is not None:
|
381
|
+
targets = {k: targets[k] for k in categories} # type: ignore
|
382
|
+
|
383
|
+
for group in targets:
|
384
|
+
targets[group] = list(targets[group].keys()) # type: ignore
|
385
|
+
|
386
|
+
var_names: list[str] = []
|
387
|
+
var_group_positions: list[tuple[int, int]] = []
|
388
|
+
var_group_labels: list[str] = []
|
389
|
+
start = 0
|
390
|
+
|
391
|
+
enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
|
392
|
+
enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]
|
393
|
+
|
394
|
+
for group in targets:
|
395
|
+
targets[group] = list( # type: ignore
|
396
|
+
enrichment_score_adata.var_names[np.isin(enrichment_score_adata.var_names, targets[group])]
|
397
|
+
)
|
398
|
+
if len(targets[group]) == 0:
|
399
|
+
continue
|
400
|
+
var_names = var_names + targets[group] # type: ignore
|
401
|
+
var_group_positions = var_group_positions + [(start, len(var_names) - 1)]
|
402
|
+
var_group_labels = var_group_labels + [group]
|
403
|
+
start = len(var_names)
|
404
|
+
|
405
|
+
plot_args = {
|
406
|
+
"var_names": var_names,
|
407
|
+
"var_group_positions": var_group_positions,
|
408
|
+
"var_group_labels": var_group_labels,
|
409
|
+
}
|
410
|
+
|
411
|
+
return sc.pl.dotplot(
|
412
|
+
enrichment_score_adata,
|
413
|
+
groupby=groupby,
|
414
|
+
swap_axes=True,
|
415
|
+
ax=ax,
|
416
|
+
save=save,
|
417
|
+
show=show,
|
418
|
+
**plot_args,
|
419
|
+
**kwargs,
|
420
|
+
)
|
421
|
+
|
422
|
+
def plot_gsea(
|
423
|
+
self,
|
424
|
+
adata: AnnData,
|
425
|
+
enrichment: dict[str, pd.DataFrame],
|
426
|
+
n: int = 10,
|
427
|
+
key: str = "pertpy_enrichment_gsea",
|
428
|
+
interactive_plot: bool = False,
|
429
|
+
) -> None:
|
430
|
+
"""Generates a blitzgsea top_table plot.
|
431
|
+
|
432
|
+
This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
|
433
|
+
It uses the output from the `gsea()` method, which provides the enrichment data,
|
434
|
+
and displays the top results using blitzgsea's `top_table()` plot.
|
435
|
+
|
436
|
+
Args:
|
437
|
+
adata: AnnData object to plot.
|
438
|
+
enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
|
439
|
+
n: How many top scores to show for each group. Defaults to 10.
|
440
|
+
key: GSEA results key in `uns`. Defaults to "pertpy_enrichment_gsea".
|
441
|
+
interactive_plot: Whether to plot interactively or not. Defaults to False.
|
442
|
+
|
443
|
+
Examples:
|
444
|
+
>>> import pertpy as pt
|
445
|
+
>>> import scanpy as sc
|
446
|
+
>>> pt_enrichment = pt.tl.Enrichment()
|
447
|
+
>>> adata = sc.datasets.pbmc3k_processed()
|
448
|
+
>>> pt_enrichment.score(adata)
|
449
|
+
>>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
|
450
|
+
>>> enrichment = pt_enrichment.gsea(adata)
|
451
|
+
>>> pt_enrichment.plot_gsea(adata, enrichment, interactive_plot=True)
|
452
|
+
|
453
|
+
Preview:
|
454
|
+
.. image:: /_static/docstring_previews/enrichment_gsea.png
|
455
|
+
"""
|
456
|
+
for cluster in enrichment:
|
457
|
+
fig = blitzgsea.plot.top_table(
|
458
|
+
adata.uns[key]["scores"][cluster],
|
459
|
+
adata.uns[key]["targets"],
|
460
|
+
enrichment[cluster],
|
461
|
+
n=n,
|
462
|
+
interactive_plot=interactive_plot,
|
463
|
+
)
|
464
|
+
fig.suptitle(cluster)
|
465
|
+
fig.show()
|