pertpy 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +3 -2
- pertpy/data/__init__.py +5 -1
- pertpy/data/_dataloader.py +2 -4
- pertpy/data/_datasets.py +203 -92
- pertpy/metadata/__init__.py +4 -0
- pertpy/metadata/_cell_line.py +826 -0
- pertpy/metadata/_compound.py +129 -0
- pertpy/metadata/_drug.py +242 -0
- pertpy/metadata/_look_up.py +582 -0
- pertpy/metadata/_metadata.py +73 -0
- pertpy/metadata/_moa.py +129 -0
- pertpy/plot/__init__.py +1 -9
- pertpy/plot/_augur.py +53 -116
- pertpy/plot/_coda.py +277 -677
- pertpy/plot/_guide_rna.py +17 -35
- pertpy/plot/_milopy.py +59 -134
- pertpy/plot/_mixscape.py +152 -391
- pertpy/preprocessing/_guide_rna.py +88 -4
- pertpy/tools/__init__.py +8 -13
- pertpy/tools/_augur.py +315 -17
- pertpy/tools/_cinemaot.py +143 -4
- pertpy/tools/_coda/_base_coda.py +1210 -65
- pertpy/tools/_coda/_sccoda.py +50 -21
- pertpy/tools/_coda/_tasccoda.py +27 -19
- pertpy/tools/_dialogue.py +164 -56
- pertpy/tools/_differential_gene_expression.py +240 -14
- pertpy/tools/_distances/_distance_tests.py +8 -8
- pertpy/tools/_distances/_distances.py +184 -34
- pertpy/tools/_enrichment.py +465 -0
- pertpy/tools/_milo.py +345 -11
- pertpy/tools/_mixscape.py +668 -50
- pertpy/tools/_perturbation_space/_clustering.py +5 -1
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +526 -0
- pertpy/tools/_perturbation_space/_perturbation_space.py +135 -43
- pertpy/tools/_perturbation_space/_simple.py +51 -10
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_scgen.py +701 -0
- pertpy/tools/_scgen/_utils.py +1 -3
- pertpy/tools/decoupler_LICENSE +674 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/METADATA +31 -12
- pertpy-0.7.0.dist-info/RECORD +53 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_cinemaot.py +0 -81
- pertpy/plot/_dialogue.py +0 -91
- pertpy/plot/_scgen.py +0 -337
- pertpy/tools/_metadata/__init__.py +0 -0
- pertpy/tools/_metadata/_cell_line.py +0 -613
- pertpy/tools/_metadata/_look_up.py +0 -342
- pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
- pertpy/tools/_scgen/_jax_scgen.py +0 -370
- pertpy-0.6.0.dist-info/RECORD +0 -50
- /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
- {pertpy-0.6.0.dist-info → pertpy-0.7.0.dist-info}/licenses/LICENSE +0 -0
pertpy/plot/_coda.py
CHANGED
@@ -1,21 +1,16 @@
|
|
1
|
-
|
1
|
+
import warnings
|
2
2
|
from typing import Literal, Optional, Union
|
3
3
|
|
4
|
-
import matplotlib.image as mpimg
|
5
4
|
import matplotlib.pyplot as plt
|
6
5
|
import numpy as np
|
7
|
-
import pandas as pd
|
8
|
-
import scanpy as sc
|
9
6
|
import seaborn as sns
|
10
|
-
from adjustText import adjust_text
|
11
7
|
from anndata import AnnData
|
12
|
-
from ete3 import CircleFace, NodeStyle, TextFace, Tree, TreeStyle, faces
|
13
8
|
from matplotlib import cm, rcParams
|
14
9
|
from matplotlib.axes import Axes
|
15
10
|
from matplotlib.colors import ListedColormap
|
16
11
|
from mudata import MuData
|
17
12
|
|
18
|
-
from pertpy.tools._coda._base_coda import CompositionalModel2
|
13
|
+
from pertpy.tools._coda._base_coda import CompositionalModel2
|
19
14
|
|
20
15
|
sns.set_style("ticks")
|
21
16
|
|
@@ -27,10 +22,10 @@ class CodaPlot:
|
|
27
22
|
type_names: list[str],
|
28
23
|
title: str,
|
29
24
|
level_names: list[str],
|
30
|
-
figsize:
|
31
|
-
dpi:
|
32
|
-
cmap:
|
33
|
-
show_legend:
|
25
|
+
figsize: tuple[float, float] | None = None,
|
26
|
+
dpi: int | None = 100,
|
27
|
+
cmap: ListedColormap | None = cm.tab20,
|
28
|
+
show_legend: bool | None = True,
|
34
29
|
) -> plt.Axes:
|
35
30
|
"""Plots a stacked barplot for one (discrete) covariate.
|
36
31
|
|
@@ -62,7 +57,7 @@ class CodaPlot:
|
|
62
57
|
cum_bars = np.zeros(n_bars)
|
63
58
|
|
64
59
|
for n in range(n_types):
|
65
|
-
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums)]
|
60
|
+
bars = [i / j * 100 for i, j in zip([y[k][n] for k in range(n_bars)], sample_sums, strict=False)]
|
66
61
|
plt.bar(
|
67
62
|
r,
|
68
63
|
bars,
|
@@ -85,13 +80,13 @@ class CodaPlot:
|
|
85
80
|
|
86
81
|
@staticmethod
|
87
82
|
def stacked_barplot( # pragma: no cover
|
88
|
-
data:
|
83
|
+
data: AnnData | MuData,
|
89
84
|
feature_name: str,
|
90
85
|
modality_key: str = "coda",
|
91
|
-
figsize:
|
92
|
-
dpi:
|
93
|
-
cmap:
|
94
|
-
show_legend:
|
86
|
+
figsize: tuple[float, float] | None = None,
|
87
|
+
dpi: int | None = 100,
|
88
|
+
cmap: ListedColormap | None = cm.tab20,
|
89
|
+
show_legend: bool | None = True,
|
95
90
|
level_order: list[str] = None,
|
96
91
|
) -> plt.Axes:
|
97
92
|
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
@@ -110,79 +105,49 @@ class CodaPlot:
|
|
110
105
|
A :class:`~matplotlib.axes.Axes` object
|
111
106
|
|
112
107
|
Examples:
|
113
|
-
Example with scCODA:
|
114
108
|
>>> import pertpy as pt
|
115
109
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
116
110
|
>>> sccoda = pt.tl.Sccoda()
|
117
111
|
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
118
112
|
sample_identifier="batch", covariate_obs=["condition"])
|
119
|
-
>>>
|
113
|
+
>>> sccoda.plot_stacked_barplot(mdata, feature_name="samples")
|
120
114
|
"""
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
show_legend=show_legend,
|
142
|
-
)
|
143
|
-
else:
|
144
|
-
# Order levels
|
145
|
-
if level_order:
|
146
|
-
assert set(level_order) == set(data.obs[feature_name]), "level order is inconsistent with levels"
|
147
|
-
levels = level_order
|
148
|
-
elif hasattr(data.obs[feature_name], "cat"):
|
149
|
-
levels = data.obs[feature_name].cat.categories.to_list()
|
150
|
-
else:
|
151
|
-
levels = pd.unique(data.obs[feature_name])
|
152
|
-
n_levels = len(levels)
|
153
|
-
feature_totals = np.zeros([n_levels, data.X.shape[1]])
|
154
|
-
|
155
|
-
for level in range(n_levels):
|
156
|
-
l_indices = np.where(data.obs[feature_name] == levels[level])
|
157
|
-
feature_totals[level] = np.sum(data.X[l_indices], axis=0)
|
158
|
-
|
159
|
-
ax = CodaPlot.__stackbar(
|
160
|
-
feature_totals,
|
161
|
-
type_names=ct_names,
|
162
|
-
title=feature_name,
|
163
|
-
level_names=levels,
|
164
|
-
figsize=figsize,
|
165
|
-
dpi=dpi,
|
166
|
-
cmap=cmap,
|
167
|
-
show_legend=show_legend,
|
168
|
-
)
|
169
|
-
return ax
|
115
|
+
warnings.warn(
|
116
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
117
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
118
|
+
FutureWarning,
|
119
|
+
stacklevel=2,
|
120
|
+
)
|
121
|
+
|
122
|
+
from pertpy.tools import Sccoda
|
123
|
+
|
124
|
+
coda = Sccoda()
|
125
|
+
return coda.plot_stacked_barplot(
|
126
|
+
data=data,
|
127
|
+
feature_name=feature_name,
|
128
|
+
modality_key=modality_key,
|
129
|
+
figsize=figsize,
|
130
|
+
dpi=dpi,
|
131
|
+
palette=cmap,
|
132
|
+
show_legend=show_legend,
|
133
|
+
level_order=level_order,
|
134
|
+
)
|
170
135
|
|
171
136
|
@staticmethod
|
172
137
|
def effects_barplot( # pragma: no cover
|
173
|
-
data:
|
138
|
+
data: AnnData | MuData,
|
174
139
|
modality_key: str = "coda",
|
175
|
-
covariates:
|
140
|
+
covariates: str | list | None = None,
|
176
141
|
parameter: Literal["log2-fold change", "Final Parameter", "Expected Sample"] = "log2-fold change",
|
177
142
|
plot_facets: bool = True,
|
178
143
|
plot_zero_covariate: bool = True,
|
179
144
|
plot_zero_cell_type: bool = False,
|
180
|
-
figsize:
|
181
|
-
dpi:
|
182
|
-
cmap:
|
145
|
+
figsize: tuple[float, float] | None = None,
|
146
|
+
dpi: int | None = 100,
|
147
|
+
cmap: str | ListedColormap | None = cm.tab20,
|
183
148
|
level_order: list[str] = None,
|
184
|
-
args_barplot:
|
185
|
-
) ->
|
149
|
+
args_barplot: dict | None = None,
|
150
|
+
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
|
186
151
|
"""Barplot visualization for effects.
|
187
152
|
|
188
153
|
The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types.
|
@@ -193,9 +158,12 @@ class CodaPlot:
|
|
193
158
|
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
|
194
159
|
covariates: The name of the covariates in data.obs to plot. Defaults to None.
|
195
160
|
parameter: The parameter in effect summary to plot. Defaults to "log2-fold change".
|
196
|
-
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
197
|
-
|
198
|
-
|
161
|
+
plot_facets: If False, plot cell types on the x-axis. If True, plot as facets.
|
162
|
+
Defaults to True.
|
163
|
+
plot_zero_covariate: If True, plot covariate that have all zero effects. If False, do not plot.
|
164
|
+
Defaults to True.
|
165
|
+
plot_zero_cell_type: If True, plot cell type that have zero effect. If False, do not plot.
|
166
|
+
Defaults to False.
|
199
167
|
figsize: Figure size. Defaults to None.
|
200
168
|
dpi: Figure size. Defaults to 100.
|
201
169
|
cmap: The seaborn color map for the barplot. Defaults to cm.tab20.
|
@@ -207,7 +175,6 @@ class CodaPlot:
|
|
207
175
|
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
208
176
|
|
209
177
|
Examples:
|
210
|
-
Example with scCODA:
|
211
178
|
>>> import pertpy as pt
|
212
179
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
213
180
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -215,152 +182,50 @@ class CodaPlot:
|
|
215
182
|
sample_identifier="batch", covariate_obs=["condition"])
|
216
183
|
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
217
184
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
218
|
-
>>>
|
185
|
+
>>> sccoda.plot_effects_barplot(mdata)
|
219
186
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
data = data
|
226
|
-
# Get covariate names from adata, partition into those with nonzero effects for min. one cell type/no cell types
|
227
|
-
covariate_names = data.uns["scCODA_params"]["covariate_names"]
|
228
|
-
if covariates is not None:
|
229
|
-
if isinstance(covariates, str):
|
230
|
-
covariates = [covariates]
|
231
|
-
partial_covariate_names = [
|
232
|
-
covariate_name
|
233
|
-
for covariate_name in covariate_names
|
234
|
-
if any(covariate in covariate_name for covariate in covariates)
|
235
|
-
]
|
236
|
-
covariate_names = partial_covariate_names
|
237
|
-
covariate_names_non_zero = [
|
238
|
-
covariate_name
|
239
|
-
for covariate_name in covariate_names
|
240
|
-
if data.varm[f"effect_df_{covariate_name}"][parameter].any()
|
241
|
-
]
|
242
|
-
covariate_names_zero = list(set(covariate_names) - set(covariate_names_non_zero))
|
243
|
-
if not plot_zero_covariate:
|
244
|
-
covariate_names = covariate_names_non_zero
|
245
|
-
|
246
|
-
# set up df for plotting
|
247
|
-
plot_df = pd.concat(
|
248
|
-
[data.varm[f"effect_df_{covariate_name}"][parameter] for covariate_name in covariate_names],
|
249
|
-
axis=1,
|
187
|
+
warnings.warn(
|
188
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
189
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
190
|
+
FutureWarning,
|
191
|
+
stacklevel=2,
|
250
192
|
)
|
251
|
-
plot_df.columns = covariate_names
|
252
|
-
plot_df = pd.melt(plot_df, ignore_index=False, var_name="Covariate")
|
253
|
-
|
254
|
-
plot_df = plot_df.reset_index()
|
255
|
-
|
256
|
-
if len(covariate_names_zero) != 0:
|
257
|
-
if plot_facets:
|
258
|
-
if plot_zero_covariate and not plot_zero_cell_type:
|
259
|
-
plot_df = plot_df[plot_df["value"] != 0]
|
260
|
-
for covariate_name_zero in covariate_names_zero:
|
261
|
-
new_row = {
|
262
|
-
"Covariate": covariate_name_zero,
|
263
|
-
"Cell Type": "zero",
|
264
|
-
"value": 0,
|
265
|
-
}
|
266
|
-
plot_df = pd.concat([plot_df, pd.DataFrame([new_row])], ignore_index=True)
|
267
|
-
plot_df["covariate_"] = pd.Categorical(plot_df["Covariate"], covariate_names)
|
268
|
-
plot_df = plot_df.sort_values(["covariate_"])
|
269
|
-
if not plot_zero_cell_type:
|
270
|
-
cell_type_names_zero = [
|
271
|
-
name
|
272
|
-
for name in plot_df["Cell Type"].unique()
|
273
|
-
if (plot_df[plot_df["Cell Type"] == name]["value"] == 0).all()
|
274
|
-
]
|
275
|
-
plot_df = plot_df[~plot_df["Cell Type"].isin(cell_type_names_zero)]
|
276
|
-
|
277
|
-
# If plot as facets, create a FacetGrid and map barplot to it.
|
278
|
-
if plot_facets:
|
279
|
-
if isinstance(cmap, ListedColormap):
|
280
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
281
|
-
if figsize is not None:
|
282
|
-
height = figsize[0]
|
283
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
284
|
-
else:
|
285
|
-
height = 3
|
286
|
-
aspect = 2
|
287
|
-
|
288
|
-
g = sns.FacetGrid(
|
289
|
-
plot_df,
|
290
|
-
col="Covariate",
|
291
|
-
sharey=True,
|
292
|
-
sharex=False,
|
293
|
-
height=height,
|
294
|
-
aspect=aspect,
|
295
|
-
)
|
296
193
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
ax.set_xticks([])
|
315
|
-
return g
|
316
|
-
|
317
|
-
# If not plot as facets, call barplot to plot cell types on the x-axis.
|
318
|
-
else:
|
319
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
320
|
-
if len(covariate_names) == 1:
|
321
|
-
if isinstance(cmap, ListedColormap):
|
322
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(plot_df["Cell Type"].unique()))])
|
323
|
-
sns.barplot(
|
324
|
-
data=plot_df,
|
325
|
-
x="Cell Type",
|
326
|
-
y="value",
|
327
|
-
palette=cmap,
|
328
|
-
ax=ax,
|
329
|
-
)
|
330
|
-
ax.set_title(covariate_names[0])
|
331
|
-
else:
|
332
|
-
if isinstance(cmap, ListedColormap):
|
333
|
-
cmap = np.array([cmap(i % cmap.N) for i in range(len(covariate_names))])
|
334
|
-
sns.barplot(
|
335
|
-
data=plot_df,
|
336
|
-
x="Cell Type",
|
337
|
-
y="value",
|
338
|
-
hue="Covariate",
|
339
|
-
palette=cmap,
|
340
|
-
ax=ax,
|
341
|
-
)
|
342
|
-
cell_types = pd.unique(plot_df["Cell Type"])
|
343
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
344
|
-
return ax
|
194
|
+
from pertpy.tools import Sccoda
|
195
|
+
|
196
|
+
coda = Sccoda()
|
197
|
+
return coda.plot_effects_barplot(
|
198
|
+
data=data,
|
199
|
+
modality_key=modality_key,
|
200
|
+
covariates=covariates,
|
201
|
+
parameter=parameter,
|
202
|
+
plot_facets=plot_facets,
|
203
|
+
plot_zero_covariate=plot_zero_covariate,
|
204
|
+
plot_zero_cell_type=plot_zero_cell_type,
|
205
|
+
figsize=figsize,
|
206
|
+
dpi=dpi,
|
207
|
+
palette=cmap,
|
208
|
+
level_order=level_order,
|
209
|
+
args_barplot=args_barplot,
|
210
|
+
)
|
345
211
|
|
346
212
|
@staticmethod
|
347
213
|
def boxplots( # pragma: no cover
|
348
|
-
data:
|
214
|
+
data: AnnData | MuData,
|
349
215
|
feature_name: str,
|
350
216
|
modality_key: str = "coda",
|
351
217
|
y_scale: Literal["relative", "log", "log10", "count"] = "relative",
|
352
218
|
plot_facets: bool = False,
|
353
219
|
add_dots: bool = False,
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
show_legend: Optional[bool] = True,
|
220
|
+
cell_types: list | None = None,
|
221
|
+
args_boxplot: dict | None = None,
|
222
|
+
args_swarmplot: dict | None = None,
|
223
|
+
figsize: tuple[float, float] | None = None,
|
224
|
+
dpi: int | None = 100,
|
225
|
+
cmap: str | None = "Blues",
|
226
|
+
show_legend: bool | None = True,
|
362
227
|
level_order: list[str] = None,
|
363
|
-
) ->
|
228
|
+
) -> plt.Axes | sns.axisgrid.FacetGrid | None:
|
364
229
|
"""Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots,
|
365
230
|
with intra--group separation by a covariate from data.obs.
|
366
231
|
|
@@ -388,196 +253,50 @@ class CodaPlot:
|
|
388
253
|
or :class:`~sns.axisgrid.FacetGrid` (`plot_facets = True`) object
|
389
254
|
|
390
255
|
Examples:
|
391
|
-
Example with scCODA:
|
392
256
|
>>> import pertpy as pt
|
393
257
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
394
258
|
>>> sccoda = pt.tl.Sccoda()
|
395
259
|
>>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
|
396
260
|
sample_identifier="batch", covariate_obs=["condition"])
|
397
|
-
>>>
|
261
|
+
>>> sccoda.plot_boxplots(mdata, feature_name="condition", add_dots=True)
|
398
262
|
"""
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
data = data[modality_key]
|
405
|
-
if isinstance(data, AnnData):
|
406
|
-
data = data
|
407
|
-
# y scale transformations
|
408
|
-
if y_scale == "relative":
|
409
|
-
sample_sums = np.sum(data.X, axis=1, keepdims=True)
|
410
|
-
X = data.X / sample_sums
|
411
|
-
value_name = "Proportion"
|
412
|
-
# add pseudocount 0.5 if using log scale
|
413
|
-
elif y_scale == "log":
|
414
|
-
X = data.X.copy()
|
415
|
-
X[X == 0] = 0.5
|
416
|
-
X = np.log(X)
|
417
|
-
value_name = "log(count)"
|
418
|
-
elif y_scale == "log10":
|
419
|
-
X = data.X.copy()
|
420
|
-
X[X == 0] = 0.5
|
421
|
-
X = np.log(X)
|
422
|
-
value_name = "log10(count)"
|
423
|
-
elif y_scale == "count":
|
424
|
-
X = data.X
|
425
|
-
value_name = "count"
|
426
|
-
else:
|
427
|
-
raise ValueError("Invalid y_scale transformation")
|
428
|
-
|
429
|
-
count_df = pd.DataFrame(X, columns=data.var.index, index=data.obs.index).merge(
|
430
|
-
data.obs[feature_name], left_index=True, right_index=True
|
263
|
+
warnings.warn(
|
264
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
265
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
266
|
+
FutureWarning,
|
267
|
+
stacklevel=2,
|
431
268
|
)
|
432
|
-
plot_df = pd.melt(count_df, id_vars=feature_name, var_name="Cell type", value_name=value_name)
|
433
|
-
if cell_types is not None:
|
434
|
-
plot_df = plot_df[plot_df["Cell type"].isin(cell_types)]
|
435
|
-
|
436
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
437
|
-
# We had to drop the dependency.
|
438
|
-
# Get credible effects results from model
|
439
|
-
# if draw_effects:
|
440
|
-
# if model is not None:
|
441
|
-
# credible_effects_df = model.credible_effects(data, modality_key).to_frame().reset_index()
|
442
|
-
# else:
|
443
|
-
# print("[bold yellow]Specify a tasCODA model to draw effects")
|
444
|
-
# credible_effects_df[feature_name] = credible_effects_df["Covariate"].str.removeprefix(f"{feature_name}[T.")
|
445
|
-
# credible_effects_df[feature_name] = credible_effects_df[feature_name].str.removesuffix("]")
|
446
|
-
# credible_effects_df = credible_effects_df[credible_effects_df["Final Parameter"]]
|
447
|
-
|
448
|
-
# If plot as facets, create a FacetGrid and map boxplot to it.
|
449
|
-
if plot_facets:
|
450
|
-
if level_order is None:
|
451
|
-
level_order = pd.unique(plot_df[feature_name])
|
452
|
-
|
453
|
-
K = X.shape[1]
|
454
|
-
|
455
|
-
if figsize is not None:
|
456
|
-
height = figsize[0]
|
457
|
-
aspect = np.round(figsize[1] / figsize[0], 2)
|
458
|
-
else:
|
459
|
-
height = 3
|
460
|
-
aspect = 2
|
461
|
-
|
462
|
-
g = sns.FacetGrid(
|
463
|
-
plot_df,
|
464
|
-
col="Cell type",
|
465
|
-
sharey=False,
|
466
|
-
col_wrap=int(np.floor(np.sqrt(K))),
|
467
|
-
height=height,
|
468
|
-
aspect=aspect,
|
469
|
-
)
|
470
|
-
g.map(
|
471
|
-
sns.boxplot,
|
472
|
-
feature_name,
|
473
|
-
value_name,
|
474
|
-
palette=cmap,
|
475
|
-
order=level_order,
|
476
|
-
**args_boxplot,
|
477
|
-
)
|
478
269
|
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
value_name,
|
499
|
-
hue,
|
500
|
-
order=level_order,
|
501
|
-
**args_swarmplot,
|
502
|
-
).set_titles("{col_name}")
|
503
|
-
return g
|
504
|
-
|
505
|
-
# If not plot as facets, call boxplot to plot cell types on the x-axis.
|
506
|
-
else:
|
507
|
-
if level_order:
|
508
|
-
args_boxplot["hue_order"] = level_order
|
509
|
-
args_swarmplot["hue_order"] = level_order
|
510
|
-
|
511
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
512
|
-
|
513
|
-
ax = sns.boxplot(
|
514
|
-
x="Cell type",
|
515
|
-
y=value_name,
|
516
|
-
hue=feature_name,
|
517
|
-
data=plot_df,
|
518
|
-
fliersize=1,
|
519
|
-
palette=cmap,
|
520
|
-
ax=ax,
|
521
|
-
**args_boxplot,
|
522
|
-
)
|
523
|
-
|
524
|
-
# Currently disabled because the latest statsannotations does not support the latest seaborn.
|
525
|
-
# We had to drop the dependency.
|
526
|
-
# if draw_effects:
|
527
|
-
# pairs = [
|
528
|
-
# [(row["Cell Type"], row[feature_name]), (row["Cell Type"], "Control")]
|
529
|
-
# for _, row in credible_effects_df.iterrows()
|
530
|
-
# ]
|
531
|
-
# annot = Annotator(ax, pairs, data=plot_df, x="Cell type", y=value_name, hue=feature_name)
|
532
|
-
# annot.configure(test=None, loc="outside", color="red", line_height=0, verbose=False)
|
533
|
-
# annot.set_custom_annotations([row[feature_name] for _, row in credible_effects_df.iterrows()])
|
534
|
-
# annot.annotate()
|
535
|
-
|
536
|
-
if add_dots:
|
537
|
-
sns.swarmplot(
|
538
|
-
x="Cell type",
|
539
|
-
y=value_name,
|
540
|
-
data=plot_df,
|
541
|
-
hue=feature_name,
|
542
|
-
ax=ax,
|
543
|
-
dodge=True,
|
544
|
-
color="black",
|
545
|
-
**args_swarmplot,
|
546
|
-
)
|
547
|
-
|
548
|
-
cell_types = pd.unique(plot_df["Cell type"])
|
549
|
-
ax.set_xticklabels(cell_types, rotation=90)
|
550
|
-
|
551
|
-
if show_legend:
|
552
|
-
handles, labels = ax.get_legend_handles_labels()
|
553
|
-
handout = []
|
554
|
-
labelout = []
|
555
|
-
for h, l in zip(handles, labels):
|
556
|
-
if l not in labelout:
|
557
|
-
labelout.append(l)
|
558
|
-
handout.append(h)
|
559
|
-
ax.legend(
|
560
|
-
handout,
|
561
|
-
labelout,
|
562
|
-
loc="upper left",
|
563
|
-
bbox_to_anchor=(1, 1),
|
564
|
-
ncol=1,
|
565
|
-
title=feature_name,
|
566
|
-
)
|
567
|
-
|
568
|
-
plt.tight_layout()
|
569
|
-
return ax
|
270
|
+
from pertpy.tools import Sccoda
|
271
|
+
|
272
|
+
coda = Sccoda()
|
273
|
+
return coda.plot_boxplots(
|
274
|
+
data=data,
|
275
|
+
feature_name=feature_name,
|
276
|
+
modality_key=modality_key,
|
277
|
+
y_scale=y_scale,
|
278
|
+
plot_facets=plot_facets,
|
279
|
+
add_dots=add_dots,
|
280
|
+
cell_types=cell_types,
|
281
|
+
args_boxplot=args_boxplot,
|
282
|
+
args_swarmplot=args_swarmplot,
|
283
|
+
figsize=figsize,
|
284
|
+
dpi=dpi,
|
285
|
+
palette=cmap,
|
286
|
+
show_legend=show_legend,
|
287
|
+
level_order=level_order,
|
288
|
+
)
|
570
289
|
|
571
290
|
@staticmethod
|
572
291
|
def rel_abundance_dispersion_plot( # pragma: no cover
|
573
|
-
data:
|
292
|
+
data: AnnData | MuData,
|
574
293
|
modality_key: str = "coda",
|
575
|
-
abundant_threshold:
|
576
|
-
default_color:
|
577
|
-
abundant_color:
|
294
|
+
abundant_threshold: float | None = 0.9,
|
295
|
+
default_color: str | None = "Grey",
|
296
|
+
abundant_color: str | None = "Red",
|
578
297
|
label_cell_types: bool = True,
|
579
|
-
figsize:
|
580
|
-
dpi:
|
298
|
+
figsize: tuple[float, float] | None = None,
|
299
|
+
dpi: int | None = 100,
|
581
300
|
ax: Axes = None,
|
582
301
|
) -> plt.Axes:
|
583
302
|
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
@@ -601,7 +320,6 @@ class CodaPlot:
|
|
601
320
|
A :class:`~matplotlib.axes.Axes` object
|
602
321
|
|
603
322
|
Examples:
|
604
|
-
Example with scCODA:
|
605
323
|
>>> import pertpy as pt
|
606
324
|
>>> haber_cells = pt.dt.haber_2017_regions()
|
607
325
|
>>> sccoda = pt.tl.Sccoda()
|
@@ -609,119 +327,74 @@ class CodaPlot:
|
|
609
327
|
sample_identifier="batch", covariate_obs=["condition"])
|
610
328
|
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
|
611
329
|
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
|
612
|
-
>>>
|
330
|
+
>>> sccoda.plot_rel_abundance_dispersion_plot(mdata)
|
613
331
|
"""
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
_, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
620
|
-
|
621
|
-
rel_abun = data.X / np.sum(data.X, axis=1, keepdims=True)
|
622
|
-
|
623
|
-
percent_zero = np.sum(data.X == 0, axis=0) / data.X.shape[0]
|
624
|
-
nonrare_ct = np.where(percent_zero < 1 - abundant_threshold)[0]
|
625
|
-
|
626
|
-
# select reference
|
627
|
-
cell_type_disp = np.var(rel_abun, axis=0) / np.mean(rel_abun, axis=0)
|
628
|
-
|
629
|
-
is_abundant = [x in nonrare_ct for x in range(data.X.shape[1])]
|
630
|
-
|
631
|
-
# Scatterplot
|
632
|
-
plot_df = pd.DataFrame(
|
633
|
-
{
|
634
|
-
"Total dispersion": cell_type_disp,
|
635
|
-
"Cell type": data.var.index,
|
636
|
-
"Presence": 1 - percent_zero,
|
637
|
-
"Is abundant": is_abundant,
|
638
|
-
}
|
332
|
+
warnings.warn(
|
333
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
334
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
335
|
+
FutureWarning,
|
336
|
+
stacklevel=2,
|
639
337
|
)
|
640
338
|
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
palette=palette,
|
339
|
+
from pertpy.tools import Sccoda
|
340
|
+
|
341
|
+
coda = Sccoda()
|
342
|
+
return coda.plot_rel_abundance_dispersion_plot(
|
343
|
+
data=data,
|
344
|
+
modality_key=modality_key,
|
345
|
+
abundant_threshold=abundant_threshold,
|
346
|
+
default_color=default_color,
|
347
|
+
abundant_color=abundant_color,
|
348
|
+
label_cell_types=label_cell_types,
|
349
|
+
figsize=figsize,
|
350
|
+
dpi=dpi,
|
654
351
|
ax=ax,
|
655
352
|
)
|
656
353
|
|
657
|
-
# Text labels for abundant cell types
|
658
|
-
|
659
|
-
abundant_df = plot_df.loc[plot_df["Is abundant"], :]
|
660
|
-
|
661
|
-
def label_point(x, y, val, ax):
|
662
|
-
a = pd.concat({"x": x, "y": y, "val": val}, axis=1)
|
663
|
-
texts = [
|
664
|
-
ax.text(
|
665
|
-
point["x"],
|
666
|
-
point["y"],
|
667
|
-
str(point["val"]),
|
668
|
-
)
|
669
|
-
for i, point in a.iterrows()
|
670
|
-
]
|
671
|
-
adjust_text(texts)
|
672
|
-
|
673
|
-
if label_cell_types:
|
674
|
-
label_point(
|
675
|
-
abundant_df["Presence"],
|
676
|
-
abundant_df["Total dispersion"],
|
677
|
-
abundant_df["Cell type"],
|
678
|
-
plt.gca(),
|
679
|
-
)
|
680
|
-
|
681
|
-
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
682
|
-
|
683
|
-
plt.tight_layout()
|
684
|
-
return ax
|
685
|
-
|
686
354
|
@staticmethod
|
687
355
|
def draw_tree( # pragma: no cover
|
688
|
-
data:
|
356
|
+
data: AnnData | MuData,
|
689
357
|
modality_key: str = "coda",
|
690
|
-
tree:
|
691
|
-
tight_text:
|
692
|
-
show_scale:
|
693
|
-
show:
|
694
|
-
|
695
|
-
units:
|
696
|
-
|
697
|
-
|
698
|
-
dpi: Optional[int] = 90,
|
358
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
359
|
+
tight_text: bool | None = False,
|
360
|
+
show_scale: bool | None = False,
|
361
|
+
show: bool | None = True,
|
362
|
+
save: str | None = None,
|
363
|
+
units: Literal["px", "mm", "in"] | None = "px",
|
364
|
+
figsize: tuple[float, float] | None = (None, None),
|
365
|
+
dpi: int | None = 90,
|
699
366
|
):
|
700
367
|
"""Plot a tree using input ete3 tree object.
|
701
368
|
|
702
369
|
Args:
|
703
370
|
data: AnnData object or MuData object.
|
704
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
705
|
-
|
371
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
372
|
+
Defaults to "coda".
|
373
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
374
|
+
Defaults to "tree".
|
706
375
|
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
707
|
-
producing slightly worse aligned text faces but improving
|
376
|
+
producing slightly worse aligned text faces but improving
|
377
|
+
the performance of tree visualization in scenes with a lot of text faces.
|
708
378
|
Default to False.
|
709
|
-
show_scale: Include the scale legend in the tree image or not.
|
710
|
-
|
711
|
-
|
379
|
+
show_scale: Include the scale legend in the tree image or not.
|
380
|
+
Defaults to False.
|
381
|
+
show: If True, plot the tree inline. If false, return tree and tree_style objects.
|
382
|
+
Defaults to True.
|
383
|
+
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
|
384
|
+
Output image can be saved whether show is True or not.
|
712
385
|
Defaults to None.
|
713
|
-
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
386
|
+
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
|
387
|
+
Defaults to "px".
|
714
388
|
h: Height of the image in units. Defaults to None.
|
715
389
|
w: Width of the image in units. Defaults to None.
|
716
390
|
dpi: Dots per inches. Defaults to 90.
|
717
391
|
|
718
392
|
Returns:
|
719
|
-
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or
|
393
|
+
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
|
720
394
|
|
721
395
|
Examples:
|
722
|
-
Example with tascCODA:
|
723
396
|
>>> import pertpy as pt
|
724
|
-
>>> adata = pt.dt.
|
397
|
+
>>> adata = pt.dt.tasccoda_example()
|
725
398
|
>>> tasccoda = pt.tl.Tasccoda()
|
726
399
|
>>> mdata = tasccoda.load(
|
727
400
|
>>> adata, type="sample_level",
|
@@ -732,56 +405,62 @@ class CodaPlot:
|
|
732
405
|
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
733
406
|
>>> )
|
734
407
|
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
735
|
-
>>>
|
408
|
+
>>> tasccoda.plot_draw_tree(mdata, tree="lineage")
|
409
|
+
|
410
|
+
Preview: #TODO: Add preview
|
736
411
|
"""
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
412
|
+
warnings.warn(
|
413
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
414
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
415
|
+
FutureWarning,
|
416
|
+
stacklevel=2,
|
417
|
+
)
|
418
|
+
|
419
|
+
from pertpy.tools import Tasccoda
|
420
|
+
|
421
|
+
coda = Tasccoda()
|
422
|
+
return coda.plot_draw_tree(
|
423
|
+
data=data,
|
424
|
+
modality_key=modality_key,
|
425
|
+
tree=tree,
|
426
|
+
tight_text=tight_text,
|
427
|
+
show_scale=show_scale,
|
428
|
+
show=show,
|
429
|
+
save=save,
|
430
|
+
units=units,
|
431
|
+
figsize=figsize,
|
432
|
+
dpi=dpi,
|
433
|
+
)
|
758
434
|
|
759
435
|
@staticmethod
|
760
436
|
def draw_effects( # pragma: no cover
|
761
|
-
data:
|
437
|
+
data: AnnData | MuData,
|
762
438
|
covariate: str,
|
763
439
|
modality_key: str = "coda",
|
764
|
-
tree:
|
765
|
-
show_legend:
|
766
|
-
show_leaf_effects:
|
767
|
-
tight_text:
|
768
|
-
show_scale:
|
769
|
-
show:
|
770
|
-
|
771
|
-
units:
|
772
|
-
|
773
|
-
|
774
|
-
dpi: Optional[int] = 90,
|
440
|
+
tree: str = "tree", # Also type ete3.Tree. Omitted due to import errors
|
441
|
+
show_legend: bool | None = None,
|
442
|
+
show_leaf_effects: bool | None = False,
|
443
|
+
tight_text: bool | None = False,
|
444
|
+
show_scale: bool | None = False,
|
445
|
+
show: bool | None = True,
|
446
|
+
save: str | None = None,
|
447
|
+
units: Literal["px", "mm", "in"] | None = "in",
|
448
|
+
figsize: tuple[float, float] | None = (None, None),
|
449
|
+
dpi: int | None = 90,
|
775
450
|
):
|
776
451
|
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
777
452
|
|
778
453
|
Args:
|
779
454
|
data: AnnData object or MuData object.
|
780
455
|
covariate: The covariate, whose effects should be plotted.
|
781
|
-
modality_key: If data is a MuData object, specify which modality to use.
|
782
|
-
|
783
|
-
|
784
|
-
|
456
|
+
modality_key: If data is a MuData object, specify which modality to use.
|
457
|
+
Defaults to "coda".
|
458
|
+
tree: A ete3 tree object or a str to indicate the tree stored in `.uns`.
|
459
|
+
Defaults to "tree".
|
460
|
+
show_legend: If show legend of nodes significant effects or not.
|
461
|
+
Defaults to False if show_leaf_effects is True.
|
462
|
+
show_leaf_effects: If True, plot bar plots which indicate leave-level significant effects.
|
463
|
+
Defaults to False.
|
785
464
|
tight_text: When False, boundaries of the text are approximated according to general font metrics,
|
786
465
|
producing slightly worse aligned text faces but improving the performance of tree visualization in scenes with a lot of text faces.
|
787
466
|
Defaults to False.
|
@@ -799,9 +478,8 @@ class CodaPlot:
|
|
799
478
|
or plot the tree inline (`show = False`)
|
800
479
|
|
801
480
|
Examples:
|
802
|
-
Example with tascCODA:
|
803
481
|
>>> import pertpy as pt
|
804
|
-
>>> adata = pt.dt.
|
482
|
+
>>> adata = pt.dt.tasccoda_example()
|
805
483
|
>>> tasccoda = pt.tl.Tasccoda()
|
806
484
|
>>> mdata = tasccoda.load(
|
807
485
|
>>> adata, type="sample_level",
|
@@ -812,138 +490,38 @@ class CodaPlot:
|
|
812
490
|
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
813
491
|
>>> )
|
814
492
|
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
815
|
-
>>>
|
493
|
+
>>> tasccoda.plot_draw_effects(mdata, covariate="Health[T.Inflamed]", tree="lineage")
|
816
494
|
"""
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
show_legend = not show_leaf_effects
|
823
|
-
elif show_legend:
|
824
|
-
print("Tree leaves and leaf effect bars won't be aligned when legend is shown!")
|
825
|
-
|
826
|
-
if isinstance(tree, str):
|
827
|
-
tree = data.uns[tree]
|
828
|
-
# Collapse tree singularities
|
829
|
-
tree2 = collapse_singularities_2(tree)
|
830
|
-
|
831
|
-
node_effs = data.uns["scCODA_params"]["node_df"].loc[(covariate + "_node",),].copy()
|
832
|
-
node_effs.index = node_effs.index.get_level_values("Node")
|
833
|
-
|
834
|
-
covariates = data.uns["scCODA_params"]["covariate_names"]
|
835
|
-
effect_dfs = [data.varm[f"effect_df_{cov}"] for cov in covariates]
|
836
|
-
eff_df = pd.concat(effect_dfs)
|
837
|
-
eff_df.index = pd.MultiIndex.from_product(
|
838
|
-
(covariates, data.var.index.tolist()),
|
839
|
-
names=["Covariate", "Cell Type"],
|
495
|
+
warnings.warn(
|
496
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
497
|
+
" Please use the corresponding 'pt.tl' object",
|
498
|
+
FutureWarning,
|
499
|
+
stacklevel=2,
|
840
500
|
)
|
841
|
-
leaf_effs = eff_df.loc[(covariate,),].copy()
|
842
|
-
leaf_effs.index = leaf_effs.index.get_level_values("Cell Type")
|
843
|
-
|
844
|
-
# Add effect values
|
845
|
-
for n in tree2.traverse():
|
846
|
-
nstyle = NodeStyle()
|
847
|
-
nstyle["size"] = 0
|
848
|
-
n.set_style(nstyle)
|
849
|
-
if n.name in node_effs.index:
|
850
|
-
e = node_effs.loc[n.name, "Final Parameter"]
|
851
|
-
n.add_feature("node_effect", e)
|
852
|
-
else:
|
853
|
-
n.add_feature("node_effect", 0)
|
854
|
-
if n.name in leaf_effs.index:
|
855
|
-
e = leaf_effs.loc[n.name, "Effect"]
|
856
|
-
n.add_feature("leaf_effect", e)
|
857
|
-
else:
|
858
|
-
n.add_feature("leaf_effect", 0)
|
859
|
-
|
860
|
-
# Scale effect values to get nice node sizes
|
861
|
-
eff_max = np.max([np.abs(n.node_effect) for n in tree2.traverse()])
|
862
|
-
leaf_eff_max = np.max([np.abs(n.leaf_effect) for n in tree2.traverse()])
|
863
|
-
|
864
|
-
def my_layout(node):
|
865
|
-
text_face = TextFace(node.name, tight_text=tight_text)
|
866
|
-
text_face.margin_left = 10
|
867
|
-
faces.add_face_to_node(text_face, node, column=0, aligned=True)
|
868
|
-
|
869
|
-
# if node.is_leaf():
|
870
|
-
size = (np.abs(node.node_effect) * 10 / eff_max) if node.node_effect != 0 else 0
|
871
|
-
if np.sign(node.node_effect) == 1:
|
872
|
-
color = "blue"
|
873
|
-
elif np.sign(node.node_effect) == -1:
|
874
|
-
color = "red"
|
875
|
-
else:
|
876
|
-
color = "cyan"
|
877
|
-
if size != 0:
|
878
|
-
faces.add_face_to_node(CircleFace(radius=size, color=color), node, column=0)
|
879
|
-
|
880
|
-
tree_style = TreeStyle()
|
881
|
-
tree_style.show_leaf_name = False
|
882
|
-
tree_style.layout_fn = my_layout
|
883
|
-
tree_style.show_scale = show_scale
|
884
|
-
tree_style.draw_guiding_lines = True
|
885
|
-
tree_style.legend_position = 1
|
886
501
|
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
)
|
906
|
-
tree_style.legend.add_face(TextFace(f" {eff_max * i / 4:.2f}"), column=1)
|
907
|
-
|
908
|
-
if show_leaf_effects:
|
909
|
-
leaf_name = [node.name for node in tree2.traverse("postorder") if node.is_leaf()]
|
910
|
-
leaf_effs = leaf_effs.loc[leaf_name].reset_index()
|
911
|
-
palette = ["blue" if Effect > 0 else "red" for Effect in leaf_effs["Effect"].tolist()]
|
912
|
-
|
913
|
-
dir_path = Path.cwd()
|
914
|
-
dir_path = Path(dir_path / "tree_effect.png")
|
915
|
-
tree2.render(dir_path, tree_style=tree_style, units="in")
|
916
|
-
_, ax = plt.subplots(1, 2, figsize=(10, 10))
|
917
|
-
sns.barplot(data=leaf_effs, x="Effect", y="Cell Type", palette=palette, ax=ax[1])
|
918
|
-
img = mpimg.imread(dir_path)
|
919
|
-
ax[0].imshow(img)
|
920
|
-
ax[0].get_xaxis().set_visible(False)
|
921
|
-
ax[0].get_yaxis().set_visible(False)
|
922
|
-
ax[0].set_frame_on(False)
|
923
|
-
|
924
|
-
ax[1].get_yaxis().set_visible(False)
|
925
|
-
ax[1].spines["left"].set_visible(False)
|
926
|
-
ax[1].spines["right"].set_visible(False)
|
927
|
-
ax[1].spines["top"].set_visible(False)
|
928
|
-
plt.xlim(-leaf_eff_max, leaf_eff_max)
|
929
|
-
plt.subplots_adjust(wspace=0)
|
930
|
-
|
931
|
-
if file_name is not None:
|
932
|
-
plt.savefig(file_name)
|
933
|
-
|
934
|
-
if file_name is not None and not show_leaf_effects:
|
935
|
-
tree2.render(file_name, tree_style=tree_style, units=units)
|
936
|
-
if show:
|
937
|
-
if not show_leaf_effects:
|
938
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
|
939
|
-
else:
|
940
|
-
if not show_leaf_effects:
|
941
|
-
return tree2, tree_style
|
502
|
+
from pertpy.tools import Tasccoda
|
503
|
+
|
504
|
+
coda = Tasccoda()
|
505
|
+
return coda.plot_draw_effects(
|
506
|
+
data=data,
|
507
|
+
modality_key=modality_key,
|
508
|
+
covariate=covariate,
|
509
|
+
tree=tree,
|
510
|
+
show_legend=show_legend,
|
511
|
+
show_leaf_effects=show_leaf_effects,
|
512
|
+
tight_text=tight_text,
|
513
|
+
show_scale=show_scale,
|
514
|
+
show=show,
|
515
|
+
save=save,
|
516
|
+
units=units,
|
517
|
+
figsize=figsize,
|
518
|
+
dpi=dpi,
|
519
|
+
)
|
942
520
|
|
943
521
|
@staticmethod
|
944
522
|
def effects_umap( # pragma: no cover
|
945
523
|
data: MuData,
|
946
|
-
effect_name:
|
524
|
+
effect_name: str | list | None,
|
947
525
|
cluster_key: str,
|
948
526
|
modality_key_1: str = "rna",
|
949
527
|
modality_key_2: str = "coda",
|
@@ -953,49 +531,71 @@ class CodaPlot:
|
|
953
531
|
):
|
954
532
|
"""Plot a UMAP visualization colored by effect strength.
|
955
533
|
|
956
|
-
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
534
|
+
Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData
|
535
|
+
(default is data['rna']) depending on the cluster they were assigned to.
|
957
536
|
|
958
537
|
Args:
|
959
538
|
data: AnnData object or MuData object.
|
960
|
-
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData
|
961
|
-
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
539
|
+
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
|
540
|
+
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
|
541
|
+
To assign cell types' effects to original cells.
|
962
542
|
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
|
963
|
-
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
543
|
+
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
544
|
+
Defaults to "coda".
|
964
545
|
show: Whether to display the figure or return axis. Defaults to None.
|
965
|
-
ax: A matplotlib axes object. Only works if plotting a single component.
|
546
|
+
ax: A matplotlib axes object. Only works if plotting a single component.
|
547
|
+
Defaults to None.
|
966
548
|
**kwargs: All other keyword arguments are passed to `scanpy.plot.umap()`
|
967
549
|
|
968
550
|
Returns:
|
969
551
|
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
|
970
552
|
|
971
553
|
Examples:
|
972
|
-
Example with scCODA:
|
973
554
|
>>> import pertpy as pt
|
974
|
-
>>>
|
975
|
-
>>>
|
976
|
-
>>>
|
977
|
-
|
978
|
-
>>>
|
979
|
-
>>>
|
980
|
-
|
981
|
-
>>>
|
982
|
-
|
555
|
+
>>> import schist
|
556
|
+
>>> adata = pt.dt.haber_2017_regions()
|
557
|
+
>>> schist.inference.nested_model(adata, samples=100, random_seed=5678)
|
558
|
+
>>> tasccoda_model = pt.tl.Tasccoda()
|
559
|
+
>>> tasccoda_data = tasccoda_model.load(adata, type="cell_level",
|
560
|
+
>>> cell_type_identifier="nsbm_level_1",
|
561
|
+
>>> sample_identifier="batch", covariate_obs=["condition"],
|
562
|
+
>>> levels_orig=["nsbm_level_4", "nsbm_level_3", "nsbm_level_2", "nsbm_level_1"],
|
563
|
+
>>> add_level_name=True)sccoda = pt.tl.Sccoda()
|
564
|
+
>>> tasccoda_model.prepare(
|
565
|
+
>>> tasccoda_data,
|
566
|
+
>>> modality_key="coda",
|
567
|
+
>>> reference_cell_type="18",
|
568
|
+
>>> formula="condition",
|
569
|
+
>>> pen_args={"phi": 0, "lambda_1": 3.5},
|
570
|
+
>>> tree_key="tree"
|
571
|
+
>>> )
|
572
|
+
>>> tasccoda_model.run_nuts(
|
573
|
+
... tasccoda_data, modality_key="coda", rng_key=1234, num_samples=10000, num_warmup=1000
|
574
|
+
... )
|
575
|
+
>>> tasccoda_model.plot_effects_umap(tasccoda_data,
|
576
|
+
>>> effect_name=["effect_df_condition[T.Salmonella]",
|
577
|
+
>>> "effect_df_condition[T.Hpoly.Day3]",
|
578
|
+
>>> "effect_df_condition[T.Hpoly.Day10]"],
|
579
|
+
>>> cluster_key="nsbm_level_1",
|
580
|
+
>>> )
|
983
581
|
"""
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
582
|
+
warnings.warn(
|
583
|
+
"This function is deprecated and will be removed in pertpy 0.8.0!"
|
584
|
+
" Please use the corresponding 'pt.tl' object for plotting function directly.",
|
585
|
+
FutureWarning,
|
586
|
+
stacklevel=2,
|
587
|
+
)
|
588
|
+
|
589
|
+
from pertpy.tools import Tasccoda
|
590
|
+
|
591
|
+
coda = Tasccoda()
|
592
|
+
coda.plot_effects_umap(
|
593
|
+
data=data,
|
594
|
+
effect_name=effect_name,
|
595
|
+
cluster_key=cluster_key,
|
596
|
+
modality_key_1=modality_key_1,
|
597
|
+
modality_key_2=modality_key_2,
|
598
|
+
show=show,
|
599
|
+
ax=ax,
|
600
|
+
**kwargs,
|
601
|
+
)
|