pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py
DELETED
@@ -1,1001 +0,0 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
from typing import Literal, Optional, Union
|
3
|
-
|
4
|
-
import matplotlib.image as mpimg
|
5
|
-
import matplotlib.pyplot as plt
|
6
|
-
import numpy as np
|
7
|
-
import pandas as pd
|
8
|
-
import scanpy as sc
|
9
|
-
import seaborn as sns
|
10
|
-
from adjustText import adjust_text
|
11
|
-
from anndata import AnnData
|
12
|
-
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
13
|
-
from matplotlib import cm, rcParams
|
14
|
-
from matplotlib.axes import Axes
|
15
|
-
from matplotlib.colors import ListedColormap
|
16
|
-
from mudata import MuData
|
17
|
-
|
18
|
-
from pertpy.tools._coda._base_coda import CompositionalModel2, collapse_singularities_2
|
19
|
-
|
20
|
-
sns.set_style("ticks")
|
21
|
-
|
22
|
-
|
23
|
-
class CodaPlot:
|
24
|
-
@staticmethod
|
25
|
-
def __stackbar( # pragma: no cover
|
26
|
-
y: np.ndarray,
|
27
|
-
type_names: list[str],
|
28
|
-
title: str,
|
29
|
-
level_names: list[str],
|
30
|
-
figsize: Optional[tuple[float, float]] = None,
|
31
|
-
dpi: Optional[int] = 100,
|
32
|
-
cmap: Optional[ListedColormap] = cm.tab20,
|
33
|
-
show_legend: Optional[bool] = True,
|
34
|
-
) -> plt.Axes:
|
35
|
-
"""Plots a stacked barplot for one (discrete) covariate.
|
36
|
-
|
37
|
-
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
38
|
-
|
39
|
-
Args:
|
40
|
-
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
41
|
-
one for each group, containing the count mean of each cell type
|
42
|
-
type_names: The names of all cell types
|
43
|
-
title: Plot title, usually the covariate's name
|
44
|
-
level_names: Names of the covariate's levels
|
45
|
-
figsize: Figure size. Defaults to None.
|
46
|
-
dpi: Dpi setting. Defaults to 100.
|
47
|
-
cmap: The color map for the barplot. Defaults to cm.tab20.
|
48
|
-
show_legend: If True, adds a legend. Defaults to True.
|
49
|
-
|
50
|
-
Returns:
|
51
|
-
A :class:`~matplotlib.axes.Axes` object
|
52
|
-
"""
|
53
|
-
n_bars, n_types = y.shape
|
54
|
-
|
55
|
-
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
56
|
-
|
57
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
58
|
-
r = np.array(range(n_bars))
|
59
|
-
sample_sums = np.sum(y, axis=1)
|
60
|
-
|
61
|
-
barwidth = 0.85
|
62
|
-
cum_bars = np.zeros(n_bars)
|
63
|
-
|
64
|
-
for n in range(n_types):
|
65
|
-
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
|
66
|
-
plt.bar(
|
67
|
-
r,
|
68
|
-
bars,
|
69
|
-
bottom=cum_bars,
|
70
|
-
color=cmap(n % cmap.N),
|
71
|
-
width=barwidth,
|
72
|
-
label=type_names[n],
|
73
|
-
linewidth=0,
|
74
|
-
)
|
75
|
-
cum_bars += bars
|
76
|
-
|
77
|
-
ax.set_title(title)
|
78
|
-
if show_legend:
|
79
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
80
|
-
ax.set_xticks(r)
|
81
|
-
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
82
|
-
ax.set_ylabel("Proportion")
|
83
|
-
|
84
|
-
return ax
|
85
|
-
|
86
|
-
@staticmethod
|
87
|
-
def stacked_barplot( # pragma: no cover
|
88
|
-
data: Union[AnnData, MuData],
|
89
|
-
feature_name: str,
|
90
|
-
modality_key: str = "coda",
|
91
|
-
figsize: Optional[tuple[float, float]] = None,
|
92
|
-
dpi: Optional[int] = 100,
|
93
|
-
cmap: Optional[ListedColormap] = cm.tab20,
|
94
|
-
show_legend: Optional[bool] = True,
|
95
|
-
level_order: list[str] = None,
|
96
|
-
) -> plt.Axes:
|
97
|
-
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
98
|
-
|
99
|
-
Args:
|
100
|
-
data: AnnData object or MuData object.
|
101
|
-
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
102
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
103
|
-
figsize: Figure size. Defaults to None.
|
104
|
-
dpi: Dpi setting. Defaults to 100.
|
105
|
-
cmap: The matplotlib color map for the barplot. Defaults to cm.tab20.
|
106
|
-
show_legend: If True, adds a legend. Defaults to True.
|
107
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
108
|
-
|
109
|
-
Returns:
|
110
|
-
A :class:`~matplotlib.axes.Axes` object
|
111
|
-
|
112
|
-
Examples:
|
113
|
-
Example with scCODA:
|
114
|
-
>>> import pertpy as pt
|
115
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
116
|
-
>>> sccoda = pt.tl.Sccoda()
|
117
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
118
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
119
|
-
>>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
|
120
|
-
"""
|
121
|
-
if isinstance(data, MuData):
|
122
|
-
data = data[modality_key]
|
123
|
-
if isinstance(data, AnnData):
|
124
|
-
data = data
|
125
|
-
|
126
|
-
ct_names = data.var.index
|
127
|
-
|
128
|
-
# option to plot one stacked barplot per sample
|
129
|
-
if feature_name == "samples":
|
130
|
-
if level_order:
|
131
|
-
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
132
|
-
data = data[level_order]
|
133
|
-
ax = CodaPlot.__stackbar(
|
134
|
-
data.X,
|
135
|
-
type_names=data.var.index,
|
136
|
-
title="samples",
|
137
|
-
level_names=data.obs.index,
|
138
|
-
figsize=figsize,
|
139
|
-
dpi=dpi,
|
140
|
-
cmap=cmap,
|
141
|
-
show_legend=show_legend,
|
142
|
-
)
|
143
|
-
else:
|
144
|
-
# Order levels
|
145
|
-
if level_order:
|
146
|
-
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
147
|
-
levels = level_order
|
148
|
-
elif hasattr(data.obs[feature_name], "cat"):
|
149
|
-
levels = data.obs[feature_name].cat.categories.to_list()
|
150
|
-
else:
|
151
|
-
levels = pd.unique(data.obs[feature_name])
|
152
|
-
n_levels = len(levels)
|
153
|
-
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
154
|
-
|
155
|
-
for level in range(n_levels):
|
156
|
-
l_indices = np.where(data.obs[feature_name] == levels[level])
|
157
|
-
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
158
|
-
|
159
|
-
ax = CodaPlot.__stackbar(
|
160
|
-
feature_totals,
|
161
|
-
type_names=ct_names,
|
162
|
-
title=feature_name,
|
163
|
-
level_names=levels,
|
164
|
-
figsize=figsize,
|
165
|
-
dpi=dpi,
|
166
|
-
cmap=cmap,
|
167
|
-
show_legend=show_legend,
|
168
|
-
)
|
169
|
-
return ax
|
170
|
-
|
171
|
-
@staticmethod
|
172
|
-
def effects_barplot( # pragma: no cover
|
173
|
-
data: Union[AnnData, MuData],
|
174
|
-
modality_key: str = "coda",
|
175
|
-
covariates: Optional[Union[str, list]] = None,
|
176
|
-
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
177
|
-
plot_facets: bool = True,
|
178
|
-
plot_zero_covariate: bool = True,
|
179
|
-
plot_zero_cell_type: bool = False,
|
180
|
-
figsize: Optional[tuple[float, float]] = None,
|
181
|
-
dpi: Optional[int] = 100,
|
182
|
-
cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
|
183
|
-
level_order: list[str] = None,
|
184
|
-
args_barplot: Optional[dict] = None,
|
185
|
-
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
|
186
|
-
"""Barplot visualization for effects.
|
187
|
-
|
188
|
-
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
189
|
-
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
190
|
-
|
191
|
-
Args:
|
192
|
-
data: AnnData object or MuData object.
|
193
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
194
|
-
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
195
|
-
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
196
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to True.
|
197
|
-
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot. Defaults to True.
|
198
|
-
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot. Defaults to False.
|
199
|
-
figsize: Figure size. Defaults to None.
|
200
|
-
dpi: Figure size. Defaults to 100.
|
201
|
-
cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
|
202
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
203
|
-
args_barplot: Arguments passed to sns.barplot. Defaults to None.
|
204
|
-
|
205
|
-
Returns:
|
206
|
-
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
207
|
-
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
208
|
-
|
209
|
-
Examples:
|
210
|
-
Example with scCODA:
|
211
|
-
>>> import pertpy as pt
|
212
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
213
|
-
>>> sccoda = pt.tl.Sccoda()
|
214
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
215
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
216
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
217
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
218
|
-
>>> pt.pl.coda.effects_barplot(mdata)
|
219
|
-
"""
|
220
|
-
if args_barplot is None:
|
221
|
-
args_barplot = {}
|
222
|
-
if isinstance(data, MuData):
|
223
|
-
data = data[modality_key]
|
224
|
-
if isinstance(data, AnnData):
|
225
|
-
data = data
|
226
|
-
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
227
|
-
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
228
|
-
if covariates is not None:
|
229
|
-
if isinstance(covariates, str):
|
230
|
-
covariates = [covariates]
|
231
|
-
partial_covariate_names = [
|
232
|
-
covariate_name
|
233
|
-
for covariate_name in covariate_names
|
234
|
-
if any(covariate in covariate_name for covariate in covariates)
|
235
|
-
]
|
236
|
-
covariate_names = partial_covariate_names
|
237
|
-
covariate_names_non_zero = [
|
238
|
-
covariate_name
|
239
|
-
for covariate_name in covariate_names
|
240
|
-
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
241
|
-
]
|
242
|
-
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
243
|
-
if not plot_zero_covariate:
|
244
|
-
covariate_names = covariate_names_non_zero
|
245
|
-
|
246
|
-
# set up df for plotting
|
247
|
-
plot_df = pd.concat(
|
248
|
-
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
249
|
-
axis=1,
|
250
|
-
)
|
251
|
-
plot_df.columns = covariate_names
|
252
|
-
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
253
|
-
|
254
|
-
plot_df = plot_df.reset_index()
|
255
|
-
|
256
|
-
if len(covariate_names_zero) != 0:
|
257
|
-
if plot_facets:
|
258
|
-
if plot_zero_covariate and not plot_zero_cell_type:
|
259
|
-
plot_df = plot_df[plot_df["value"] != 0]
|
260
|
-
for covariate_name_zero in covariate_names_zero:
|
261
|
-
new_row = {
|
262
|
-
"Covariate": covariate_name_zero,
|
263
|
-
"Cell Type": "zero",
|
264
|
-
"value": 0,
|
265
|
-
}
|
266
|
-
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
267
|
-
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
268
|
-
plot_df = plot_df.sort_values(["covariate_"])
|
269
|
-
if not plot_zero_cell_type:
|
270
|
-
cell_type_names_zero = [
|
271
|
-
name
|
272
|
-
for name in plot_df["Cell Type"].unique()
|
273
|
-
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
274
|
-
]
|
275
|
-
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
276
|
-
|
277
|
-
# If plot as facets, create a FacetGrid and map barplot to it.
|
278
|
-
if plot_facets:
|
279
|
-
if isinstance(cmap, ListedColormap):
|
280
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
281
|
-
if figsize is not None:
|
282
|
-
height = figsize[0]
|
283
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
284
|
-
else:
|
285
|
-
height = 3
|
286
|
-
aspect = 2
|
287
|
-
|
288
|
-
g = sns.FacetGrid(
|
289
|
-
plot_df,
|
290
|
-
col="Covariate",
|
291
|
-
sharey=True,
|
292
|
-
sharex=False,
|
293
|
-
height=height,
|
294
|
-
aspect=aspect,
|
295
|
-
)
|
296
|
-
|
297
|
-
g.map(
|
298
|
-
sns.barplot,
|
299
|
-
"Cell Type",
|
300
|
-
"value",
|
301
|
-
palette=cmap,
|
302
|
-
order=level_order,
|
303
|
-
**args_barplot,
|
304
|
-
)
|
305
|
-
g.set_xticklabels(rotation=90)
|
306
|
-
g.set(ylabel=parameter)
|
307
|
-
axes = g.axes.flatten()
|
308
|
-
for i, ax in enumerate(axes):
|
309
|
-
ax.set_title(covariate_names[i])
|
310
|
-
if len(ax.get_xticklabels()) < 5:
|
311
|
-
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
312
|
-
if len(ax.get_xticklabels()) == 1:
|
313
|
-
if ax.get_xticklabels()[0]._text == "zero":
|
314
|
-
ax.set_xticks([])
|
315
|
-
return g
|
316
|
-
|
317
|
-
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
318
|
-
else:
|
319
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
320
|
-
if len(covariate_names) == 1:
|
321
|
-
if isinstance(cmap, ListedColormap):
|
322
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
323
|
-
sns.barplot(
|
324
|
-
data=plot_df,
|
325
|
-
x="Cell Type",
|
326
|
-
y="value",
|
327
|
-
palette=cmap,
|
328
|
-
ax=ax,
|
329
|
-
)
|
330
|
-
ax.set_title(covariate_names[0])
|
331
|
-
else:
|
332
|
-
if isinstance(cmap, ListedColormap):
|
333
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
|
334
|
-
sns.barplot(
|
335
|
-
data=plot_df,
|
336
|
-
x="Cell Type",
|
337
|
-
y="value",
|
338
|
-
hue="Covariate",
|
339
|
-
palette=cmap,
|
340
|
-
ax=ax,
|
341
|
-
)
|
342
|
-
cell_types = pd.unique(plot_df["Cell Type"])
|
343
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
344
|
-
return ax
|
345
|
-
|
346
|
-
@staticmethod
|
347
|
-
def boxplots( # pragma: no cover
|
348
|
-
data: Union[AnnData, MuData],
|
349
|
-
feature_name: str,
|
350
|
-
modality_key: str = "coda",
|
351
|
-
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
352
|
-
plot_facets: bool = False,
|
353
|
-
add_dots: bool = False,
|
354
|
-
model: CompositionalModel2 = None,
|
355
|
-
cell_types: Optional[list] = None,
|
356
|
-
args_boxplot: Optional[dict] = None,
|
357
|
-
args_swarmplot: Optional[dict] = None,
|
358
|
-
figsize: Optional[tuple[float, float]] = None,
|
359
|
-
dpi: Optional[int] = 100,
|
360
|
-
cmap: Optional[str] = "Blues",
|
361
|
-
show_legend: Optional[bool] = True,
|
362
|
-
level_order: list[str] = None,
|
363
|
-
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
|
364
|
-
"""Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
|
365
|
-
with intra--group separation by a covariate from data.obs.
|
366
|
-
|
367
|
-
Args:
|
368
|
-
data: AnnData object or MuData object
|
369
|
-
feature_name: The name of the feature in data.obs to plot
|
370
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
371
|
-
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
372
|
-
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
373
|
-
Defaults to "relative".
|
374
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
|
375
|
-
add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
|
376
|
-
model: When draw_effects, specify a tasCODA model
|
377
|
-
cell_types: Subset of cell types that should be plotted. Defaults to None.
|
378
|
-
args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
|
379
|
-
args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
|
380
|
-
figsize: Figure size. Defaults to None.
|
381
|
-
dpi: Dpi setting. Defaults to 100.
|
382
|
-
cmap: The seaborn color map for the barplot. Defaults to "Blues".
|
383
|
-
show_legend: If True, adds a legend. Defaults to True.
|
384
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
385
|
-
|
386
|
-
Returns:
|
387
|
-
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
388
|
-
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
389
|
-
|
390
|
-
Examples:
|
391
|
-
Example with scCODA:
|
392
|
-
>>> import pertpy as pt
|
393
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
394
|
-
>>> sccoda = pt.tl.Sccoda()
|
395
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
396
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
397
|
-
>>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
|
398
|
-
"""
|
399
|
-
if args_boxplot is None:
|
400
|
-
args_boxplot = {}
|
401
|
-
if args_swarmplot is None:
|
402
|
-
args_swarmplot = {}
|
403
|
-
if isinstance(data, MuData):
|
404
|
-
data = data[modality_key]
|
405
|
-
if isinstance(data, AnnData):
|
406
|
-
data = data
|
407
|
-
# y scale transformations
|
408
|
-
if y_scale == "relative":
|
409
|
-
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
410
|
-
X = data.X / sample_sums
|
411
|
-
value_name = "Proportion"
|
412
|
-
# add pseudocount 0.5 if using log scale
|
413
|
-
elif y_scale == "log":
|
414
|
-
X = data.X.copy()
|
415
|
-
X[X == 0] = 0.5
|
416
|
-
X = np.log(X)
|
417
|
-
value_name = "log(count)"
|
418
|
-
elif y_scale == "log10":
|
419
|
-
X = data.X.copy()
|
420
|
-
X[X == 0] = 0.5
|
421
|
-
X = np.log(X)
|
422
|
-
value_name = "log10(count)"
|
423
|
-
elif y_scale == "count":
|
424
|
-
X = data.X
|
425
|
-
value_name = "count"
|
426
|
-
else:
|
427
|
-
raise ValueError("Invalid y_scale transformation")
|
428
|
-
|
429
|
-
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
430
|
-
data.obs[feature_name], left_index=True, right_index=True
|
431
|
-
)
|
432
|
-
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
433
|
-
if cell_types is not None:
|
434
|
-
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
435
|
-
|
436
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
437
|
-
# We had to drop the dependency.
|
438
|
-
# Get credible effects results from model
|
439
|
-
# if draw_effects:
|
440
|
-
# if model is not None:
|
441
|
-
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
442
|
-
# else:
|
443
|
-
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
444
|
-
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
445
|
-
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
446
|
-
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
447
|
-
|
448
|
-
# If plot as facets, create a FacetGrid and map boxplot to it.
|
449
|
-
if plot_facets:
|
450
|
-
if level_order is None:
|
451
|
-
level_order = pd.unique(plot_df[feature_name])
|
452
|
-
|
453
|
-
K = X.shape[1]
|
454
|
-
|
455
|
-
if figsize is not None:
|
456
|
-
height = figsize[0]
|
457
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
458
|
-
else:
|
459
|
-
height = 3
|
460
|
-
aspect = 2
|
461
|
-
|
462
|
-
g = sns.FacetGrid(
|
463
|
-
plot_df,
|
464
|
-
col="Cell type",
|
465
|
-
sharey=False,
|
466
|
-
col_wrap=int(np.floor(np.sqrt(K))),
|
467
|
-
height=height,
|
468
|
-
aspect=aspect,
|
469
|
-
)
|
470
|
-
g.map(
|
471
|
-
sns.boxplot,
|
472
|
-
feature_name,
|
473
|
-
value_name,
|
474
|
-
palette=cmap,
|
475
|
-
order=level_order,
|
476
|
-
**args_boxplot,
|
477
|
-
)
|
478
|
-
|
479
|
-
if add_dots:
|
480
|
-
if "hue" in args_swarmplot:
|
481
|
-
hue = args_swarmplot.pop("hue")
|
482
|
-
else:
|
483
|
-
hue = None
|
484
|
-
|
485
|
-
if hue is None:
|
486
|
-
g.map(
|
487
|
-
sns.swarmplot,
|
488
|
-
feature_name,
|
489
|
-
value_name,
|
490
|
-
color="black",
|
491
|
-
order=level_order,
|
492
|
-
**args_swarmplot,
|
493
|
-
).set_titles("{col_name}")
|
494
|
-
else:
|
495
|
-
g.map(
|
496
|
-
sns.swarmplot,
|
497
|
-
feature_name,
|
498
|
-
value_name,
|
499
|
-
hue,
|
500
|
-
order=level_order,
|
501
|
-
**args_swarmplot,
|
502
|
-
).set_titles("{col_name}")
|
503
|
-
return g
|
504
|
-
|
505
|
-
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
506
|
-
else:
|
507
|
-
if level_order:
|
508
|
-
args_boxplot["hue_order"] = level_order
|
509
|
-
args_swarmplot["hue_order"] = level_order
|
510
|
-
|
511
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
512
|
-
|
513
|
-
ax = sns.boxplot(
|
514
|
-
x="Cell type",
|
515
|
-
y=value_name,
|
516
|
-
hue=feature_name,
|
517
|
-
data=plot_df,
|
518
|
-
fliersize=1,
|
519
|
-
palette=cmap,
|
520
|
-
ax=ax,
|
521
|
-
**args_boxplot,
|
522
|
-
)
|
523
|
-
|
524
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
525
|
-
# We had to drop the dependency.
|
526
|
-
# if draw_effects:
|
527
|
-
# pairs = [
|
528
|
-
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
529
|
-
# for _, row in credible_effects_df.iterrows()
|
530
|
-
# ]
|
531
|
-
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
532
|
-
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
533
|
-
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
534
|
-
# annot.annotate()
|
535
|
-
|
536
|
-
if add_dots:
|
537
|
-
sns.swarmplot(
|
538
|
-
x="Cell type",
|
539
|
-
y=value_name,
|
540
|
-
data=plot_df,
|
541
|
-
hue=feature_name,
|
542
|
-
ax=ax,
|
543
|
-
dodge=True,
|
544
|
-
color="black",
|
545
|
-
**args_swarmplot,
|
546
|
-
)
|
547
|
-
|
548
|
-
cell_types = pd.unique(plot_df["Cell type"])
|
549
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
550
|
-
|
551
|
-
if show_legend:
|
552
|
-
handles, labels = ax.get_legend_handles_labels()
|
553
|
-
handout = []
|
554
|
-
labelout = []
|
555
|
-
for h, l in zip(handles, labels):
|
556
|
-
if l not in labelout:
|
557
|
-
labelout.append(l)
|
558
|
-
handout.append(h)
|
559
|
-
ax.legend(
|
560
|
-
handout,
|
561
|
-
labelout,
|
562
|
-
loc="upper left",
|
563
|
-
bbox_to_anchor=(1, 1),
|
564
|
-
ncol=1,
|
565
|
-
title=feature_name,
|
566
|
-
)
|
567
|
-
|
568
|
-
plt.tight_layout()
|
569
|
-
return ax
|
570
|
-
|
571
|
-
@staticmethod
|
572
|
-
def rel_abundance_dispersion_plot( # pragma: no cover
|
573
|
-
data: Union[AnnData, MuData],
|
574
|
-
modality_key: str = "coda",
|
575
|
-
abundant_threshold: Optional[float] = 0.9,
|
576
|
-
default_color: Optional[str] = "Grey",
|
577
|
-
abundant_color: Optional[str] = "Red",
|
578
|
-
label_cell_types: bool = True,
|
579
|
-
figsize: Optional[tuple[float, float]] = None,
|
580
|
-
dpi: Optional[int] = 100,
|
581
|
-
ax: Axes = None,
|
582
|
-
) -> plt.Axes:
|
583
|
-
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
584
|
-
|
585
|
-
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
586
|
-
|
587
|
-
Args:
|
588
|
-
data: AnnData object or MuData object.
|
589
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
590
|
-
Defaults to "coda".
|
591
|
-
abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
|
592
|
-
default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
|
593
|
-
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
594
|
-
Defaults to "Red".
|
595
|
-
label_cell_types: Label dots with cell type names. Defaults to True.
|
596
|
-
figsize: Figure size. Defaults to None.
|
597
|
-
dpi: Dpi setting. Defaults to 100.
|
598
|
-
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
599
|
-
|
600
|
-
Returns:
|
601
|
-
A :class:`~matplotlib.axes.Axes` object
|
602
|
-
|
603
|
-
Examples:
|
604
|
-
Example with scCODA:
|
605
|
-
>>> import pertpy as pt
|
606
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
607
|
-
>>> sccoda = pt.tl.Sccoda()
|
608
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
609
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
610
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
611
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
612
|
-
>>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
|
613
|
-
"""
|
614
|
-
if isinstance(data, MuData):
|
615
|
-
data = data[modality_key]
|
616
|
-
if isinstance(data, AnnData):
|
617
|
-
data = data
|
618
|
-
if ax is None:
|
619
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
620
|
-
|
621
|
-
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
622
|
-
|
623
|
-
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
624
|
-
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
625
|
-
|
626
|
-
# select reference
|
627
|
-
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
628
|
-
|
629
|
-
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
630
|
-
|
631
|
-
# Scatterplot
|
632
|
-
plot_df = pd.DataFrame(
|
633
|
-
{
|
634
|
-
"Total dispersion": cell_type_disp,
|
635
|
-
"Cell type": data.var.index,
|
636
|
-
"Presence": 1 - percent_zero,
|
637
|
-
"Is abundant": is_abundant,
|
638
|
-
}
|
639
|
-
)
|
640
|
-
|
641
|
-
if len(np.unique(plot_df["Is abundant"])) > 1:
|
642
|
-
palette = [default_color, abundant_color]
|
643
|
-
elif np.unique(plot_df["Is abundant"]) == [False]:
|
644
|
-
palette = [default_color]
|
645
|
-
else:
|
646
|
-
palette = [abundant_color]
|
647
|
-
|
648
|
-
ax = sns.scatterplot(
|
649
|
-
data=plot_df,
|
650
|
-
x="Presence",
|
651
|
-
y="Total dispersion",
|
652
|
-
hue="Is abundant",
|
653
|
-
palette=palette,
|
654
|
-
ax=ax,
|
655
|
-
)
|
656
|
-
|
657
|
-
# Text labels for abundant cell types
|
658
|
-
|
659
|
-
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
660
|
-
|
661
|
-
def label_point(x, y, val, ax):
|
662
|
-
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
663
|
-
texts = [
|
664
|
-
ax.text(
|
665
|
-
point["x"],
|
666
|
-
point["y"],
|
667
|
-
str(point["val"]),
|
668
|
-
)
|
669
|
-
for i, point in a.iterrows()
|
670
|
-
]
|
671
|
-
adjust_text(texts)
|
672
|
-
|
673
|
-
if label_cell_types:
|
674
|
-
label_point(
|
675
|
-
abundant_df["Presence"],
|
676
|
-
abundant_df["Total dispersion"],
|
677
|
-
abundant_df["Cell type"],
|
678
|
-
plt.gca(),
|
679
|
-
)
|
680
|
-
|
681
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
682
|
-
|
683
|
-
plt.tight_layout()
|
684
|
-
return ax
|
685
|
-
|
686
|
-
@staticmethod
|
687
|
-
def draw_tree( # pragma: no cover
|
688
|
-
data: Union[AnnData, MuData],
|
689
|
-
modality_key: str = "coda",
|
690
|
-
tree: Union[Tree, str] = "tree",
|
691
|
-
tight_text: Optional[bool] = False,
|
692
|
-
show_scale: Optional[bool] = False,
|
693
|
-
show: Optional[bool] = True,
|
694
|
-
file_name: Optional[str] = None,
|
695
|
-
units: Optional[Literal["px", "mm", "in"]] = "px",
|
696
|
-
h: Optional[float] = None,
|
697
|
-
w: Optional[float] = None,
|
698
|
-
dpi: Optional[int] = 90,
|
699
|
-
):
|
700
|
-
"""Plot a tree using input ete3 tree object.
|
701
|
-
|
702
|
-
Args:
|
703
|
-
data: AnnData object or MuData object.
|
704
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
705
|
-
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
|
706
|
-
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
707
|
-
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
708
|
-
Default to False.
|
709
|
-
show_scale: Include the scale legend in the tree image or not. Default to False.
|
710
|
-
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
711
|
-
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
712
|
-
Defaults to None.
|
713
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
714
|
-
h: Height of the image in units. Defaults to None.
|
715
|
-
w: Width of the image in units. Defaults to None.
|
716
|
-
dpi: Dots per inches. Defaults to 90.
|
717
|
-
|
718
|
-
Returns:
|
719
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
720
|
-
|
721
|
-
Examples:
|
722
|
-
Example with tascCODA:
|
723
|
-
>>> import pertpy as pt
|
724
|
-
>>> adata = pt.dt.smillie()
|
725
|
-
>>> tasccoda = pt.tl.Tasccoda()
|
726
|
-
>>> mdata = tasccoda.load(
|
727
|
-
>>> adata, type="sample_level",
|
728
|
-
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
729
|
-
>>> key_added="lineage", add_level_name=True
|
730
|
-
>>> )
|
731
|
-
>>> mdata = tasccoda.prepare(
|
732
|
-
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
733
|
-
>>> )
|
734
|
-
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
735
|
-
>>> pt.pl.coda.draw_tree(mdata, tree="lineage")
|
736
|
-
"""
|
737
|
-
if isinstance(data, MuData):
|
738
|
-
data = data[modality_key]
|
739
|
-
if isinstance(data, AnnData):
|
740
|
-
data = data
|
741
|
-
if isinstance(tree, str):
|
742
|
-
tree = data.uns[tree]
|
743
|
-
|
744
|
-
def my_layout(node):
|
745
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
746
|
-
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
747
|
-
|
748
|
-
tree_style = TreeStyle()
|
749
|
-
tree_style.show_leaf_name = False
|
750
|
-
tree_style.layout_fn = my_layout
|
751
|
-
tree_style.show_scale = show_scale
|
752
|
-
if file_name is not None:
|
753
|
-
tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
754
|
-
if show:
|
755
|
-
return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
756
|
-
else:
|
757
|
-
return tree, tree_style
|
758
|
-
|
759
|
-
@staticmethod
|
760
|
-
def draw_effects( # pragma: no cover
|
761
|
-
data: Union[AnnData, MuData],
|
762
|
-
covariate: str,
|
763
|
-
modality_key: str = "coda",
|
764
|
-
tree: Union[Tree, str] = "tree",
|
765
|
-
show_legend: Optional[bool] = None,
|
766
|
-
show_leaf_effects: Optional[bool] = False,
|
767
|
-
tight_text: Optional[bool] = False,
|
768
|
-
show_scale: Optional[bool] = False,
|
769
|
-
show: Optional[bool] = True,
|
770
|
-
file_name: Optional[str] = None,
|
771
|
-
units: Optional[Literal["px", "mm", "in"]] = "in",
|
772
|
-
h: Optional[float] = None,
|
773
|
-
w: Optional[float] = None,
|
774
|
-
dpi: Optional[int] = 90,
|
775
|
-
):
|
776
|
-
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
777
|
-
|
778
|
-
Args:
|
779
|
-
data: AnnData object or MuData object.
|
780
|
-
covariate: The covariate, whose effects should be plotted.
|
781
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
782
|
-
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
|
783
|
-
show_legend: If show legend of nodes significant effects or not. Default is False if show_leaf_effects is True.
|
784
|
-
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects. Defaults to False.
|
785
|
-
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
786
|
-
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
787
|
-
Defaults to False.
|
788
|
-
show_scale: Include the scale legend in the tree image or not. Defaults to False.
|
789
|
-
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
790
|
-
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
791
|
-
Defaults to None.
|
792
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in".
|
793
|
-
h: Height of the image in units. Defaults to None.
|
794
|
-
w: Width of the image in units. Defaults to None.
|
795
|
-
dpi: Dots per inches. Defaults to 90.
|
796
|
-
|
797
|
-
Returns:
|
798
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
799
|
-
or plot the tree inline (`show = False`)
|
800
|
-
|
801
|
-
Examples:
|
802
|
-
Example with tascCODA:
|
803
|
-
>>> import pertpy as pt
|
804
|
-
>>> adata = pt.dt.smillie()
|
805
|
-
>>> tasccoda = pt.tl.Tasccoda()
|
806
|
-
>>> mdata = tasccoda.load(
|
807
|
-
>>> adata, type="sample_level",
|
808
|
-
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
809
|
-
>>> key_added="lineage", add_level_name=True
|
810
|
-
>>> )
|
811
|
-
>>> mdata = tasccoda.prepare(
|
812
|
-
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
813
|
-
>>> )
|
814
|
-
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
815
|
-
>>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
816
|
-
"""
|
817
|
-
if isinstance(data, MuData):
|
818
|
-
data = data[modality_key]
|
819
|
-
if isinstance(data, AnnData):
|
820
|
-
data = data
|
821
|
-
if show_legend is None:
|
822
|
-
show_legend = not show_leaf_effects
|
823
|
-
elif show_legend:
|
824
|
-
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
825
|
-
|
826
|
-
if isinstance(tree, str):
|
827
|
-
tree = data.uns[tree]
|
828
|
-
# Collapse tree singularities
|
829
|
-
tree2 = collapse_singularities_2(tree)
|
830
|
-
|
831
|
-
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
832
|
-
node_effs.index = node_effs.index.get_level_values("Node")
|
833
|
-
|
834
|
-
covariates = data.uns["scCODA_params"]["covariate_names"]
|
835
|
-
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
836
|
-
eff_df = pd.concat(effect_dfs)
|
837
|
-
eff_df.index = pd.MultiIndex.from_product(
|
838
|
-
(covariates, data.var.index.tolist()),
|
839
|
-
names=["Covariate", "Cell Type"],
|
840
|
-
)
|
841
|
-
leaf_effs = eff_df.loc[(covariate,),].copy()
|
842
|
-
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
843
|
-
|
844
|
-
# Add effect values
|
845
|
-
for n in tree2.traverse():
|
846
|
-
nstyle = NodeStyle()
|
847
|
-
nstyle["size"] = 0
|
848
|
-
n.set_style(nstyle)
|
849
|
-
if n.name in node_effs.index:
|
850
|
-
e = node_effs.loc[n.name, "Final Parameter"]
|
851
|
-
n.add_feature("node_effect", e)
|
852
|
-
else:
|
853
|
-
n.add_feature("node_effect", 0)
|
854
|
-
if n.name in leaf_effs.index:
|
855
|
-
e = leaf_effs.loc[n.name, "Effect"]
|
856
|
-
n.add_feature("leaf_effect", e)
|
857
|
-
else:
|
858
|
-
n.add_feature("leaf_effect", 0)
|
859
|
-
|
860
|
-
# Scale effect values to get nice node sizes
|
861
|
-
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
862
|
-
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
863
|
-
|
864
|
-
def my_layout(node):
|
865
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
866
|
-
text_face.margin_left = 10
|
867
|
-
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
868
|
-
|
869
|
-
# if node.is_leaf():
|
870
|
-
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
871
|
-
if np.sign(node.node_effect) == 1:
|
872
|
-
color = "blue"
|
873
|
-
elif np.sign(node.node_effect) == -1:
|
874
|
-
color = "red"
|
875
|
-
else:
|
876
|
-
color = "cyan"
|
877
|
-
if size != 0:
|
878
|
-
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
879
|
-
|
880
|
-
tree_style = TreeStyle()
|
881
|
-
tree_style.show_leaf_name = False
|
882
|
-
tree_style.layout_fn = my_layout
|
883
|
-
tree_style.show_scale = show_scale
|
884
|
-
tree_style.draw_guiding_lines = True
|
885
|
-
tree_style.legend_position = 1
|
886
|
-
|
887
|
-
if show_legend:
|
888
|
-
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
889
|
-
tree_style.legend.add_face(TextFace(" "), column=1)
|
890
|
-
for i in range(4, 0, -1):
|
891
|
-
tree_style.legend.add_face(
|
892
|
-
CircleFace(
|
893
|
-
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
894
|
-
"red",
|
895
|
-
),
|
896
|
-
column=0,
|
897
|
-
)
|
898
|
-
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
899
|
-
tree_style.legend.add_face(
|
900
|
-
CircleFace(
|
901
|
-
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
902
|
-
"blue",
|
903
|
-
),
|
904
|
-
column=1,
|
905
|
-
)
|
906
|
-
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
907
|
-
|
908
|
-
if show_leaf_effects:
|
909
|
-
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
910
|
-
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
911
|
-
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
912
|
-
|
913
|
-
dir_path = Path.cwd()
|
914
|
-
dir_path = Path(dir_path / "tree_effect.png")
|
915
|
-
tree2.render(dir_path, tree_style=tree_style, units="in")
|
916
|
-
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
917
|
-
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
918
|
-
img = mpimg.imread(dir_path)
|
919
|
-
ax[0].imshow(img)
|
920
|
-
ax[0].get_xaxis().set_visible(False)
|
921
|
-
ax[0].get_yaxis().set_visible(False)
|
922
|
-
ax[0].set_frame_on(False)
|
923
|
-
|
924
|
-
ax[1].get_yaxis().set_visible(False)
|
925
|
-
ax[1].spines["left"].set_visible(False)
|
926
|
-
ax[1].spines["right"].set_visible(False)
|
927
|
-
ax[1].spines["top"].set_visible(False)
|
928
|
-
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
929
|
-
plt.subplots_adjust(wspace=0)
|
930
|
-
|
931
|
-
if file_name is not None:
|
932
|
-
plt.savefig(file_name)
|
933
|
-
|
934
|
-
if file_name is not None and not show_leaf_effects:
|
935
|
-
tree2.render(file_name, tree_style=tree_style, units=units)
|
936
|
-
if show:
|
937
|
-
if not show_leaf_effects:
|
938
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
939
|
-
else:
|
940
|
-
if not show_leaf_effects:
|
941
|
-
return tree2, tree_style
|
942
|
-
|
943
|
-
@staticmethod
|
944
|
-
def effects_umap( # pragma: no cover
|
945
|
-
data: MuData,
|
946
|
-
effect_name: Optional[Union[str, list]],
|
947
|
-
cluster_key: str,
|
948
|
-
modality_key_1: str = "rna",
|
949
|
-
modality_key_2: str = "coda",
|
950
|
-
show: bool = None,
|
951
|
-
ax: Axes = None,
|
952
|
-
**kwargs,
|
953
|
-
):
|
954
|
-
"""Plot a UMAP visualization colored by effect strength.
|
955
|
-
|
956
|
-
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to.
|
957
|
-
|
958
|
-
Args:
|
959
|
-
data: AnnData object or MuData object.
|
960
|
-
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData (default is data['coda']) to plot
|
961
|
-
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells.
|
962
|
-
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
963
|
-
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
|
964
|
-
show: Whether to display the figure or return axis. Defaults to None.
|
965
|
-
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
966
|
-
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
967
|
-
|
968
|
-
Returns:
|
969
|
-
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
970
|
-
|
971
|
-
Examples:
|
972
|
-
Example with scCODA:
|
973
|
-
>>> import pertpy as pt
|
974
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
975
|
-
>>> sccoda = pt.tl.Sccoda()
|
976
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
977
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
978
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
979
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
980
|
-
|
981
|
-
>>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
|
982
|
-
#TODO: Add effect_name parameter and cluster_key and test the example
|
983
|
-
"""
|
984
|
-
data_rna = data[modality_key_1]
|
985
|
-
data_coda = data[modality_key_2]
|
986
|
-
if isinstance(effect_name, str):
|
987
|
-
effect_name = [effect_name]
|
988
|
-
for _, effect in enumerate(effect_name):
|
989
|
-
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
990
|
-
if kwargs.get("vmin"):
|
991
|
-
vmin = kwargs["vmin"]
|
992
|
-
kwargs.pop("vmin")
|
993
|
-
else:
|
994
|
-
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
995
|
-
if kwargs.get("vmax"):
|
996
|
-
vmax = kwargs["vmax"]
|
997
|
-
kwargs.pop("vmax")
|
998
|
-
else:
|
999
|
-
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
1000
|
-
|
1001
|
-
return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, **kwargs)
|