pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +4 -2
- pertpy/data/__init__.py +66 -1
- pertpy/data/_dataloader.py +28 -26
- pertpy/data/_datasets.py +261 -92
- pertpy/metadata/__init__.py +6 -0
- pertpy/metadata/_cell_line.py +795 -0
- pertpy/metadata/_compound.py +128 -0
- pertpy/metadata/_drug.py +238 -0
- pertpy/metadata/_look_up.py +569 -0
- pertpy/metadata/_metadata.py +70 -0
- pertpy/metadata/_moa.py +125 -0
- pertpy/plot/__init__.py +0 -13
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +89 -6
- pertpy/tools/__init__.py +48 -15
- pertpy/tools/_augur.py +329 -32
- pertpy/tools/_cinemaot.py +145 -6
- pertpy/tools/_coda/_base_coda.py +1237 -116
- pertpy/tools/_coda/_sccoda.py +66 -36
- pertpy/tools/_coda/_tasccoda.py +46 -39
- pertpy/tools/_dialogue.py +180 -77
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +29 -24
- pertpy/tools/_distances/_distances.py +584 -98
- pertpy/tools/_enrichment.py +460 -0
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +406 -49
- pertpy/tools/_mixscape.py +677 -55
- pertpy/tools/_perturbation_space/_clustering.py +10 -3
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
- pertpy/tools/_perturbation_space/_simple.py +52 -11
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +706 -0
- pertpy/tools/_scgen/_utils.py +3 -5
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -234
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_coda.py +0 -1001
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_guide_rna.py +0 -82
- pertpy/plot/_milopy.py +0 -284
- pertpy/plot/_mixscape.py +0 -594
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_differential_gene_expression.py +0 -99
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py
DELETED
@@ -1,1001 +0,0 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
from typing import Literal, Optional, Union
|
3
|
-
|
4
|
-
import matplotlib.image as mpimg
|
5
|
-
import matplotlib.pyplot as plt
|
6
|
-
import numpy as np
|
7
|
-
import pandas as pd
|
8
|
-
import scanpy as sc
|
9
|
-
import seaborn as sns
|
10
|
-
from adjustText import adjust_text
|
11
|
-
from anndata import AnnData
|
12
|
-
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
13
|
-
from matplotlib import cm, rcParams
|
14
|
-
from matplotlib.axes import Axes
|
15
|
-
from matplotlib.colors import ListedColormap
|
16
|
-
from mudata import MuData
|
17
|
-
|
18
|
-
from pertpy.tools._coda._base_coda import CompositionalModel2, collapse_singularities_2
|
19
|
-
|
20
|
-
sns.set_style("ticks")
|
21
|
-
|
22
|
-
|
23
|
-
class CodaPlot:
|
24
|
-
@staticmethod
|
25
|
-
def __stackbar( # pragma: no cover
|
26
|
-
y: np.ndarray,
|
27
|
-
type_names: list[str],
|
28
|
-
title: str,
|
29
|
-
level_names: list[str],
|
30
|
-
figsize: Optional[tuple[float, float]] = None,
|
31
|
-
dpi: Optional[int] = 100,
|
32
|
-
cmap: Optional[ListedColormap] = cm.tab20,
|
33
|
-
show_legend: Optional[bool] = True,
|
34
|
-
) -> plt.Axes:
|
35
|
-
"""Plots a stacked barplot for one (discrete) covariate.
|
36
|
-
|
37
|
-
Typical use (only inside stacked_barplot): plot_one_stackbar(data.X, data.var.index, "xyz", data.obs.index)
|
38
|
-
|
39
|
-
Args:
|
40
|
-
y: The count data, collapsed onto the level of interest. i.e. a binary covariate has two rows,
|
41
|
-
one for each group, containing the count mean of each cell type
|
42
|
-
type_names: The names of all cell types
|
43
|
-
title: Plot title, usually the covariate's name
|
44
|
-
level_names: Names of the covariate's levels
|
45
|
-
figsize: Figure size. Defaults to None.
|
46
|
-
dpi: Dpi setting. Defaults to 100.
|
47
|
-
cmap: The color map for the barplot. Defaults to cm.tab20.
|
48
|
-
show_legend: If True, adds a legend. Defaults to True.
|
49
|
-
|
50
|
-
Returns:
|
51
|
-
A :class:`~matplotlib.axes.Axes` object
|
52
|
-
"""
|
53
|
-
n_bars, n_types = y.shape
|
54
|
-
|
55
|
-
figsize = rcParams["figure.figsize"] if figsize is None else figsize
|
56
|
-
|
57
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
58
|
-
r = np.array(range(n_bars))
|
59
|
-
sample_sums = np.sum(y, axis=1)
|
60
|
-
|
61
|
-
barwidth = 0.85
|
62
|
-
cum_bars = np.zeros(n_bars)
|
63
|
-
|
64
|
-
for n in range(n_types):
|
65
|
-
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
|
66
|
-
plt.bar(
|
67
|
-
r,
|
68
|
-
bars,
|
69
|
-
bottom=cum_bars,
|
70
|
-
color=cmap(n % cmap.N),
|
71
|
-
width=barwidth,
|
72
|
-
label=type_names[n],
|
73
|
-
linewidth=0,
|
74
|
-
)
|
75
|
-
cum_bars += bars
|
76
|
-
|
77
|
-
ax.set_title(title)
|
78
|
-
if show_legend:
|
79
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
|
80
|
-
ax.set_xticks(r)
|
81
|
-
ax.set_xticklabels(level_names, rotation=45, ha="right")
|
82
|
-
ax.set_ylabel("Proportion")
|
83
|
-
|
84
|
-
return ax
|
85
|
-
|
86
|
-
@staticmethod
|
87
|
-
def stacked_barplot( # pragma: no cover
|
88
|
-
data: Union[AnnData, MuData],
|
89
|
-
feature_name: str,
|
90
|
-
modality_key: str = "coda",
|
91
|
-
figsize: Optional[tuple[float, float]] = None,
|
92
|
-
dpi: Optional[int] = 100,
|
93
|
-
cmap: Optional[ListedColormap] = cm.tab20,
|
94
|
-
show_legend: Optional[bool] = True,
|
95
|
-
level_order: list[str] = None,
|
96
|
-
) -> plt.Axes:
|
97
|
-
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
98
|
-
|
99
|
-
Args:
|
100
|
-
data: AnnData object or MuData object.
|
101
|
-
feature_name: The name of the covariate to plot. If feature_name=="samples", one bar for every sample will be plotted
|
102
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
103
|
-
figsize: Figure size. Defaults to None.
|
104
|
-
dpi: Dpi setting. Defaults to 100.
|
105
|
-
cmap: The matplotlib color map for the barplot. Defaults to cm.tab20.
|
106
|
-
show_legend: If True, adds a legend. Defaults to True.
|
107
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
108
|
-
|
109
|
-
Returns:
|
110
|
-
A :class:`~matplotlib.axes.Axes` object
|
111
|
-
|
112
|
-
Examples:
|
113
|
-
Example with scCODA:
|
114
|
-
>>> import pertpy as pt
|
115
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
116
|
-
>>> sccoda = pt.tl.Sccoda()
|
117
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
118
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
119
|
-
>>> pt.pl.coda.stacked_barplot(mdata, feature_name="samples")
|
120
|
-
"""
|
121
|
-
if isinstance(data, MuData):
|
122
|
-
data = data[modality_key]
|
123
|
-
if isinstance(data, AnnData):
|
124
|
-
data = data
|
125
|
-
|
126
|
-
ct_names = data.var.index
|
127
|
-
|
128
|
-
# option to plot one stacked barplot per sample
|
129
|
-
if feature_name == "samples":
|
130
|
-
if level_order:
|
131
|
-
assert set(level_order) == set(data.obs.index), "level order is inconsistent with levels"
|
132
|
-
data = data[level_order]
|
133
|
-
ax = CodaPlot.__stackbar(
|
134
|
-
data.X,
|
135
|
-
type_names=data.var.index,
|
136
|
-
title="samples",
|
137
|
-
level_names=data.obs.index,
|
138
|
-
figsize=figsize,
|
139
|
-
dpi=dpi,
|
140
|
-
cmap=cmap,
|
141
|
-
show_legend=show_legend,
|
142
|
-
)
|
143
|
-
else:
|
144
|
-
# Order levels
|
145
|
-
if level_order:
|
146
|
-
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
147
|
-
levels = level_order
|
148
|
-
elif hasattr(data.obs[feature_name], "cat"):
|
149
|
-
levels = data.obs[feature_name].cat.categories.to_list()
|
150
|
-
else:
|
151
|
-
levels = pd.unique(data.obs[feature_name])
|
152
|
-
n_levels = len(levels)
|
153
|
-
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
154
|
-
|
155
|
-
for level in range(n_levels):
|
156
|
-
l_indices = np.where(data.obs[feature_name] == levels[level])
|
157
|
-
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
158
|
-
|
159
|
-
ax = CodaPlot.__stackbar(
|
160
|
-
feature_totals,
|
161
|
-
type_names=ct_names,
|
162
|
-
title=feature_name,
|
163
|
-
level_names=levels,
|
164
|
-
figsize=figsize,
|
165
|
-
dpi=dpi,
|
166
|
-
cmap=cmap,
|
167
|
-
show_legend=show_legend,
|
168
|
-
)
|
169
|
-
return ax
|
170
|
-
|
171
|
-
@staticmethod
|
172
|
-
def effects_barplot( # pragma: no cover
|
173
|
-
data: Union[AnnData, MuData],
|
174
|
-
modality_key: str = "coda",
|
175
|
-
covariates: Optional[Union[str, list]] = None,
|
176
|
-
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
177
|
-
plot_facets: bool = True,
|
178
|
-
plot_zero_covariate: bool = True,
|
179
|
-
plot_zero_cell_type: bool = False,
|
180
|
-
figsize: Optional[tuple[float, float]] = None,
|
181
|
-
dpi: Optional[int] = 100,
|
182
|
-
cmap: Optional[Union[str, ListedColormap]] = cm.tab20,
|
183
|
-
level_order: list[str] = None,
|
184
|
-
args_barplot: Optional[dict] = None,
|
185
|
-
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
|
186
|
-
"""Barplot visualization for effects.
|
187
|
-
|
188
|
-
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
189
|
-
The covariates groups can either be ordered along the x-axis of a single plot (plot_facets=False) or as plot facets (plot_facets=True).
|
190
|
-
|
191
|
-
Args:
|
192
|
-
data: AnnData object or MuData object.
|
193
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
194
|
-
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
195
|
-
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
196
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to True.
|
197
|
-
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot. Defaults to True.
|
198
|
-
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot. Defaults to False.
|
199
|
-
figsize: Figure size. Defaults to None.
|
200
|
-
dpi: Figure size. Defaults to 100.
|
201
|
-
cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
|
202
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
203
|
-
args_barplot: Arguments passed to sns.barplot. Defaults to None.
|
204
|
-
|
205
|
-
Returns:
|
206
|
-
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
207
|
-
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
208
|
-
|
209
|
-
Examples:
|
210
|
-
Example with scCODA:
|
211
|
-
>>> import pertpy as pt
|
212
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
213
|
-
>>> sccoda = pt.tl.Sccoda()
|
214
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
215
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
216
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
217
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
218
|
-
>>> pt.pl.coda.effects_barplot(mdata)
|
219
|
-
"""
|
220
|
-
if args_barplot is None:
|
221
|
-
args_barplot = {}
|
222
|
-
if isinstance(data, MuData):
|
223
|
-
data = data[modality_key]
|
224
|
-
if isinstance(data, AnnData):
|
225
|
-
data = data
|
226
|
-
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
227
|
-
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
228
|
-
if covariates is not None:
|
229
|
-
if isinstance(covariates, str):
|
230
|
-
covariates = [covariates]
|
231
|
-
partial_covariate_names = [
|
232
|
-
covariate_name
|
233
|
-
for covariate_name in covariate_names
|
234
|
-
if any(covariate in covariate_name for covariate in covariates)
|
235
|
-
]
|
236
|
-
covariate_names = partial_covariate_names
|
237
|
-
covariate_names_non_zero = [
|
238
|
-
covariate_name
|
239
|
-
for covariate_name in covariate_names
|
240
|
-
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
241
|
-
]
|
242
|
-
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
243
|
-
if not plot_zero_covariate:
|
244
|
-
covariate_names = covariate_names_non_zero
|
245
|
-
|
246
|
-
# set up df for plotting
|
247
|
-
plot_df = pd.concat(
|
248
|
-
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
249
|
-
axis=1,
|
250
|
-
)
|
251
|
-
plot_df.columns = covariate_names
|
252
|
-
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
253
|
-
|
254
|
-
plot_df = plot_df.reset_index()
|
255
|
-
|
256
|
-
if len(covariate_names_zero) != 0:
|
257
|
-
if plot_facets:
|
258
|
-
if plot_zero_covariate and not plot_zero_cell_type:
|
259
|
-
plot_df = plot_df[plot_df["value"] != 0]
|
260
|
-
for covariate_name_zero in covariate_names_zero:
|
261
|
-
new_row = {
|
262
|
-
"Covariate": covariate_name_zero,
|
263
|
-
"Cell Type": "zero",
|
264
|
-
"value": 0,
|
265
|
-
}
|
266
|
-
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
267
|
-
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
268
|
-
plot_df = plot_df.sort_values(["covariate_"])
|
269
|
-
if not plot_zero_cell_type:
|
270
|
-
cell_type_names_zero = [
|
271
|
-
name
|
272
|
-
for name in plot_df["Cell Type"].unique()
|
273
|
-
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
274
|
-
]
|
275
|
-
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
276
|
-
|
277
|
-
# If plot as facets, create a FacetGrid and map barplot to it.
|
278
|
-
if plot_facets:
|
279
|
-
if isinstance(cmap, ListedColormap):
|
280
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
281
|
-
if figsize is not None:
|
282
|
-
height = figsize[0]
|
283
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
284
|
-
else:
|
285
|
-
height = 3
|
286
|
-
aspect = 2
|
287
|
-
|
288
|
-
g = sns.FacetGrid(
|
289
|
-
plot_df,
|
290
|
-
col="Covariate",
|
291
|
-
sharey=True,
|
292
|
-
sharex=False,
|
293
|
-
height=height,
|
294
|
-
aspect=aspect,
|
295
|
-
)
|
296
|
-
|
297
|
-
g.map(
|
298
|
-
sns.barplot,
|
299
|
-
"Cell Type",
|
300
|
-
"value",
|
301
|
-
palette=cmap,
|
302
|
-
order=level_order,
|
303
|
-
**args_barplot,
|
304
|
-
)
|
305
|
-
g.set_xticklabels(rotation=90)
|
306
|
-
g.set(ylabel=parameter)
|
307
|
-
axes = g.axes.flatten()
|
308
|
-
for i, ax in enumerate(axes):
|
309
|
-
ax.set_title(covariate_names[i])
|
310
|
-
if len(ax.get_xticklabels()) < 5:
|
311
|
-
ax.set_aspect(10 / len(ax.get_xticklabels()))
|
312
|
-
if len(ax.get_xticklabels()) == 1:
|
313
|
-
if ax.get_xticklabels()[0]._text == "zero":
|
314
|
-
ax.set_xticks([])
|
315
|
-
return g
|
316
|
-
|
317
|
-
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
318
|
-
else:
|
319
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
320
|
-
if len(covariate_names) == 1:
|
321
|
-
if isinstance(cmap, ListedColormap):
|
322
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
323
|
-
sns.barplot(
|
324
|
-
data=plot_df,
|
325
|
-
x="Cell Type",
|
326
|
-
y="value",
|
327
|
-
palette=cmap,
|
328
|
-
ax=ax,
|
329
|
-
)
|
330
|
-
ax.set_title(covariate_names[0])
|
331
|
-
else:
|
332
|
-
if isinstance(cmap, ListedColormap):
|
333
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
|
334
|
-
sns.barplot(
|
335
|
-
data=plot_df,
|
336
|
-
x="Cell Type",
|
337
|
-
y="value",
|
338
|
-
hue="Covariate",
|
339
|
-
palette=cmap,
|
340
|
-
ax=ax,
|
341
|
-
)
|
342
|
-
cell_types = pd.unique(plot_df["Cell Type"])
|
343
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
344
|
-
return ax
|
345
|
-
|
346
|
-
@staticmethod
|
347
|
-
def boxplots( # pragma: no cover
|
348
|
-
data: Union[AnnData, MuData],
|
349
|
-
feature_name: str,
|
350
|
-
modality_key: str = "coda",
|
351
|
-
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
352
|
-
plot_facets: bool = False,
|
353
|
-
add_dots: bool = False,
|
354
|
-
model: CompositionalModel2 = None,
|
355
|
-
cell_types: Optional[list] = None,
|
356
|
-
args_boxplot: Optional[dict] = None,
|
357
|
-
args_swarmplot: Optional[dict] = None,
|
358
|
-
figsize: Optional[tuple[float, float]] = None,
|
359
|
-
dpi: Optional[int] = 100,
|
360
|
-
cmap: Optional[str] = "Blues",
|
361
|
-
show_legend: Optional[bool] = True,
|
362
|
-
level_order: list[str] = None,
|
363
|
-
) -> Optional[Union[plt.Axes, sns.axisgrid.FacetGrid]]:
|
364
|
-
"""Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
|
365
|
-
with intra--group separation by a covariate from data.obs.
|
366
|
-
|
367
|
-
Args:
|
368
|
-
data: AnnData object or MuData object
|
369
|
-
feature_name: The name of the feature in data.obs to plot
|
370
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
371
|
-
y_scale: Transformation to of cell counts. Options: "relative" - Relative abundance, "log" - log(count),
|
372
|
-
"log10" - log10(count), "count" - absolute abundance (cell counts).
|
373
|
-
Defaults to "relative".
|
374
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets. Defaults to False.
|
375
|
-
add_dots: If True, overlay a scatterplot with one dot for each data point. Defaults to False.
|
376
|
-
model: When draw_effects, specify a tasCODA model
|
377
|
-
cell_types: Subset of cell types that should be plotted. Defaults to None.
|
378
|
-
args_boxplot: Arguments passed to sns.boxplot. Defaults to {}.
|
379
|
-
args_swarmplot: Arguments passed to sns.swarmplot. Defaults to {}.
|
380
|
-
figsize: Figure size. Defaults to None.
|
381
|
-
dpi: Dpi setting. Defaults to 100.
|
382
|
-
cmap: The seaborn color map for the barplot. Defaults to "Blues".
|
383
|
-
show_legend: If True, adds a legend. Defaults to True.
|
384
|
-
level_order: Custom ordering of bars on the x-axis. Defaults to None.
|
385
|
-
|
386
|
-
Returns:
|
387
|
-
Depending on `plot_facets`, returns a :class:`~matplotlib.axes.Axes` (`plot_facets = False`)
|
388
|
-
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
389
|
-
|
390
|
-
Examples:
|
391
|
-
Example with scCODA:
|
392
|
-
>>> import pertpy as pt
|
393
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
394
|
-
>>> sccoda = pt.tl.Sccoda()
|
395
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
396
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
397
|
-
>>> pt.pl.coda.boxplots(mdata, feature_name="condition", add_dots=True)
|
398
|
-
"""
|
399
|
-
if args_boxplot is None:
|
400
|
-
args_boxplot = {}
|
401
|
-
if args_swarmplot is None:
|
402
|
-
args_swarmplot = {}
|
403
|
-
if isinstance(data, MuData):
|
404
|
-
data = data[modality_key]
|
405
|
-
if isinstance(data, AnnData):
|
406
|
-
data = data
|
407
|
-
# y scale transformations
|
408
|
-
if y_scale == "relative":
|
409
|
-
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
410
|
-
X = data.X / sample_sums
|
411
|
-
value_name = "Proportion"
|
412
|
-
# add pseudocount 0.5 if using log scale
|
413
|
-
elif y_scale == "log":
|
414
|
-
X = data.X.copy()
|
415
|
-
X[X == 0] = 0.5
|
416
|
-
X = np.log(X)
|
417
|
-
value_name = "log(count)"
|
418
|
-
elif y_scale == "log10":
|
419
|
-
X = data.X.copy()
|
420
|
-
X[X == 0] = 0.5
|
421
|
-
X = np.log(X)
|
422
|
-
value_name = "log10(count)"
|
423
|
-
elif y_scale == "count":
|
424
|
-
X = data.X
|
425
|
-
value_name = "count"
|
426
|
-
else:
|
427
|
-
raise ValueError("Invalid y_scale transformation")
|
428
|
-
|
429
|
-
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
430
|
-
data.obs[feature_name], left_index=True, right_index=True
|
431
|
-
)
|
432
|
-
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
433
|
-
if cell_types is not None:
|
434
|
-
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
435
|
-
|
436
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
437
|
-
# We had to drop the dependency.
|
438
|
-
# Get credible effects results from model
|
439
|
-
# if draw_effects:
|
440
|
-
# if model is not None:
|
441
|
-
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
442
|
-
# else:
|
443
|
-
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
444
|
-
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
445
|
-
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
446
|
-
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
447
|
-
|
448
|
-
# If plot as facets, create a FacetGrid and map boxplot to it.
|
449
|
-
if plot_facets:
|
450
|
-
if level_order is None:
|
451
|
-
level_order = pd.unique(plot_df[feature_name])
|
452
|
-
|
453
|
-
K = X.shape[1]
|
454
|
-
|
455
|
-
if figsize is not None:
|
456
|
-
height = figsize[0]
|
457
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
458
|
-
else:
|
459
|
-
height = 3
|
460
|
-
aspect = 2
|
461
|
-
|
462
|
-
g = sns.FacetGrid(
|
463
|
-
plot_df,
|
464
|
-
col="Cell type",
|
465
|
-
sharey=False,
|
466
|
-
col_wrap=int(np.floor(np.sqrt(K))),
|
467
|
-
height=height,
|
468
|
-
aspect=aspect,
|
469
|
-
)
|
470
|
-
g.map(
|
471
|
-
sns.boxplot,
|
472
|
-
feature_name,
|
473
|
-
value_name,
|
474
|
-
palette=cmap,
|
475
|
-
order=level_order,
|
476
|
-
**args_boxplot,
|
477
|
-
)
|
478
|
-
|
479
|
-
if add_dots:
|
480
|
-
if "hue" in args_swarmplot:
|
481
|
-
hue = args_swarmplot.pop("hue")
|
482
|
-
else:
|
483
|
-
hue = None
|
484
|
-
|
485
|
-
if hue is None:
|
486
|
-
g.map(
|
487
|
-
sns.swarmplot,
|
488
|
-
feature_name,
|
489
|
-
value_name,
|
490
|
-
color="black",
|
491
|
-
order=level_order,
|
492
|
-
**args_swarmplot,
|
493
|
-
).set_titles("{col_name}")
|
494
|
-
else:
|
495
|
-
g.map(
|
496
|
-
sns.swarmplot,
|
497
|
-
feature_name,
|
498
|
-
value_name,
|
499
|
-
hue,
|
500
|
-
order=level_order,
|
501
|
-
**args_swarmplot,
|
502
|
-
).set_titles("{col_name}")
|
503
|
-
return g
|
504
|
-
|
505
|
-
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
506
|
-
else:
|
507
|
-
if level_order:
|
508
|
-
args_boxplot["hue_order"] = level_order
|
509
|
-
args_swarmplot["hue_order"] = level_order
|
510
|
-
|
511
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
512
|
-
|
513
|
-
ax = sns.boxplot(
|
514
|
-
x="Cell type",
|
515
|
-
y=value_name,
|
516
|
-
hue=feature_name,
|
517
|
-
data=plot_df,
|
518
|
-
fliersize=1,
|
519
|
-
palette=cmap,
|
520
|
-
ax=ax,
|
521
|
-
**args_boxplot,
|
522
|
-
)
|
523
|
-
|
524
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
525
|
-
# We had to drop the dependency.
|
526
|
-
# if draw_effects:
|
527
|
-
# pairs = [
|
528
|
-
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
529
|
-
# for _, row in credible_effects_df.iterrows()
|
530
|
-
# ]
|
531
|
-
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
532
|
-
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
533
|
-
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
534
|
-
# annot.annotate()
|
535
|
-
|
536
|
-
if add_dots:
|
537
|
-
sns.swarmplot(
|
538
|
-
x="Cell type",
|
539
|
-
y=value_name,
|
540
|
-
data=plot_df,
|
541
|
-
hue=feature_name,
|
542
|
-
ax=ax,
|
543
|
-
dodge=True,
|
544
|
-
color="black",
|
545
|
-
**args_swarmplot,
|
546
|
-
)
|
547
|
-
|
548
|
-
cell_types = pd.unique(plot_df["Cell type"])
|
549
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
550
|
-
|
551
|
-
if show_legend:
|
552
|
-
handles, labels = ax.get_legend_handles_labels()
|
553
|
-
handout = []
|
554
|
-
labelout = []
|
555
|
-
for h, l in zip(handles, labels):
|
556
|
-
if l not in labelout:
|
557
|
-
labelout.append(l)
|
558
|
-
handout.append(h)
|
559
|
-
ax.legend(
|
560
|
-
handout,
|
561
|
-
labelout,
|
562
|
-
loc="upper left",
|
563
|
-
bbox_to_anchor=(1, 1),
|
564
|
-
ncol=1,
|
565
|
-
title=feature_name,
|
566
|
-
)
|
567
|
-
|
568
|
-
plt.tight_layout()
|
569
|
-
return ax
|
570
|
-
|
571
|
-
@staticmethod
|
572
|
-
def rel_abundance_dispersion_plot( # pragma: no cover
|
573
|
-
data: Union[AnnData, MuData],
|
574
|
-
modality_key: str = "coda",
|
575
|
-
abundant_threshold: Optional[float] = 0.9,
|
576
|
-
default_color: Optional[str] = "Grey",
|
577
|
-
abundant_color: Optional[str] = "Red",
|
578
|
-
label_cell_types: bool = True,
|
579
|
-
figsize: Optional[tuple[float, float]] = None,
|
580
|
-
dpi: Optional[int] = 100,
|
581
|
-
ax: Axes = None,
|
582
|
-
) -> plt.Axes:
|
583
|
-
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
584
|
-
|
585
|
-
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
|
586
|
-
|
587
|
-
Args:
|
588
|
-
data: AnnData object or MuData object.
|
589
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
590
|
-
Defaults to "coda".
|
591
|
-
abundant_threshold: Presence threshold for abundant cell types. Defaults to 0.9.
|
592
|
-
default_color: Bar color for all non-minimal cell types. Defaults to "Grey".
|
593
|
-
abundant_color: Bar color for cell types with abundant percentage larger than abundant_threshold.
|
594
|
-
Defaults to "Red".
|
595
|
-
label_cell_types: Label dots with cell type names. Defaults to True.
|
596
|
-
figsize: Figure size. Defaults to None.
|
597
|
-
dpi: Dpi setting. Defaults to 100.
|
598
|
-
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
599
|
-
|
600
|
-
Returns:
|
601
|
-
A :class:`~matplotlib.axes.Axes` object
|
602
|
-
|
603
|
-
Examples:
|
604
|
-
Example with scCODA:
|
605
|
-
>>> import pertpy as pt
|
606
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
607
|
-
>>> sccoda = pt.tl.Sccoda()
|
608
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
609
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
610
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
611
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
612
|
-
>>> pt.pl.coda.rel_abundance_dispersion_plot(mdata)
|
613
|
-
"""
|
614
|
-
if isinstance(data, MuData):
|
615
|
-
data = data[modality_key]
|
616
|
-
if isinstance(data, AnnData):
|
617
|
-
data = data
|
618
|
-
if ax is None:
|
619
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
620
|
-
|
621
|
-
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
622
|
-
|
623
|
-
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
624
|
-
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
625
|
-
|
626
|
-
# select reference
|
627
|
-
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
628
|
-
|
629
|
-
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
630
|
-
|
631
|
-
# Scatterplot
|
632
|
-
plot_df = pd.DataFrame(
|
633
|
-
{
|
634
|
-
"Total dispersion": cell_type_disp,
|
635
|
-
"Cell type": data.var.index,
|
636
|
-
"Presence": 1 - percent_zero,
|
637
|
-
"Is abundant": is_abundant,
|
638
|
-
}
|
639
|
-
)
|
640
|
-
|
641
|
-
if len(np.unique(plot_df["Is abundant"])) > 1:
|
642
|
-
palette = [default_color, abundant_color]
|
643
|
-
elif np.unique(plot_df["Is abundant"]) == [False]:
|
644
|
-
palette = [default_color]
|
645
|
-
else:
|
646
|
-
palette = [abundant_color]
|
647
|
-
|
648
|
-
ax = sns.scatterplot(
|
649
|
-
data=plot_df,
|
650
|
-
x="Presence",
|
651
|
-
y="Total dispersion",
|
652
|
-
hue="Is abundant",
|
653
|
-
palette=palette,
|
654
|
-
ax=ax,
|
655
|
-
)
|
656
|
-
|
657
|
-
# Text labels for abundant cell types
|
658
|
-
|
659
|
-
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
660
|
-
|
661
|
-
def label_point(x, y, val, ax):
|
662
|
-
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
663
|
-
texts = [
|
664
|
-
ax.text(
|
665
|
-
point["x"],
|
666
|
-
point["y"],
|
667
|
-
str(point["val"]),
|
668
|
-
)
|
669
|
-
for i, point in a.iterrows()
|
670
|
-
]
|
671
|
-
adjust_text(texts)
|
672
|
-
|
673
|
-
if label_cell_types:
|
674
|
-
label_point(
|
675
|
-
abundant_df["Presence"],
|
676
|
-
abundant_df["Total dispersion"],
|
677
|
-
abundant_df["Cell type"],
|
678
|
-
plt.gca(),
|
679
|
-
)
|
680
|
-
|
681
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
682
|
-
|
683
|
-
plt.tight_layout()
|
684
|
-
return ax
|
685
|
-
|
686
|
-
@staticmethod
|
687
|
-
def draw_tree( # pragma: no cover
|
688
|
-
data: Union[AnnData, MuData],
|
689
|
-
modality_key: str = "coda",
|
690
|
-
tree: Union[Tree, str] = "tree",
|
691
|
-
tight_text: Optional[bool] = False,
|
692
|
-
show_scale: Optional[bool] = False,
|
693
|
-
show: Optional[bool] = True,
|
694
|
-
file_name: Optional[str] = None,
|
695
|
-
units: Optional[Literal["px", "mm", "in"]] = "px",
|
696
|
-
h: Optional[float] = None,
|
697
|
-
w: Optional[float] = None,
|
698
|
-
dpi: Optional[int] = 90,
|
699
|
-
):
|
700
|
-
"""Plot a tree using input ete3 tree object.
|
701
|
-
|
702
|
-
Args:
|
703
|
-
data: AnnData object or MuData object.
|
704
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
705
|
-
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
|
706
|
-
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
707
|
-
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
708
|
-
Default to False.
|
709
|
-
show_scale: Include the scale legend in the tree image or not. Default to False.
|
710
|
-
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
711
|
-
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
712
|
-
Defaults to None.
|
713
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
|
714
|
-
h: Height of the image in units. Defaults to None.
|
715
|
-
w: Width of the image in units. Defaults to None.
|
716
|
-
dpi: Dots per inches. Defaults to 90.
|
717
|
-
|
718
|
-
Returns:
|
719
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
720
|
-
|
721
|
-
Examples:
|
722
|
-
Example with tascCODA:
|
723
|
-
>>> import pertpy as pt
|
724
|
-
>>> adata = pt.dt.smillie()
|
725
|
-
>>> tasccoda = pt.tl.Tasccoda()
|
726
|
-
>>> mdata = tasccoda.load(
|
727
|
-
>>> adata, type="sample_level",
|
728
|
-
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
729
|
-
>>> key_added="lineage", add_level_name=True
|
730
|
-
>>> )
|
731
|
-
>>> mdata = tasccoda.prepare(
|
732
|
-
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
733
|
-
>>> )
|
734
|
-
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
735
|
-
>>> pt.pl.coda.draw_tree(mdata, tree="lineage")
|
736
|
-
"""
|
737
|
-
if isinstance(data, MuData):
|
738
|
-
data = data[modality_key]
|
739
|
-
if isinstance(data, AnnData):
|
740
|
-
data = data
|
741
|
-
if isinstance(tree, str):
|
742
|
-
tree = data.uns[tree]
|
743
|
-
|
744
|
-
def my_layout(node):
|
745
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
746
|
-
faces.add_face_to_node(text_face, node, column=0, position="branch-right")
|
747
|
-
|
748
|
-
tree_style = TreeStyle()
|
749
|
-
tree_style.show_leaf_name = False
|
750
|
-
tree_style.layout_fn = my_layout
|
751
|
-
tree_style.show_scale = show_scale
|
752
|
-
if file_name is not None:
|
753
|
-
tree.render(file_name, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
754
|
-
if show:
|
755
|
-
return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
756
|
-
else:
|
757
|
-
return tree, tree_style
|
758
|
-
|
759
|
-
@staticmethod
|
760
|
-
def draw_effects( # pragma: no cover
|
761
|
-
data: Union[AnnData, MuData],
|
762
|
-
covariate: str,
|
763
|
-
modality_key: str = "coda",
|
764
|
-
tree: Union[Tree, str] = "tree",
|
765
|
-
show_legend: Optional[bool] = None,
|
766
|
-
show_leaf_effects: Optional[bool] = False,
|
767
|
-
tight_text: Optional[bool] = False,
|
768
|
-
show_scale: Optional[bool] = False,
|
769
|
-
show: Optional[bool] = True,
|
770
|
-
file_name: Optional[str] = None,
|
771
|
-
units: Optional[Literal["px", "mm", "in"]] = "in",
|
772
|
-
h: Optional[float] = None,
|
773
|
-
w: Optional[float] = None,
|
774
|
-
dpi: Optional[int] = 90,
|
775
|
-
):
|
776
|
-
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
777
|
-
|
778
|
-
Args:
|
779
|
-
data: AnnData object or MuData object.
|
780
|
-
covariate: The covariate, whose effects should be plotted.
|
781
|
-
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
782
|
-
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`. Defaults to "tree".
|
783
|
-
show_legend: If show legend of nodes significant effects or not. Default is False if show_leaf_effects is True.
|
784
|
-
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects. Defaults to False.
|
785
|
-
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
786
|
-
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
787
|
-
Defaults to False.
|
788
|
-
show_scale: Include the scale legend in the tree image or not. Defaults to False.
|
789
|
-
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
|
790
|
-
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
|
791
|
-
Defaults to None.
|
792
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in".
|
793
|
-
h: Height of the image in units. Defaults to None.
|
794
|
-
w: Width of the image in units. Defaults to None.
|
795
|
-
dpi: Dots per inches. Defaults to 90.
|
796
|
-
|
797
|
-
Returns:
|
798
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
|
799
|
-
or plot the tree inline (`show = False`)
|
800
|
-
|
801
|
-
Examples:
|
802
|
-
Example with tascCODA:
|
803
|
-
>>> import pertpy as pt
|
804
|
-
>>> adata = pt.dt.smillie()
|
805
|
-
>>> tasccoda = pt.tl.Tasccoda()
|
806
|
-
>>> mdata = tasccoda.load(
|
807
|
-
>>> adata, type="sample_level",
|
808
|
-
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
809
|
-
>>> key_added="lineage", add_level_name=True
|
810
|
-
>>> )
|
811
|
-
>>> mdata = tasccoda.prepare(
|
812
|
-
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
813
|
-
>>> )
|
814
|
-
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
815
|
-
>>> pt.pl.coda.draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
816
|
-
"""
|
817
|
-
if isinstance(data, MuData):
|
818
|
-
data = data[modality_key]
|
819
|
-
if isinstance(data, AnnData):
|
820
|
-
data = data
|
821
|
-
if show_legend is None:
|
822
|
-
show_legend = not show_leaf_effects
|
823
|
-
elif show_legend:
|
824
|
-
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
825
|
-
|
826
|
-
if isinstance(tree, str):
|
827
|
-
tree = data.uns[tree]
|
828
|
-
# Collapse tree singularities
|
829
|
-
tree2 = collapse_singularities_2(tree)
|
830
|
-
|
831
|
-
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
832
|
-
node_effs.index = node_effs.index.get_level_values("Node")
|
833
|
-
|
834
|
-
covariates = data.uns["scCODA_params"]["covariate_names"]
|
835
|
-
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
836
|
-
eff_df = pd.concat(effect_dfs)
|
837
|
-
eff_df.index = pd.MultiIndex.from_product(
|
838
|
-
(covariates, data.var.index.tolist()),
|
839
|
-
names=["Covariate", "Cell Type"],
|
840
|
-
)
|
841
|
-
leaf_effs = eff_df.loc[(covariate,),].copy()
|
842
|
-
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
843
|
-
|
844
|
-
# Add effect values
|
845
|
-
for n in tree2.traverse():
|
846
|
-
nstyle = NodeStyle()
|
847
|
-
nstyle["size"] = 0
|
848
|
-
n.set_style(nstyle)
|
849
|
-
if n.name in node_effs.index:
|
850
|
-
e = node_effs.loc[n.name, "Final Parameter"]
|
851
|
-
n.add_feature("node_effect", e)
|
852
|
-
else:
|
853
|
-
n.add_feature("node_effect", 0)
|
854
|
-
if n.name in leaf_effs.index:
|
855
|
-
e = leaf_effs.loc[n.name, "Effect"]
|
856
|
-
n.add_feature("leaf_effect", e)
|
857
|
-
else:
|
858
|
-
n.add_feature("leaf_effect", 0)
|
859
|
-
|
860
|
-
# Scale effect values to get nice node sizes
|
861
|
-
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
862
|
-
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
863
|
-
|
864
|
-
def my_layout(node):
|
865
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
866
|
-
text_face.margin_left = 10
|
867
|
-
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
868
|
-
|
869
|
-
# if node.is_leaf():
|
870
|
-
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
871
|
-
if np.sign(node.node_effect) == 1:
|
872
|
-
color = "blue"
|
873
|
-
elif np.sign(node.node_effect) == -1:
|
874
|
-
color = "red"
|
875
|
-
else:
|
876
|
-
color = "cyan"
|
877
|
-
if size != 0:
|
878
|
-
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
879
|
-
|
880
|
-
tree_style = TreeStyle()
|
881
|
-
tree_style.show_leaf_name = False
|
882
|
-
tree_style.layout_fn = my_layout
|
883
|
-
tree_style.show_scale = show_scale
|
884
|
-
tree_style.draw_guiding_lines = True
|
885
|
-
tree_style.legend_position = 1
|
886
|
-
|
887
|
-
if show_legend:
|
888
|
-
tree_style.legend.add_face(TextFace("Effects"), column=0)
|
889
|
-
tree_style.legend.add_face(TextFace(" "), column=1)
|
890
|
-
for i in range(4, 0, -1):
|
891
|
-
tree_style.legend.add_face(
|
892
|
-
CircleFace(
|
893
|
-
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
894
|
-
"red",
|
895
|
-
),
|
896
|
-
column=0,
|
897
|
-
)
|
898
|
-
tree_style.legend.add_face(TextFace(f"{-eff_max * i / 4:.2f} "), column=0)
|
899
|
-
tree_style.legend.add_face(
|
900
|
-
CircleFace(
|
901
|
-
float(f"{np.abs(eff_max) * 10 * i / (eff_max * 4):.2f}"),
|
902
|
-
"blue",
|
903
|
-
),
|
904
|
-
column=1,
|
905
|
-
)
|
906
|
-
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
907
|
-
|
908
|
-
if show_leaf_effects:
|
909
|
-
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
910
|
-
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
911
|
-
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
912
|
-
|
913
|
-
dir_path = Path.cwd()
|
914
|
-
dir_path = Path(dir_path / "tree_effect.png")
|
915
|
-
tree2.render(dir_path, tree_style=tree_style, units="in")
|
916
|
-
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
917
|
-
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
918
|
-
img = mpimg.imread(dir_path)
|
919
|
-
ax[0].imshow(img)
|
920
|
-
ax[0].get_xaxis().set_visible(False)
|
921
|
-
ax[0].get_yaxis().set_visible(False)
|
922
|
-
ax[0].set_frame_on(False)
|
923
|
-
|
924
|
-
ax[1].get_yaxis().set_visible(False)
|
925
|
-
ax[1].spines["left"].set_visible(False)
|
926
|
-
ax[1].spines["right"].set_visible(False)
|
927
|
-
ax[1].spines["top"].set_visible(False)
|
928
|
-
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
929
|
-
plt.subplots_adjust(wspace=0)
|
930
|
-
|
931
|
-
if file_name is not None:
|
932
|
-
plt.savefig(file_name)
|
933
|
-
|
934
|
-
if file_name is not None and not show_leaf_effects:
|
935
|
-
tree2.render(file_name, tree_style=tree_style, units=units)
|
936
|
-
if show:
|
937
|
-
if not show_leaf_effects:
|
938
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
939
|
-
else:
|
940
|
-
if not show_leaf_effects:
|
941
|
-
return tree2, tree_style
|
942
|
-
|
943
|
-
@staticmethod
|
944
|
-
def effects_umap( # pragma: no cover
|
945
|
-
data: MuData,
|
946
|
-
effect_name: Optional[Union[str, list]],
|
947
|
-
cluster_key: str,
|
948
|
-
modality_key_1: str = "rna",
|
949
|
-
modality_key_2: str = "coda",
|
950
|
-
show: bool = None,
|
951
|
-
ax: Axes = None,
|
952
|
-
**kwargs,
|
953
|
-
):
|
954
|
-
"""Plot a UMAP visualization colored by effect strength.
|
955
|
-
|
956
|
-
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData (default is data['rna']) depending on the cluster they were assigned to.
|
957
|
-
|
958
|
-
Args:
|
959
|
-
data: AnnData object or MuData object.
|
960
|
-
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData (default is data['coda']) to plot
|
961
|
-
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells.
|
962
|
-
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
963
|
-
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
|
964
|
-
show: Whether to display the figure or return axis. Defaults to None.
|
965
|
-
ax: A matplotlib axes object. Only works if plotting a single component. Defaults to None.
|
966
|
-
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
967
|
-
|
968
|
-
Returns:
|
969
|
-
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
970
|
-
|
971
|
-
Examples:
|
972
|
-
Example with scCODA:
|
973
|
-
>>> import pertpy as pt
|
974
|
-
>>> haber_cells = pt.dt.haber_2017_regions()
|
975
|
-
>>> sccoda = pt.tl.Sccoda()
|
976
|
-
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
977
|
-
sample_identifier="batch", covariate_obs=["condition"])
|
978
|
-
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
979
|
-
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
980
|
-
|
981
|
-
>>> pt.pl.coda.effects_umap(mdata, effect_name="", cluster_key="")
|
982
|
-
#TODO: Add effect_name parameter and cluster_key and test the example
|
983
|
-
"""
|
984
|
-
data_rna = data[modality_key_1]
|
985
|
-
data_coda = data[modality_key_2]
|
986
|
-
if isinstance(effect_name, str):
|
987
|
-
effect_name = [effect_name]
|
988
|
-
for _, effect in enumerate(effect_name):
|
989
|
-
data_rna.obs[effect] = [data_coda.varm[effect].loc[f"{c}", "Effect"] for c in data_rna.obs[cluster_key]]
|
990
|
-
if kwargs.get("vmin"):
|
991
|
-
vmin = kwargs["vmin"]
|
992
|
-
kwargs.pop("vmin")
|
993
|
-
else:
|
994
|
-
vmin = min(data_rna.obs[effect].min() for _, effect in enumerate(effect_name))
|
995
|
-
if kwargs.get("vmax"):
|
996
|
-
vmax = kwargs["vmax"]
|
997
|
-
kwargs.pop("vmax")
|
998
|
-
else:
|
999
|
-
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))
|
1000
|
-
|
1001
|
-
return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, **kwargs)
|