pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py
CHANGED
@@ -1,21 +1,16 @@
|
|
1
|
-
|
1
|
+
import warnings
|
2
2
|
from typing import Literal, Optional, Union
|
3
3
|
|
4
|
-
import matplotlib.image as mpimg
|
5
4
|
import matplotlib.pyplot as plt
|
6
5
|
import numpy as np
|
7
|
-
import pandas as pd
|
8
|
-
import scanpy as sc
|
9
6
|
import seaborn as sns
|
10
|
-
from adjustText import adjust_text
|
11
7
|
from anndata import AnnData
|
12
|
-
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
13
8
|
from matplotlib import cm, rcParams
|
14
9
|
from matplotlib.axes import Axes
|
15
10
|
from matplotlib.colors import ListedColormap
|
16
11
|
from mudata import MuData
|
17
12
|
|
18
|
-
from pertpy.tools._coda._base_coda import CompositionalModel2
|
13
|
+
from pertpy.tools._coda._base_coda import CompositionalModel2
|
19
14
|
|
20
15
|
sns.set_style("ticks")
|
21
16
|
|
@@ -27,10 +22,10 @@ class CodaPlot:
|
|
27
22
|
type_names: list[str],
|
28
23
|
title: str,
|
29
24
|
level_names: list[str],
|
30
|
-
figsize:
|
31
|
-
dpi:
|
32
|
-
cmap:
|
33
|
-
show_legend:
|
25
|
+
figsize: tuple[float, float] | None = None,
|
26
|
+
dpi: int | None = 100,
|
27
|
+
cmap: ListedColormap | None = cm.tab20,
|
28
|
+
show_legend: bool | None = True,
|
34
29
|
) -> plt.Axes:
|
35
30
|
"""Plots a stacked barplot for one (discrete) covariate.
|
36
31
|
|
@@ -62,7 +57,7 @@ class CodaPlot:
|
|
62
57
|
cum_bars = np.zeros(n_bars)
|
63
58
|
|
64
59
|
for n in range(n_types):
|
65
|
-
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
|
60
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
66
61
|
plt.bar(
|
67
62
|
r,
|
68
63
|
bars,
|
@@ -85,13 +80,13 @@ class CodaPlot:
|
|
85
80
|
|
86
81
|
@staticmethod
|
87
82
|
def stacked_barplot( # pragma: no cover
|
88
|
-
data:
|
83
|
+
data: AnnData | MuData,
|
89
84
|
feature_name: str,
|
90
85
|
modality_key: str = "coda",
|
91
|
-
figsize:
|
92
|
-
dpi:
|
93
|
-
cmap:
|
94
|
-
show_legend:
|
86
|
+
figsize: tuple[float, float] | None = None,
|
87
|
+
dpi: int | None = 100,
|
88
|
+
cmap: ListedColormap | None = cm.tab20,
|
89
|
+
show_legend: bool | None = True,
|
95
90
|
level_order: list[str] = None,
|
96
91
|
) -> plt.Axes:
|
97
92
|
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
@@ -110,79 +105,49 @@ class CodaPlot:
|
|
110
105
|
A :class:`~matplotlib.axes.Axes` object
|
111
106
|
|
112
107
|
Examples:
|
113
|
-
Example with scCODA:
|
114
108
|
>>> import pertpy as pt
|
115
109
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
116
110
|
>>> sccoda = pt.tl.Sccoda()
|
117
111
|
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
118
112
|
sample_identifier="batch", covariate_obs=["condition"])
|
119
|
-
>>>
|
113
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
120
114
|
"""
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
show_legend=show_legend,
|
142
|
-
)
|
143
|
-
else:
|
144
|
-
# Order levels
|
145
|
-
if level_order:
|
146
|
-
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
147
|
-
levels = level_order
|
148
|
-
elif hasattr(data.obs[feature_name], "cat"):
|
149
|
-
levels = data.obs[feature_name].cat.categories.to_list()
|
150
|
-
else:
|
151
|
-
levels = pd.unique(data.obs[feature_name])
|
152
|
-
n_levels = len(levels)
|
153
|
-
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
154
|
-
|
155
|
-
for level in range(n_levels):
|
156
|
-
l_indices = np.where(data.obs[feature_name] == levels[level])
|
157
|
-
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
158
|
-
|
159
|
-
ax = CodaPlot.__stackbar(
|
160
|
-
feature_totals,
|
161
|
-
type_names=ct_names,
|
162
|
-
title=feature_name,
|
163
|
-
level_names=levels,
|
164
|
-
figsize=figsize,
|
165
|
-
dpi=dpi,
|
166
|
-
cmap=cmap,
|
167
|
-
show_legend=show_legend,
|
168
|
-
)
|
169
|
-
return ax
|
115
|
+
warnings.warn(
|
116
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
117
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
118
|
+
FutureWarning,
|
119
|
+
stacklevel=2,
|
120
|
+
)
|
121
|
+
|
122
|
+
from pertpy.tools import Sccoda
|
123
|
+
|
124
|
+
coda = Sccoda()
|
125
|
+
return coda.plot_stacked_barplot(
|
126
|
+
data=data,
|
127
|
+
feature_name=feature_name,
|
128
|
+
modality_key=modality_key,
|
129
|
+
figsize=figsize,
|
130
|
+
dpi=dpi,
|
131
|
+
palette=cmap,
|
132
|
+
show_legend=show_legend,
|
133
|
+
level_order=level_order,
|
134
|
+
)
|
170
135
|
|
171
136
|
@staticmethod
|
172
137
|
def effects_barplot( # pragma: no cover
|
173
|
-
data:
|
138
|
+
data: AnnData | MuData,
|
174
139
|
modality_key: str = "coda",
|
175
|
-
covariates:
|
140
|
+
covariates: str | list | None = None,
|
176
141
|
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
177
142
|
plot_facets: bool = True,
|
178
143
|
plot_zero_covariate: bool = True,
|
179
144
|
plot_zero_cell_type: bool = False,
|
180
|
-
figsize:
|
181
|
-
dpi:
|
182
|
-
cmap:
|
145
|
+
figsize: tuple[float, float] | None = None,
|
146
|
+
dpi: int | None = 100,
|
147
|
+
cmap: str | ListedColormap | None = cm.tab20,
|
183
148
|
level_order: list[str] = None,
|
184
|
-
args_barplot:
|
185
|
-
) ->
|
149
|
+
args_barplot: dict | None = None,
|
150
|
+
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
|
186
151
|
"""Barplot visualization for effects.
|
187
152
|
|
188
153
|
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
@@ -193,9 +158,12 @@ class CodaPlot:
|
|
193
158
|
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
194
159
|
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
195
160
|
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
196
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
197
|
-
|
198
|
-
|
161
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
162
|
+
Defaults to True.
|
163
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
164
|
+
Defaults to True.
|
165
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
166
|
+
Defaults to False.
|
199
167
|
figsize: Figure size. Defaults to None.
|
200
168
|
dpi: Figure size. Defaults to 100.
|
201
169
|
cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
|
@@ -207,7 +175,6 @@ class CodaPlot:
|
|
207
175
|
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
208
176
|
|
209
177
|
Examples:
|
210
|
-
Example with scCODA:
|
211
178
|
>>> import pertpy as pt
|
212
179
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
213
180
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -215,152 +182,50 @@ class CodaPlot:
|
|
215
182
|
sample_identifier="batch", covariate_obs=["condition"])
|
216
183
|
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
217
184
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
218
|
-
>>>
|
185
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
219
186
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
data = data
|
226
|
-
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
227
|
-
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
228
|
-
if covariates is not None:
|
229
|
-
if isinstance(covariates, str):
|
230
|
-
covariates = [covariates]
|
231
|
-
partial_covariate_names = [
|
232
|
-
covariate_name
|
233
|
-
for covariate_name in covariate_names
|
234
|
-
if any(covariate in covariate_name for covariate in covariates)
|
235
|
-
]
|
236
|
-
covariate_names = partial_covariate_names
|
237
|
-
covariate_names_non_zero = [
|
238
|
-
covariate_name
|
239
|
-
for covariate_name in covariate_names
|
240
|
-
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
241
|
-
]
|
242
|
-
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
243
|
-
if not plot_zero_covariate:
|
244
|
-
covariate_names = covariate_names_non_zero
|
245
|
-
|
246
|
-
# set up df for plotting
|
247
|
-
plot_df = pd.concat(
|
248
|
-
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
249
|
-
axis=1,
|
187
|
+
warnings.warn(
|
188
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
189
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
190
|
+
FutureWarning,
|
191
|
+
stacklevel=2,
|
250
192
|
)
|
251
|
-
plot_df.columns = covariate_names
|
252
|
-
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
253
|
-
|
254
|
-
plot_df = plot_df.reset_index()
|
255
|
-
|
256
|
-
if len(covariate_names_zero) != 0:
|
257
|
-
if plot_facets:
|
258
|
-
if plot_zero_covariate and not plot_zero_cell_type:
|
259
|
-
plot_df = plot_df[plot_df["value"] != 0]
|
260
|
-
for covariate_name_zero in covariate_names_zero:
|
261
|
-
new_row = {
|
262
|
-
"Covariate": covariate_name_zero,
|
263
|
-
"Cell Type": "zero",
|
264
|
-
"value": 0,
|
265
|
-
}
|
266
|
-
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
267
|
-
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
268
|
-
plot_df = plot_df.sort_values(["covariate_"])
|
269
|
-
if not plot_zero_cell_type:
|
270
|
-
cell_type_names_zero = [
|
271
|
-
name
|
272
|
-
for name in plot_df["Cell Type"].unique()
|
273
|
-
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
274
|
-
]
|
275
|
-
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
276
|
-
|
277
|
-
# If plot as facets, create a FacetGrid and map barplot to it.
|
278
|
-
if plot_facets:
|
279
|
-
if isinstance(cmap, ListedColormap):
|
280
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
281
|
-
if figsize is not None:
|
282
|
-
height = figsize[0]
|
283
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
284
|
-
else:
|
285
|
-
height = 3
|
286
|
-
aspect = 2
|
287
|
-
|
288
|
-
g = sns.FacetGrid(
|
289
|
-
plot_df,
|
290
|
-
col="Covariate",
|
291
|
-
sharey=True,
|
292
|
-
sharex=False,
|
293
|
-
height=height,
|
294
|
-
aspect=aspect,
|
295
|
-
)
|
296
193
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
ax.set_xticks([])
|
315
|
-
return g
|
316
|
-
|
317
|
-
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
318
|
-
else:
|
319
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
320
|
-
if len(covariate_names) == 1:
|
321
|
-
if isinstance(cmap, ListedColormap):
|
322
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
323
|
-
sns.barplot(
|
324
|
-
data=plot_df,
|
325
|
-
x="Cell Type",
|
326
|
-
y="value",
|
327
|
-
palette=cmap,
|
328
|
-
ax=ax,
|
329
|
-
)
|
330
|
-
ax.set_title(covariate_names[0])
|
331
|
-
else:
|
332
|
-
if isinstance(cmap, ListedColormap):
|
333
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
|
334
|
-
sns.barplot(
|
335
|
-
data=plot_df,
|
336
|
-
x="Cell Type",
|
337
|
-
y="value",
|
338
|
-
hue="Covariate",
|
339
|
-
palette=cmap,
|
340
|
-
ax=ax,
|
341
|
-
)
|
342
|
-
cell_types = pd.unique(plot_df["Cell Type"])
|
343
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
344
|
-
return ax
|
194
|
+
from pertpy.tools import Sccoda
|
195
|
+
|
196
|
+
coda = Sccoda()
|
197
|
+
return coda.plot_effects_barplot(
|
198
|
+
data=data,
|
199
|
+
modality_key=modality_key,
|
200
|
+
covariates=covariates,
|
201
|
+
parameter=parameter,
|
202
|
+
plot_facets=plot_facets,
|
203
|
+
plot_zero_covariate=plot_zero_covariate,
|
204
|
+
plot_zero_cell_type=plot_zero_cell_type,
|
205
|
+
figsize=figsize,
|
206
|
+
dpi=dpi,
|
207
|
+
palette=cmap,
|
208
|
+
level_order=level_order,
|
209
|
+
args_barplot=args_barplot,
|
210
|
+
)
|
345
211
|
|
346
212
|
@staticmethod
|
347
213
|
def boxplots( # pragma: no cover
|
348
|
-
data:
|
214
|
+
data: AnnData | MuData,
|
349
215
|
feature_name: str,
|
350
216
|
modality_key: str = "coda",
|
351
217
|
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
352
218
|
plot_facets: bool = False,
|
353
219
|
add_dots: bool = False,
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
show_legend: Optional[bool] = True,
|
220
|
+
cell_types: list | None = None,
|
221
|
+
args_boxplot: dict | None = None,
|
222
|
+
args_swarmplot: dict | None = None,
|
223
|
+
figsize: tuple[float, float] | None = None,
|
224
|
+
dpi: int | None = 100,
|
225
|
+
cmap: str | None = "Blues",
|
226
|
+
show_legend: bool | None = True,
|
362
227
|
level_order: list[str] = None,
|
363
|
-
) ->
|
228
|
+
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
|
364
229
|
"""Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
|
365
230
|
with intra--group separation by a covariate from data.obs.
|
366
231
|
|
@@ -388,196 +253,50 @@ class CodaPlot:
|
|
388
253
|
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
389
254
|
|
390
255
|
Examples:
|
391
|
-
Example with scCODA:
|
392
256
|
>>> import pertpy as pt
|
393
257
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
394
258
|
>>> sccoda = pt.tl.Sccoda()
|
395
259
|
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
396
260
|
sample_identifier="batch", covariate_obs=["condition"])
|
397
|
-
>>>
|
261
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
398
262
|
"""
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
data = data[modality_key]
|
405
|
-
if isinstance(data, AnnData):
|
406
|
-
data = data
|
407
|
-
# y scale transformations
|
408
|
-
if y_scale == "relative":
|
409
|
-
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
410
|
-
X = data.X / sample_sums
|
411
|
-
value_name = "Proportion"
|
412
|
-
# add pseudocount 0.5 if using log scale
|
413
|
-
elif y_scale == "log":
|
414
|
-
X = data.X.copy()
|
415
|
-
X[X == 0] = 0.5
|
416
|
-
X = np.log(X)
|
417
|
-
value_name = "log(count)"
|
418
|
-
elif y_scale == "log10":
|
419
|
-
X = data.X.copy()
|
420
|
-
X[X == 0] = 0.5
|
421
|
-
X = np.log(X)
|
422
|
-
value_name = "log10(count)"
|
423
|
-
elif y_scale == "count":
|
424
|
-
X = data.X
|
425
|
-
value_name = "count"
|
426
|
-
else:
|
427
|
-
raise ValueError("Invalid y_scale transformation")
|
428
|
-
|
429
|
-
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
430
|
-
data.obs[feature_name], left_index=True, right_index=True
|
263
|
+
warnings.warn(
|
264
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
265
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
266
|
+
FutureWarning,
|
267
|
+
stacklevel=2,
|
431
268
|
)
|
432
|
-
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
433
|
-
if cell_types is not None:
|
434
|
-
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
435
|
-
|
436
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
437
|
-
# We had to drop the dependency.
|
438
|
-
# Get credible effects results from model
|
439
|
-
# if draw_effects:
|
440
|
-
# if model is not None:
|
441
|
-
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
442
|
-
# else:
|
443
|
-
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
444
|
-
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
445
|
-
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
446
|
-
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
447
|
-
|
448
|
-
# If plot as facets, create a FacetGrid and map boxplot to it.
|
449
|
-
if plot_facets:
|
450
|
-
if level_order is None:
|
451
|
-
level_order = pd.unique(plot_df[feature_name])
|
452
|
-
|
453
|
-
K = X.shape[1]
|
454
|
-
|
455
|
-
if figsize is not None:
|
456
|
-
height = figsize[0]
|
457
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
458
|
-
else:
|
459
|
-
height = 3
|
460
|
-
aspect = 2
|
461
|
-
|
462
|
-
g = sns.FacetGrid(
|
463
|
-
plot_df,
|
464
|
-
col="Cell type",
|
465
|
-
sharey=False,
|
466
|
-
col_wrap=int(np.floor(np.sqrt(K))),
|
467
|
-
height=height,
|
468
|
-
aspect=aspect,
|
469
|
-
)
|
470
|
-
g.map(
|
471
|
-
sns.boxplot,
|
472
|
-
feature_name,
|
473
|
-
value_name,
|
474
|
-
palette=cmap,
|
475
|
-
order=level_order,
|
476
|
-
**args_boxplot,
|
477
|
-
)
|
478
269
|
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
value_name,
|
499
|
-
hue,
|
500
|
-
order=level_order,
|
501
|
-
**args_swarmplot,
|
502
|
-
).set_titles("{col_name}")
|
503
|
-
return g
|
504
|
-
|
505
|
-
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
506
|
-
else:
|
507
|
-
if level_order:
|
508
|
-
args_boxplot["hue_order"] = level_order
|
509
|
-
args_swarmplot["hue_order"] = level_order
|
510
|
-
|
511
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
512
|
-
|
513
|
-
ax = sns.boxplot(
|
514
|
-
x="Cell type",
|
515
|
-
y=value_name,
|
516
|
-
hue=feature_name,
|
517
|
-
data=plot_df,
|
518
|
-
fliersize=1,
|
519
|
-
palette=cmap,
|
520
|
-
ax=ax,
|
521
|
-
**args_boxplot,
|
522
|
-
)
|
523
|
-
|
524
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
525
|
-
# We had to drop the dependency.
|
526
|
-
# if draw_effects:
|
527
|
-
# pairs = [
|
528
|
-
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
529
|
-
# for _, row in credible_effects_df.iterrows()
|
530
|
-
# ]
|
531
|
-
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
532
|
-
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
533
|
-
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
534
|
-
# annot.annotate()
|
535
|
-
|
536
|
-
if add_dots:
|
537
|
-
sns.swarmplot(
|
538
|
-
x="Cell type",
|
539
|
-
y=value_name,
|
540
|
-
data=plot_df,
|
541
|
-
hue=feature_name,
|
542
|
-
ax=ax,
|
543
|
-
dodge=True,
|
544
|
-
color="black",
|
545
|
-
**args_swarmplot,
|
546
|
-
)
|
547
|
-
|
548
|
-
cell_types = pd.unique(plot_df["Cell type"])
|
549
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
550
|
-
|
551
|
-
if show_legend:
|
552
|
-
handles, labels = ax.get_legend_handles_labels()
|
553
|
-
handout = []
|
554
|
-
labelout = []
|
555
|
-
for h, l in zip(handles, labels):
|
556
|
-
if l not in labelout:
|
557
|
-
labelout.append(l)
|
558
|
-
handout.append(h)
|
559
|
-
ax.legend(
|
560
|
-
handout,
|
561
|
-
labelout,
|
562
|
-
loc="upper left",
|
563
|
-
bbox_to_anchor=(1, 1),
|
564
|
-
ncol=1,
|
565
|
-
title=feature_name,
|
566
|
-
)
|
567
|
-
|
568
|
-
plt.tight_layout()
|
569
|
-
return ax
|
270
|
+
from pertpy.tools import Sccoda
|
271
|
+
|
272
|
+
coda = Sccoda()
|
273
|
+
return coda.plot_boxplots(
|
274
|
+
data=data,
|
275
|
+
feature_name=feature_name,
|
276
|
+
modality_key=modality_key,
|
277
|
+
y_scale=y_scale,
|
278
|
+
plot_facets=plot_facets,
|
279
|
+
add_dots=add_dots,
|
280
|
+
cell_types=cell_types,
|
281
|
+
args_boxplot=args_boxplot,
|
282
|
+
args_swarmplot=args_swarmplot,
|
283
|
+
figsize=figsize,
|
284
|
+
dpi=dpi,
|
285
|
+
palette=cmap,
|
286
|
+
show_legend=show_legend,
|
287
|
+
level_order=level_order,
|
288
|
+
)
|
570
289
|
|
571
290
|
@staticmethod
|
572
291
|
def rel_abundance_dispersion_plot( # pragma: no cover
|
573
|
-
data:
|
292
|
+
data: AnnData | MuData,
|
574
293
|
modality_key: str = "coda",
|
575
|
-
abundant_threshold:
|
576
|
-
default_color:
|
577
|
-
abundant_color:
|
294
|
+
abundant_threshold: float | None = 0.9,
|
295
|
+
default_color: str | None = "Grey",
|
296
|
+
abundant_color: str | None = "Red",
|
578
297
|
label_cell_types: bool = True,
|
579
|
-
figsize:
|
580
|
-
dpi:
|
298
|
+
figsize: tuple[float, float] | None = None,
|
299
|
+
dpi: int | None = 100,
|
581
300
|
ax: Axes = None,
|
582
301
|
) -> plt.Axes:
|
583
302
|
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
@@ -601,7 +320,6 @@ class CodaPlot:
|
|
601
320
|
A :class:`~matplotlib.axes.Axes` object
|
602
321
|
|
603
322
|
Examples:
|
604
|
-
Example with scCODA:
|
605
323
|
>>> import pertpy as pt
|
606
324
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
607
325
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -609,119 +327,74 @@ class CodaPlot:
|
|
609
327
|
sample_identifier="batch", covariate_obs=["condition"])
|
610
328
|
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
611
329
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
612
|
-
>>>
|
330
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
613
331
|
"""
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
620
|
-
|
621
|
-
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
622
|
-
|
623
|
-
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
624
|
-
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
625
|
-
|
626
|
-
# select reference
|
627
|
-
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
628
|
-
|
629
|
-
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
630
|
-
|
631
|
-
# Scatterplot
|
632
|
-
plot_df = pd.DataFrame(
|
633
|
-
{
|
634
|
-
"Total dispersion": cell_type_disp,
|
635
|
-
"Cell type": data.var.index,
|
636
|
-
"Presence": 1 - percent_zero,
|
637
|
-
"Is abundant": is_abundant,
|
638
|
-
}
|
332
|
+
warnings.warn(
|
333
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
334
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
335
|
+
FutureWarning,
|
336
|
+
stacklevel=2,
|
639
337
|
)
|
640
338
|
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
palette=palette,
|
339
|
+
from pertpy.tools import Sccoda
|
340
|
+
|
341
|
+
coda = Sccoda()
|
342
|
+
return coda.plot_rel_abundance_dispersion_plot(
|
343
|
+
data=data,
|
344
|
+
modality_key=modality_key,
|
345
|
+
abundant_threshold=abundant_threshold,
|
346
|
+
default_color=default_color,
|
347
|
+
abundant_color=abundant_color,
|
348
|
+
label_cell_types=label_cell_types,
|
349
|
+
figsize=figsize,
|
350
|
+
dpi=dpi,
|
654
351
|
ax=ax,
|
655
352
|
)
|
656
353
|
|
657
|
-
# Text labels for abundant cell types
|
658
|
-
|
659
|
-
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
660
|
-
|
661
|
-
def label_point(x, y, val, ax):
|
662
|
-
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
663
|
-
texts = [
|
664
|
-
ax.text(
|
665
|
-
point["x"],
|
666
|
-
point["y"],
|
667
|
-
str(point["val"]),
|
668
|
-
)
|
669
|
-
for i, point in a.iterrows()
|
670
|
-
]
|
671
|
-
adjust_text(texts)
|
672
|
-
|
673
|
-
if label_cell_types:
|
674
|
-
label_point(
|
675
|
-
abundant_df["Presence"],
|
676
|
-
abundant_df["Total dispersion"],
|
677
|
-
abundant_df["Cell type"],
|
678
|
-
plt.gca(),
|
679
|
-
)
|
680
|
-
|
681
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
682
|
-
|
683
|
-
plt.tight_layout()
|
684
|
-
return ax
|
685
|
-
|
686
354
|
@staticmethod
|
687
355
|
def draw_tree( # pragma: no cover
|
688
|
-
data:
|
356
|
+
data: AnnData | MuData,
|
689
357
|
modality_key: str = "coda",
|
690
|
-
tree:
|
691
|
-
tight_text:
|
692
|
-
show_scale:
|
693
|
-
show:
|
694
|
-
|
695
|
-
units:
|
696
|
-
|
697
|
-
|
698
|
-
dpi: Optional[int] = 90,
|
358
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
359
|
+
tight_text: bool | None = False,
|
360
|
+
show_scale: bool | None = False,
|
361
|
+
show: bool | None = True,
|
362
|
+
save: str | None = None,
|
363
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
364
|
+
figsize: tuple[float, float] | None = (None, None),
|
365
|
+
dpi: int | None = 90,
|
699
366
|
):
|
700
367
|
"""Plot a tree using input ete3 tree object.
|
701
368
|
|
702
369
|
Args:
|
703
370
|
data: AnnData object or MuData object.
|
704
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
705
|
-
|
371
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
372
|
+
Defaults to "coda".
|
373
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
374
|
+
Defaults to "tree".
|
706
375
|
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
707
|
-
producing slightly worse aligned text faces but improving
|
376
|
+
producing slightly worse aligned text faces but improving
|
377
|
+
the performance of tree visualization in scenes with a lot of text faces.
|
708
378
|
Default to False.
|
709
|
-
show_scale: Include the scale legend in the tree image or not.
|
710
|
-
|
711
|
-
|
379
|
+
show_scale: Include the scale legend in the tree image or not.
|
380
|
+
Defaults to False.
|
381
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
382
|
+
Defaults to True.
|
383
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
384
|
+
Output image can be saved whether show is True or not.
|
712
385
|
Defaults to None.
|
713
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
386
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
387
|
+
Defaults to "px".
|
714
388
|
h: Height of the image in units. Defaults to None.
|
715
389
|
w: Width of the image in units. Defaults to None.
|
716
390
|
dpi: Dots per inches. Defaults to 90.
|
717
391
|
|
718
392
|
Returns:
|
719
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or
|
393
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
720
394
|
|
721
395
|
Examples:
|
722
|
-
Example with tascCODA:
|
723
396
|
>>> import pertpy as pt
|
724
|
-
>>> adata = pt.dt.
|
397
|
+
>>> adata = pt.dt.tasccoda_example()
|
725
398
|
>>> tasccoda = pt.tl.Tasccoda()
|
726
399
|
>>> mdata = tasccoda.load(
|
727
400
|
>>> adata, type="sample_level",
|
@@ -732,56 +405,62 @@ class CodaPlot:
|
|
732
405
|
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
733
406
|
>>> )
|
734
407
|
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
735
|
-
>>>
|
408
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
409
|
+
|
410
|
+
Preview: #TODO: Add preview
|
736
411
|
"""
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
412
|
+
warnings.warn(
|
413
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
414
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
415
|
+
FutureWarning,
|
416
|
+
stacklevel=2,
|
417
|
+
)
|
418
|
+
|
419
|
+
from pertpy.tools import Tasccoda
|
420
|
+
|
421
|
+
coda = Tasccoda()
|
422
|
+
return coda.plot_draw_tree(
|
423
|
+
data=data,
|
424
|
+
modality_key=modality_key,
|
425
|
+
tree=tree,
|
426
|
+
tight_text=tight_text,
|
427
|
+
show_scale=show_scale,
|
428
|
+
show=show,
|
429
|
+
save=save,
|
430
|
+
units=units,
|
431
|
+
figsize=figsize,
|
432
|
+
dpi=dpi,
|
433
|
+
)
|
758
434
|
|
759
435
|
@staticmethod
|
760
436
|
def draw_effects( # pragma: no cover
|
761
|
-
data:
|
437
|
+
data: AnnData | MuData,
|
762
438
|
covariate: str,
|
763
439
|
modality_key: str = "coda",
|
764
|
-
tree:
|
765
|
-
show_legend:
|
766
|
-
show_leaf_effects:
|
767
|
-
tight_text:
|
768
|
-
show_scale:
|
769
|
-
show:
|
770
|
-
|
771
|
-
units:
|
772
|
-
|
773
|
-
|
774
|
-
dpi: Optional[int] = 90,
|
440
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
441
|
+
show_legend: bool | None = None,
|
442
|
+
show_leaf_effects: bool | None = False,
|
443
|
+
tight_text: bool | None = False,
|
444
|
+
show_scale: bool | None = False,
|
445
|
+
show: bool | None = True,
|
446
|
+
save: str | None = None,
|
447
|
+
units: Literal["px", "mm", "in"] | None = "in",
|
448
|
+
figsize: tuple[float, float] | None = (None, None),
|
449
|
+
dpi: int | None = 90,
|
775
450
|
):
|
776
451
|
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
777
452
|
|
778
453
|
Args:
|
779
454
|
data: AnnData object or MuData object.
|
780
455
|
covariate: The covariate, whose effects should be plotted.
|
781
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
782
|
-
|
783
|
-
|
784
|
-
|
456
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
457
|
+
Defaults to "coda".
|
458
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
459
|
+
Defaults to "tree".
|
460
|
+
show_legend: If show legend of nodes significant effects or not.
|
461
|
+
Defaults to False if show_leaf_effects is True.
|
462
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
463
|
+
Defaults to False.
|
785
464
|
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
786
465
|
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
787
466
|
Defaults to False.
|
@@ -799,9 +478,8 @@ class CodaPlot:
|
|
799
478
|
or plot the tree inline (`show = False`)
|
800
479
|
|
801
480
|
Examples:
|
802
|
-
Example with tascCODA:
|
803
481
|
>>> import pertpy as pt
|
804
|
-
>>> adata = pt.dt.
|
482
|
+
>>> adata = pt.dt.tasccoda_example()
|
805
483
|
>>> tasccoda = pt.tl.Tasccoda()
|
806
484
|
>>> mdata = tasccoda.load(
|
807
485
|
>>> adata, type="sample_level",
|
@@ -812,138 +490,38 @@ class CodaPlot:
|
|
812
490
|
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
813
491
|
>>> )
|
814
492
|
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
815
|
-
>>>
|
493
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
816
494
|
"""
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
show_legend = not show_leaf_effects
|
823
|
-
elif show_legend:
|
824
|
-
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
825
|
-
|
826
|
-
if isinstance(tree, str):
|
827
|
-
tree = data.uns[tree]
|
828
|
-
# Collapse tree singularities
|
829
|
-
tree2 = collapse_singularities_2(tree)
|
830
|
-
|
831
|
-
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
832
|
-
node_effs.index = node_effs.index.get_level_values("Node")
|
833
|
-
|
834
|
-
covariates = data.uns["scCODA_params"]["covariate_names"]
|
835
|
-
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
836
|
-
eff_df = pd.concat(effect_dfs)
|
837
|
-
eff_df.index = pd.MultiIndex.from_product(
|
838
|
-
(covariates, data.var.index.tolist()),
|
839
|
-
names=["Covariate", "Cell Type"],
|
495
|
+
warnings.warn(
|
496
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
497
|
+
" Please use the corresponding 'pt.tl' object",
|
498
|
+
FutureWarning,
|
499
|
+
stacklevel=2,
|
840
500
|
)
|
841
|
-
leaf_effs = eff_df.loc[(covariate,),].copy()
|
842
|
-
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
843
|
-
|
844
|
-
# Add effect values
|
845
|
-
for n in tree2.traverse():
|
846
|
-
nstyle = NodeStyle()
|
847
|
-
nstyle["size"] = 0
|
848
|
-
n.set_style(nstyle)
|
849
|
-
if n.name in node_effs.index:
|
850
|
-
e = node_effs.loc[n.name, "Final Parameter"]
|
851
|
-
n.add_feature("node_effect", e)
|
852
|
-
else:
|
853
|
-
n.add_feature("node_effect", 0)
|
854
|
-
if n.name in leaf_effs.index:
|
855
|
-
e = leaf_effs.loc[n.name, "Effect"]
|
856
|
-
n.add_feature("leaf_effect", e)
|
857
|
-
else:
|
858
|
-
n.add_feature("leaf_effect", 0)
|
859
|
-
|
860
|
-
# Scale effect values to get nice node sizes
|
861
|
-
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
862
|
-
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
863
|
-
|
864
|
-
def my_layout(node):
|
865
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
866
|
-
text_face.margin_left = 10
|
867
|
-
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
868
|
-
|
869
|
-
# if node.is_leaf():
|
870
|
-
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
871
|
-
if np.sign(node.node_effect) == 1:
|
872
|
-
color = "blue"
|
873
|
-
elif np.sign(node.node_effect) == -1:
|
874
|
-
color = "red"
|
875
|
-
else:
|
876
|
-
color = "cyan"
|
877
|
-
if size != 0:
|
878
|
-
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
879
|
-
|
880
|
-
tree_style = TreeStyle()
|
881
|
-
tree_style.show_leaf_name = False
|
882
|
-
tree_style.layout_fn = my_layout
|
883
|
-
tree_style.show_scale = show_scale
|
884
|
-
tree_style.draw_guiding_lines = True
|
885
|
-
tree_style.legend_position = 1
|
886
501
|
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
)
|
906
|
-
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
907
|
-
|
908
|
-
if show_leaf_effects:
|
909
|
-
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
910
|
-
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
911
|
-
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
912
|
-
|
913
|
-
dir_path = Path.cwd()
|
914
|
-
dir_path = Path(dir_path / "tree_effect.png")
|
915
|
-
tree2.render(dir_path, tree_style=tree_style, units="in")
|
916
|
-
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
917
|
-
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
918
|
-
img = mpimg.imread(dir_path)
|
919
|
-
ax[0].imshow(img)
|
920
|
-
ax[0].get_xaxis().set_visible(False)
|
921
|
-
ax[0].get_yaxis().set_visible(False)
|
922
|
-
ax[0].set_frame_on(False)
|
923
|
-
|
924
|
-
ax[1].get_yaxis().set_visible(False)
|
925
|
-
ax[1].spines["left"].set_visible(False)
|
926
|
-
ax[1].spines["right"].set_visible(False)
|
927
|
-
ax[1].spines["top"].set_visible(False)
|
928
|
-
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
929
|
-
plt.subplots_adjust(wspace=0)
|
930
|
-
|
931
|
-
if file_name is not None:
|
932
|
-
plt.savefig(file_name)
|
933
|
-
|
934
|
-
if file_name is not None and not show_leaf_effects:
|
935
|
-
tree2.render(file_name, tree_style=tree_style, units=units)
|
936
|
-
if show:
|
937
|
-
if not show_leaf_effects:
|
938
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
939
|
-
else:
|
940
|
-
if not show_leaf_effects:
|
941
|
-
return tree2, tree_style
|
502
|
+
from pertpy.tools import Tasccoda
|
503
|
+
|
504
|
+
coda = Tasccoda()
|
505
|
+
return coda.plot_draw_effects(
|
506
|
+
data=data,
|
507
|
+
modality_key=modality_key,
|
508
|
+
covariate=covariate,
|
509
|
+
tree=tree,
|
510
|
+
show_legend=show_legend,
|
511
|
+
show_leaf_effects=show_leaf_effects,
|
512
|
+
tight_text=tight_text,
|
513
|
+
show_scale=show_scale,
|
514
|
+
show=show,
|
515
|
+
save=save,
|
516
|
+
units=units,
|
517
|
+
figsize=figsize,
|
518
|
+
dpi=dpi,
|
519
|
+
)
|
942
520
|
|
943
521
|
@staticmethod
|
944
522
|
def effects_umap( # pragma: no cover
|
945
523
|
data: MuData,
|
946
|
-
effect_name:
|
524
|
+
effect_name: str | list | None,
|
947
525
|
cluster_key: str,
|
948
526
|
modality_key_1: str = "rna",
|
949
527
|
modality_key_2: str = "coda",
|
@@ -953,49 +531,71 @@ class CodaPlot:
|
|
953
531
|
):
|
954
532
|
"""Plot a UMAP visualization colored by effect strength.
|
955
533
|
|
956
|
-
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
534
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
535
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
957
536
|
|
958
537
|
Args:
|
959
538
|
data: AnnData object or MuData object.
|
960
|
-
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData
|
961
|
-
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
539
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
540
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
541
|
+
To assign cell types' effects to original cells.
|
962
542
|
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
963
|
-
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
543
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
544
|
+
Defaults to "coda".
|
964
545
|
show: Whether to display the figure or return axis. Defaults to None.
|
965
|
-
ax: A matplotlib axes object. Only works if plotting a single component.
|
546
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
547
|
+
Defaults to None.
|
966
548
|
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
967
549
|
|
968
550
|
Returns:
|
969
551
|
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
970
552
|
|
971
553
|
Examples:
|
972
|
-
Example with scCODA:
|
973
554
|
>>> import pertpy as pt
|
974
|
-
>>>
|
975
|
-
>>>
|
976
|
-
>>>
|
977
|
-
|
978
|
-
>>>
|
979
|
-
>>>
|
980
|
-
|
981
|
-
>>>
|
982
|
-
|
555
|
+
>>> import schist
|
556
|
+
>>> adata = pt.dt.haber_2017_regions()
|
557
|
+
>>> schist.inference.nested_model(adata, samples=100, random_seed=5678)
|
558
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
559
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
560
|
+
>>> cell_type_identifier="nsbm_level_1",
|
561
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
562
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
563
|
+
>>> add_level_name=True)sccoda = pt.tl.Sccoda()
|
564
|
+
>>> tasccoda_model.prepare(
|
565
|
+
>>> tasccoda_data,
|
566
|
+
>>> modality_key="coda",
|
567
|
+
>>> reference_cell_type="18",
|
568
|
+
>>> formula="condition",
|
569
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
570
|
+
>>> tree_key="tree"
|
571
|
+
>>> )
|
572
|
+
>>> tasccoda_model.run_nuts(
|
573
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
574
|
+
... )
|
575
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
576
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
577
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
578
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
579
|
+
>>> cluster_key="nsbm_level_1",
|
580
|
+
>>> )
|
983
581
|
"""
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
582
|
+
warnings.warn(
|
583
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
584
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
585
|
+
FutureWarning,
|
586
|
+
stacklevel=2,
|
587
|
+
)
|
588
|
+
|
589
|
+
from pertpy.tools import Tasccoda
|
590
|
+
|
591
|
+
coda = Tasccoda()
|
592
|
+
coda.plot_effects_umap(
|
593
|
+
data=data,
|
594
|
+
effect_name=effect_name,
|
595
|
+
cluster_key=cluster_key,
|
596
|
+
modality_key_1=modality_key_1,
|
597
|
+
modality_key_2=modality_key_2,
|
598
|
+
show=show,
|
599
|
+
ax=ax,
|
600
|
+
**kwargs,
|
601
|
+
)
|