pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,460 @@
|
|
1
|
+
from collections import ChainMap
|
2
|
+
from collections.abc import Sequence
|
3
|
+
from typing import Any, Literal
|
4
|
+
|
5
|
+
import blitzgsea
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
import scanpy as sc
|
9
|
+
from anndata import AnnData
|
10
|
+
from matplotlib.axes import Axes
|
11
|
+
from scanpy.plotting import DotPlot
|
12
|
+
from scanpy.tools._score_genes import _sparse_nanmean
|
13
|
+
from scipy.sparse import issparse
|
14
|
+
from scipy.stats import hypergeom
|
15
|
+
from statsmodels.stats.multitest import multipletests
|
16
|
+
|
17
|
+
from pertpy.metadata import Drug
|
18
|
+
|
19
|
+
|
20
|
+
def _prepare_targets(
|
21
|
+
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
|
22
|
+
nested: bool = False,
|
23
|
+
categories: str | Sequence[str] = None,
|
24
|
+
) -> ChainMap | dict:
|
25
|
+
if categories is not None:
|
26
|
+
if isinstance(categories, str):
|
27
|
+
categories = [categories]
|
28
|
+
else:
|
29
|
+
categories = list(categories)
|
30
|
+
|
31
|
+
if targets is None:
|
32
|
+
pt_drug = Drug()
|
33
|
+
pt_drug.chembl.set()
|
34
|
+
targets = pt_drug.chembl.dictionary
|
35
|
+
nested = True
|
36
|
+
else:
|
37
|
+
targets = targets.copy()
|
38
|
+
if categories is not None:
|
39
|
+
targets = {k: targets[k] for k in categories} # type: ignore
|
40
|
+
if nested:
|
41
|
+
targets = dict(ChainMap(*[targets[cat] for cat in targets])) # type: ignore
|
42
|
+
|
43
|
+
return targets
|
44
|
+
|
45
|
+
|
46
|
+
def _mean(X, names, axis):
|
47
|
+
"""Helper function to compute a mean of X across an axis, respecting names and possible nans."""
|
48
|
+
if issparse(X):
|
49
|
+
obs_avg = pd.Series(
|
50
|
+
np.array(_sparse_nanmean(X, axis=axis)).flatten(),
|
51
|
+
index=names,
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
obs_avg = pd.Series(np.nanmean(X, axis=axis), index=names)
|
55
|
+
return obs_avg
|
56
|
+
|
57
|
+
|
58
|
+
class Enrichment:
|
59
|
+
def score(
|
60
|
+
self,
|
61
|
+
adata: AnnData,
|
62
|
+
layer: str = None,
|
63
|
+
targets: dict[str, list[str]] | dict[str, dict[str, list[str]]] = None,
|
64
|
+
nested: bool = False,
|
65
|
+
categories: Sequence[str] = None,
|
66
|
+
method: Literal["mean", "seurat"] = "mean",
|
67
|
+
n_bins: int = 25,
|
68
|
+
ctrl_size: int = 50,
|
69
|
+
key_added: str = "pertpy_enrichment",
|
70
|
+
) -> None:
|
71
|
+
"""Obtain per-cell scoring of gene groups of interest.
|
72
|
+
|
73
|
+
Inspired by drug2cell score: https://github.com/Teichlab/drug2cell.
|
74
|
+
Ensure that the gene nomenclature in your target sets is compatible with your
|
75
|
+
`.var_names`. The ChEMBL drug targets use HGNC.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
adata: An AnnData object. It is recommended to use log-normalised data.
|
79
|
+
targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
|
80
|
+
Accepts two forms:
|
81
|
+
- A dictionary with group names as keys and corresponding gene lists as entries.
|
82
|
+
- A dictionary of dictionaries with group categories as keys. Use `nested=True` in this case.
|
83
|
+
If not provided, ChEMBL-derived drug target sets are used.
|
84
|
+
nested: Indicates if `targets` is a dictionary of dictionaries with group categories as keys.
|
85
|
+
categories: To subset the gene groups to specific categories, especially when `targets=None` or `nested=True`.
|
86
|
+
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
87
|
+
method: Method for scoring gene groups. `"mean"` calculates the mean over all genes,
|
88
|
+
while `"seurat"` uses a background profile subtraction approach.
|
89
|
+
layer: Specifies which `.layers` of AnnData to use for expression values.
|
90
|
+
n_bins: The number of expression bins for the `'seurat'` method.
|
91
|
+
ctrl_size: The number of genes to randomly sample from each expression bin for the `"seurat"` method.
|
92
|
+
key_added: Prefix key that adds the results to `uns`.
|
93
|
+
Note that the actual values are `key_added_score`, `key_added_variables`, `key_added_genes`, `key_added_all_genes`.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
An AnnData object with scores.
|
97
|
+
"""
|
98
|
+
if layer is not None:
|
99
|
+
mtx = adata.layers[layer]
|
100
|
+
else:
|
101
|
+
mtx = adata.X
|
102
|
+
|
103
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
104
|
+
full_targets = targets.copy()
|
105
|
+
|
106
|
+
for drug in targets:
|
107
|
+
targets[drug] = np.isin(adata.var_names, targets[drug])
|
108
|
+
|
109
|
+
# Scoring is done via matrix multiplication of the original cell by gene matrix by a new gene by drug matrix
|
110
|
+
# with the entries in the new matrix being the weights of each gene for that group (such as drug)
|
111
|
+
# The mean across targets is constant -> prepare weights for that
|
112
|
+
weights = pd.DataFrame(targets, index=adata.var_names)
|
113
|
+
weights = weights.loc[:, weights.sum() > 0]
|
114
|
+
weights = weights / weights.sum()
|
115
|
+
if issparse(mtx):
|
116
|
+
scores = mtx.dot(weights)
|
117
|
+
else:
|
118
|
+
scores = np.dot(mtx, weights)
|
119
|
+
|
120
|
+
if method == "seurat":
|
121
|
+
obs_avg = _mean(mtx, names=adata.var_names, axis=0)
|
122
|
+
n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
|
123
|
+
obs_cut = obs_avg.rank(method="min") // n_items
|
124
|
+
obs_cut = obs_cut.values
|
125
|
+
|
126
|
+
control_groups = {}
|
127
|
+
for cut in np.unique(obs_cut):
|
128
|
+
mask = obs_cut == cut
|
129
|
+
r_genes = np.nonzero(mask)[0]
|
130
|
+
rng = np.random.default_rng()
|
131
|
+
rng.shuffle(r_genes)
|
132
|
+
mask[r_genes[ctrl_size:]] = False
|
133
|
+
control_groups[cut] = mask
|
134
|
+
control_gene_weights = pd.DataFrame(control_groups, index=adata.var_names)
|
135
|
+
control_gene_weights = control_gene_weights / control_gene_weights.sum()
|
136
|
+
|
137
|
+
if issparse(mtx):
|
138
|
+
control_profiles = mtx.dot(control_gene_weights)
|
139
|
+
else:
|
140
|
+
control_profiles = np.dot(mtx, control_gene_weights)
|
141
|
+
drug_bins = {}
|
142
|
+
for drug in weights.columns:
|
143
|
+
bins = np.unique(obs_cut[targets[drug]])
|
144
|
+
drug_bins[drug] = np.isin(control_gene_weights.columns, bins)
|
145
|
+
drug_weights = pd.DataFrame(drug_bins, index=control_gene_weights.columns)
|
146
|
+
drug_weights = drug_weights / drug_weights.sum()
|
147
|
+
seurat = np.dot(control_profiles, drug_weights)
|
148
|
+
scores = scores - seurat
|
149
|
+
|
150
|
+
adata.uns[f"{key_added}_score"] = scores
|
151
|
+
adata.uns[f"{key_added}_variables"] = weights.columns
|
152
|
+
|
153
|
+
adata.uns[f"{key_added}_genes"] = {"var": pd.DataFrame(columns=["genes"]).astype(object)}
|
154
|
+
adata.uns[f"{key_added}_all_genes"] = {"var": pd.DataFrame(columns=["all_genes"]).astype(object)}
|
155
|
+
|
156
|
+
for drug in weights.columns:
|
157
|
+
adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
|
158
|
+
adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])
|
159
|
+
|
160
|
+
def hypergeometric(
|
161
|
+
self,
|
162
|
+
adata: AnnData,
|
163
|
+
targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
|
164
|
+
nested: bool = False,
|
165
|
+
categories: str | list[str] | None = None,
|
166
|
+
pvals_adj_thresh: float = 0.05,
|
167
|
+
direction: str = "both",
|
168
|
+
corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg",
|
169
|
+
):
|
170
|
+
"""Perform a hypergeometric test to assess the overrepresentation of gene group members.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
adata: With marker genes computed via `sc.tl.rank_genes_groups()` in the original expression space.
|
174
|
+
targets: The gene groups to evaluate. Can be targets of known drugs, GO terms, pathway memberships, anything you can assign genes to.
|
175
|
+
If `None`, will use `d2c.score()` output if present, and if not present load the ChEMBL-derived drug target sets distributed with the package.
|
176
|
+
Accepts two forms:
|
177
|
+
- A dictionary with the names of the groups as keys, and the entries being the corresponding gene lists.
|
178
|
+
- A dictionary of dictionaries defined like above, with names of gene group categories as keys.
|
179
|
+
If passing one of those, specify `nested=True`.
|
180
|
+
nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
|
181
|
+
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
|
182
|
+
In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
183
|
+
pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
|
184
|
+
direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
|
185
|
+
Can be `up`, `down`, or `both` (for no selection).
|
186
|
+
corr_method: Which FDR correction to apply to the p-values of the hypergeometric test.
|
187
|
+
Can be `benjamini-hochberg` or `bonferroni`.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
Dictionary with clusters for which the original object markers were computed as the keys,
|
191
|
+
and data frames of test results sorted on q-value as the items.
|
192
|
+
"""
|
193
|
+
universe = set(adata.var_names)
|
194
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
195
|
+
for group in targets:
|
196
|
+
targets[group] = set(targets[group]).intersection(universe) # type: ignore
|
197
|
+
# We remove empty keys since we don't need them
|
198
|
+
targets = {k: v for k, v in targets.items() if v}
|
199
|
+
|
200
|
+
overrepresentation = {}
|
201
|
+
for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
|
202
|
+
results = pd.DataFrame(
|
203
|
+
1,
|
204
|
+
index=list(targets.keys()),
|
205
|
+
columns=[
|
206
|
+
"intersection",
|
207
|
+
"gene_group",
|
208
|
+
"markers",
|
209
|
+
"universe",
|
210
|
+
"pvals",
|
211
|
+
"pvals_adj",
|
212
|
+
],
|
213
|
+
)
|
214
|
+
mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < pvals_adj_thresh
|
215
|
+
if direction == "up":
|
216
|
+
mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] > 0)
|
217
|
+
elif direction == "down":
|
218
|
+
mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] < 0)
|
219
|
+
markers = set(adata.uns["rank_genes_groups"]["names"][cluster][mask])
|
220
|
+
results["markers"] = len(markers)
|
221
|
+
results["universe"] = len(universe)
|
222
|
+
results["pvals"] = results["pvals"].astype(float)
|
223
|
+
|
224
|
+
for ind in results.index:
|
225
|
+
gene_group = targets[ind]
|
226
|
+
common = gene_group.intersection(markers) # type: ignore
|
227
|
+
results.loc[ind, "intersection"] = len(common)
|
228
|
+
results.loc[ind, "gene_group"] = len(gene_group)
|
229
|
+
# need to subtract 1 from the intersection length
|
230
|
+
# https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
|
231
|
+
pval = hypergeom.sf(len(common) - 1, len(universe), len(markers), len(gene_group))
|
232
|
+
results.loc[ind, "pvals"] = pval
|
233
|
+
# Just in case any NaNs popped up somehow, fill them to 1 so FDR works
|
234
|
+
results = results.fillna(1)
|
235
|
+
if corr_method == "benjamini-hochberg":
|
236
|
+
results["pvals_adj"] = multipletests(results["pvals"], method="fdr_bh")[1]
|
237
|
+
elif corr_method == "bonferroni":
|
238
|
+
results["pvals_adj"] = np.minimum(results["pvals"] * results.shape[0], 1.0)
|
239
|
+
overrepresentation[cluster] = results.sort_values("pvals_adj")
|
240
|
+
|
241
|
+
return overrepresentation
|
242
|
+
|
243
|
+
def gsea(
|
244
|
+
self,
|
245
|
+
adata: "AnnData",
|
246
|
+
targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
|
247
|
+
nested: bool = False,
|
248
|
+
categories: str | list[str] | None = None,
|
249
|
+
absolute: bool = False,
|
250
|
+
key_added: str = "pertpy_enrichment_gsea",
|
251
|
+
) -> dict[str, pd.DataFrame] | tuple[dict[str, pd.DataFrame], dict[str, dict]]: # pragma: no cover
|
252
|
+
"""Perform gene set enrichment analysis on the marker gene scores using blitzgsea.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
adata: AnnData object with marker genes computed via `sc.tl.rank_genes_groups()`
|
256
|
+
in the original expression space.
|
257
|
+
targets: The gene groups to evaluate, either as a dictionary with names of the
|
258
|
+
groups as keys and gene lists as values, or a dictionary of dictionaries
|
259
|
+
with names of gene group categories as keys.
|
260
|
+
case it uses `d2c.score()` output or loads ChEMBL-derived drug target sets.
|
261
|
+
nested: Indicates if `targets` is a dictionary of dictionaries with group
|
262
|
+
categories as keys.
|
263
|
+
categories: Used to subset the gene groups to one or more categories,
|
264
|
+
applicable if `targets=None` or `nested=True`.
|
265
|
+
absolute: If True, passes the absolute values of scores to GSEA, improving
|
266
|
+
statistical power.
|
267
|
+
key_added: Prefix key that adds the results to `uns`.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
A dictionary with clusters as keys and data frames of test results sorted on
|
271
|
+
q-value as the items.
|
272
|
+
"""
|
273
|
+
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
|
274
|
+
enrichment = {}
|
275
|
+
plot_gsea_args: dict[str, Any] = {"targets": targets, "scores": {}}
|
276
|
+
for cluster in adata.uns["rank_genes_groups"]["names"].dtype.names:
|
277
|
+
df = pd.DataFrame(
|
278
|
+
{
|
279
|
+
"0": adata.uns["rank_genes_groups"]["names"][cluster],
|
280
|
+
"1": adata.uns["rank_genes_groups"]["scores"][cluster],
|
281
|
+
}
|
282
|
+
)
|
283
|
+
if absolute:
|
284
|
+
df["1"] = np.absolute(df["1"])
|
285
|
+
df = df.sort_values("1", ascending=False)
|
286
|
+
enrichment[cluster] = blitzgsea.gsea(df, targets)
|
287
|
+
plot_gsea_args["scores"][cluster] = df
|
288
|
+
|
289
|
+
adata.uns[key_added] = plot_gsea_args
|
290
|
+
|
291
|
+
return enrichment
|
292
|
+
|
293
|
+
def plot_dotplot(
|
294
|
+
self,
|
295
|
+
adata: AnnData,
|
296
|
+
targets: dict[str, dict[str, list[str]]] = None,
|
297
|
+
source: Literal["chembl", "dgidb", "pharmgkb"] = "chembl",
|
298
|
+
category_name: str = "interaction_type",
|
299
|
+
categories: Sequence[str] = None,
|
300
|
+
groupby: str = None,
|
301
|
+
key: str = "pertpy_enrichment",
|
302
|
+
ax: Axes | None = None,
|
303
|
+
save: bool | str | None = None,
|
304
|
+
show: bool | None = None,
|
305
|
+
**kwargs,
|
306
|
+
) -> DotPlot | dict | None:
|
307
|
+
"""Plots a dotplot by groupby and categories.
|
308
|
+
|
309
|
+
Wraps scanpy's dotplot but formats it nicely by categories.
|
310
|
+
|
311
|
+
Args:
|
312
|
+
adata: An AnnData object with enrichment results stored in `.uns["pertpy_enrichment_score"]`.
|
313
|
+
targets: Gene groups to evaluate, which can be targets of known drugs, GO terms, pathway memberships, etc.
|
314
|
+
Accepts a dictionary of dictionaries with group categories as keys.
|
315
|
+
If not provided, ChEMBL-derived or dgbidb drug target sets are used, given by `source`.
|
316
|
+
source: Source of drug target sets when `targets=None`, `chembl`, `dgidb` or `pharmgkb`.
|
317
|
+
categories: To subset the gene groups to specific categories, especially when `targets=None`.
|
318
|
+
For ChEMBL drug targets, these are ATC level 1/level 2 category codes.
|
319
|
+
category_name: The name of category used to generate a nested drug target set when `targets=None` and `source=dgidb|pharmgkb`.
|
320
|
+
groupby: dotplot groupby such as clusters or cell types.
|
321
|
+
key: Prefix key of enrichment results in `uns`.
|
322
|
+
kwargs: Passed to scanpy dotplot.
|
323
|
+
|
324
|
+
Returns:
|
325
|
+
If `return_fig` is `True`, returns a :class:`~scanpy.pl.DotPlot` object,
|
326
|
+
else if `show` is false, return axes dict.
|
327
|
+
|
328
|
+
Examples:
|
329
|
+
>>> import pertpy as pt
|
330
|
+
>>> import scanpy as sc
|
331
|
+
>>> pt_enrichment = pt.tl.Enrichment()
|
332
|
+
>>> adata = sc.datasets.pbmc3k_processed()
|
333
|
+
>>> pt_enrichment.score(adata)
|
334
|
+
>>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
|
335
|
+
>>> pt_enrichment.plot_dotplot(adata, categories=["B01", "B02", "B03"], groupby="louvain")
|
336
|
+
|
337
|
+
Preview:
|
338
|
+
.. image:: /_static/docstring_previews/enrichment_dotplot.png
|
339
|
+
"""
|
340
|
+
if categories is not None:
|
341
|
+
if isinstance(categories, str):
|
342
|
+
categories = [categories]
|
343
|
+
else:
|
344
|
+
categories = list(categories)
|
345
|
+
|
346
|
+
if targets is None:
|
347
|
+
pt_drug = Drug()
|
348
|
+
if source == "chembl":
|
349
|
+
pt_drug.chembl.set()
|
350
|
+
targets = pt_drug.chembl.dictionary
|
351
|
+
elif source == "dgidb":
|
352
|
+
pt_drug.dgidb.set()
|
353
|
+
interaction = pt_drug.dgidb.data
|
354
|
+
if category_name not in interaction.columns:
|
355
|
+
raise ValueError("The category name is not available in dgidb drug target data.")
|
356
|
+
interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
|
357
|
+
targets = (
|
358
|
+
interaction.groupby(category_name)
|
359
|
+
.apply(lambda x: x.groupby("drug_claim_name")["gene_claim_name"].apply(list).to_dict())
|
360
|
+
.to_dict()
|
361
|
+
)
|
362
|
+
else:
|
363
|
+
pt_drug.pharmgkb.set()
|
364
|
+
interaction = pt_drug.pharmgkb.data
|
365
|
+
if category_name not in interaction.columns:
|
366
|
+
raise ValueError("The category name is not available in pharmgkb drug target data.")
|
367
|
+
interaction[category_name] = interaction[category_name].fillna("Unknown/Other")
|
368
|
+
targets = (
|
369
|
+
interaction.groupby(category_name)
|
370
|
+
.apply(lambda x: x.groupby("Compound|Disease")["Gene"].apply(list).to_dict())
|
371
|
+
.to_dict()
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
targets = targets.copy()
|
375
|
+
if categories is not None:
|
376
|
+
targets = {k: targets[k] for k in categories} # type: ignore
|
377
|
+
|
378
|
+
for group in targets:
|
379
|
+
targets[group] = list(targets[group].keys()) # type: ignore
|
380
|
+
|
381
|
+
var_names: list[str] = []
|
382
|
+
var_group_positions: list[tuple[int, int]] = []
|
383
|
+
var_group_labels: list[str] = []
|
384
|
+
start = 0
|
385
|
+
|
386
|
+
enrichment_score_adata = AnnData(adata.uns[f"{key}_score"], obs=adata.obs)
|
387
|
+
enrichment_score_adata.var_names = adata.uns[f"{key}_variables"]
|
388
|
+
|
389
|
+
for group in targets:
|
390
|
+
targets[group] = list( # type: ignore
|
391
|
+
enrichment_score_adata.var_names[np.isin(enrichment_score_adata.var_names, targets[group])]
|
392
|
+
)
|
393
|
+
if len(targets[group]) == 0:
|
394
|
+
continue
|
395
|
+
var_names = var_names + targets[group] # type: ignore
|
396
|
+
var_group_positions = var_group_positions + [(start, len(var_names) - 1)]
|
397
|
+
var_group_labels = var_group_labels + [group]
|
398
|
+
start = len(var_names)
|
399
|
+
|
400
|
+
plot_args = {
|
401
|
+
"var_names": var_names,
|
402
|
+
"var_group_positions": var_group_positions,
|
403
|
+
"var_group_labels": var_group_labels,
|
404
|
+
}
|
405
|
+
|
406
|
+
return sc.pl.dotplot(
|
407
|
+
enrichment_score_adata,
|
408
|
+
groupby=groupby,
|
409
|
+
swap_axes=True,
|
410
|
+
ax=ax,
|
411
|
+
save=save,
|
412
|
+
show=show,
|
413
|
+
**plot_args,
|
414
|
+
**kwargs,
|
415
|
+
)
|
416
|
+
|
417
|
+
def plot_gsea(
|
418
|
+
self,
|
419
|
+
adata: AnnData,
|
420
|
+
enrichment: dict[str, pd.DataFrame],
|
421
|
+
n: int = 10,
|
422
|
+
key: str = "pertpy_enrichment_gsea",
|
423
|
+
interactive_plot: bool = False,
|
424
|
+
) -> None:
|
425
|
+
"""Generates a blitzgsea top_table plot.
|
426
|
+
|
427
|
+
This function is designed to visualize the results from a Gene Set Enrichment Analysis (GSEA).
|
428
|
+
It uses the output from the `gsea()` method, which provides the enrichment data,
|
429
|
+
and displays the top results using blitzgsea's `top_table()` plot.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
adata: AnnData object to plot.
|
433
|
+
enrichment: Cluster names as keys, blitzgsea's ``gsea()`` output as values.
|
434
|
+
n: How many top scores to show for each group.
|
435
|
+
key: GSEA results key in `uns`.
|
436
|
+
interactive_plot: Whether to plot interactively or not.
|
437
|
+
|
438
|
+
Examples:
|
439
|
+
>>> import pertpy as pt
|
440
|
+
>>> import scanpy as sc
|
441
|
+
>>> pt_enrichment = pt.tl.Enrichment()
|
442
|
+
>>> adata = sc.datasets.pbmc3k_processed()
|
443
|
+
>>> pt_enrichment.score(adata)
|
444
|
+
>>> sc.tl.rank_genes_groups(adata, method="wilcoxon", groupby="louvain")
|
445
|
+
>>> enrichment = pt_enrichment.gsea(adata)
|
446
|
+
>>> pt_enrichment.plot_gsea(adata, enrichment, interactive_plot=True)
|
447
|
+
|
448
|
+
Preview:
|
449
|
+
.. image:: /_static/docstring_previews/enrichment_gsea.png
|
450
|
+
"""
|
451
|
+
for cluster in enrichment:
|
452
|
+
fig = blitzgsea.plot.top_table(
|
453
|
+
adata.uns[key]["scores"][cluster],
|
454
|
+
adata.uns[key]["targets"],
|
455
|
+
enrichment[cluster],
|
456
|
+
n=n,
|
457
|
+
interactive_plot=interactive_plot,
|
458
|
+
)
|
459
|
+
fig.suptitle(cluster)
|
460
|
+
fig.show()
|
pertpy/tools/_kernel_pca.py
CHANGED
@@ -31,7 +31,7 @@ def kernel_pca(
|
|
31
31
|
|
32
32
|
Returns:
|
33
33
|
If `copy=True`, returns the copy of `adata` with kernel pca in `.obsm["X_kpca"]`.
|
34
|
-
Otherwise writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
|
34
|
+
Otherwise, writes kernel pca directly to `.obsm["X_kpca"]` of the provided `adata`.
|
35
35
|
If `return_transformer=True`, returns also the fitted `KernelPCA` transformer.
|
36
36
|
"""
|
37
37
|
if copy:
|